Skip to content
Snippets Groups Projects
Commit 9dda1ef1 authored by shensq's avatar shensq
Browse files

attention_mask move to gpu

parent 95f573b4
Branches
No related tags found
No related merge requests found
......@@ -207,6 +207,7 @@ def collate_fn(data):
trg_seqs = trg_seqs.cuda()
pos_seqs = pos_seqs.cuda()
lm_seqs = lm_seqs.cuda()
attention_mask = attention_mask.cuda()
return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, attention_mask
def collate_fn_nli(data):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment