Skip to content
Snippets Groups Projects
Commit b98e456d authored by shensq's avatar shensq
Browse files

collate_fn for conditional experiment

parent 89d4adaa
No related branches found
No related tags found
No related merge requests found
......@@ -211,6 +211,34 @@ def collate_fn(data):
attention_mask = attention_mask.cuda()
return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, attention_mask
def collate_fn_conditional(data):
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[0]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs, pos_seqs,lm_seqs,total_input_length, meta = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs)
trg_seqs, trg_lengths = merge(trg_seqs)
pos_seqs, pos_lengths = merge(pos_seqs)
lm_seqs, lm_lengths = merge(lm_seqs)
if USE_CUDA:
src_seqs = src_seqs.cuda()
trg_seqs = trg_seqs.cuda()
pos_seqs = pos_seqs.cuda()
lm_seqs = lm_seqs.cuda()
return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta
def collate_fn_nli(data):
"""Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
We should build a custom collate_fn rather than using default collate_fn,
......@@ -857,10 +885,16 @@ def get_data(args, tokenizer, split_size):
if 'train_batch_size' not in args:
args.train_batch_size = 1
data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn)
test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
if args.conditional:
data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn_conditional)
test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn_conditional)
val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn_conditional)
else:
data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn)
test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
return data_loader, test_loader, val_loader
def prepare_mix_review(args, tokenizer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment