import torch from torch.utils.data import Dataset,DataLoader from torch.autograd import Variable import json import random import sys import pickle from tqdm import tqdm import copy sys.path.append("..") from utils import text_standardize USE_CUDA = torch.cuda.is_available() FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor # ==== Code for data loading ===== class GptDataset(Dataset): """Take a list of samples with form [[x,...],y,meta] """ # need 3 special tokens # # as <ref start> 2 # $ as <speaker1> 3 # % as <speaker2> 4 # '<|endoftext|>' as <eos> 50256 def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] for x,y,meta in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) return x_all,y_all,meta_all def _filter(self,x_all,y_all,meta_all,filter_mode=None): allowed_pattern = ['SR_only','CR_only','Smoking_only','Diet_only'] data = zip(x_all,y_all,meta_all) if filter_mode not in allowed_pattern: data_filt = data if filter_mode=='SR_only': data_filt = [x for x in data if x[2][2]=='SR'] if filter_mode=='CR_only': data_filt = [x for x in data if x[2][2]=='CR'] if filter_mode=='Smoking_only': data_filt = [x for x in data if x[2][1]=='Smoking cessation'] if filter_mode=='Diet_only': data_filt = [x for x in data if x[2][1]=='Weight management'] x_filt,y_filt,meta_filt = zip(*data_filt) return x_filt, y_filt, meta_filt def __init__(self,x_y_meta,tokenizer,filter_mode=None,num_turns=5): self.x_y_meta = x_y_meta self.num_turns = num_turns self.tokenizer = tokenizer self.x_encoded,self.y_encoded,self.meta = self._split(x_y_meta) self.x_encoded,self.y_encoded,self.meta = self._filter(self.x_encoded,self.y_encoded,self.meta,filter_mode) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 def __getitem__(self,index): x = [] type_x = [] lm_x = [] is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-1]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-1] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) # class GptDataset_keyword(Dataset): # def _split(self,x_y_meta): # x_all = [] # y_all = [] # meta_all = [] # aug_all = [] # for x,y,meta,aug in x_y_meta: # meta_all.append(meta) # x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) # y_all.append(self.tokenizer.encode(text_standardize(y))) # key_word.append(self.tokenizer.encode(text_standardize(aug))) # return x_all,y_all,meta_all,aug_all # def __init__(self,x_y_meta,tokenizer,num_turns=5): # self.x_y_meta = x_y_meta # self.num_turns = num_turns # self.tokenizer = tokenizer # self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta) # self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 # self.augment = 5 # self.keyword = 10 # '+' class GptDataset_aug(Dataset): def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] aug_all = [] for x,y,meta,aug in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) aug_all.append(self.tokenizer.encode(text_standardize(aug))) return x_all,y_all,meta_all,aug_all def __init__(self,x_y_meta,tokenizer,num_turns=5): self.x_y_meta = x_y_meta self.num_turns = num_turns self.tokenizer = tokenizer self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 self.augment = 5 def __getitem__(self,index): x = [] type_x = [] lm_x = [] x += [self.augment] + self.aug_encoded[index] type_x += [self.augment] * len(x) is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-1]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-1] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) def collate_fn(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs,lm_seqs,total_input_length,meta = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) lm_seqs, lm_lengths = merge(lm_seqs) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() lm_seqs = lm_seqs.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta def collate_fn_nli(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs,lm_seqs,label = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) # lm_seqs, lm_lengths = merge(lm_seqs) label = torch.tensor(label) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() # lm_seqs = lm_seqs.cuda() label = label.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),lm_seqs, label def collate_fn_keyword(data): def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs, lm_seqs, total_input_length, meta, keyword_x = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) lm_seqs, lm_lengths = merge(lm_seqs) keyword_x, _ = merge(keyword_x) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() lm_seqs = lm_seqs.cuda() keyword_x = keyword_x.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x)) class GptDataset_keyword(Dataset): def _split(self, x_y_meta): x_all = [] y_all = [] meta_all = [] keyword_all = [] for x, y, meta, keyword in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) keyword_all.append(self.tokenizer.encode(text_standardize(keyword))) return x_all, y_all, meta_all, keyword_all def __init__(self, x_y_meta, tokenizer, num_turns=5): self.x_y_meta = x_y_meta self.num_turns = num_turns self.tokenizer = tokenizer self.x_encoded, self.y_encoded, self.meta, self.keyword = self._split(x_y_meta) self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256 def __getitem__(self, index): x = [] type_x = [] lm_x = [] is_speaker1 = bool(len(self.x_encoded[index]) % 2) # which speaker start the conversation for utt in self.x_encoded[index]: if is_speaker1: # add the prefix special token for each utterance x += [self.speaker1] type_x += [self.speaker1] * (len(utt) + 1) else: x += [self.speaker2] type_x += [self.speaker2] * (len(utt) + 1) x += utt is_speaker1 = not is_speaker1 lm_x += [-1] * len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start] * (len(self.y_encoded[index]) + 2) lm_x += [-1] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] keyword_x = [] + self.keyword[index] keyword_x = torch.Tensor(keyword_x) return x, type_x, position_x, lm_x, total_input_length, self.meta[index], keyword_x def __len__(self): return len(self.x_encoded) # class GptDataset_nli(GptDataset): # def __init__(self, x_y_meta, tokenizer, filter_mode=None,num_turns=5,augment=True): # super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, num_turns=num_turns) # self.augment = augment # self.pos_len = len(self.x_encoded) # def __len__(self): # if self.augment: # return 2 * len(self.x_encoded) # else: # return len(self.x_encoded) # def __getitem__(self,index): # # client utterances - premise -speaker1 # # response - hypothesis - ref_start # true_index = index # if index >= self.pos_len: # index = index - self.pos_len # x = [] # type_x = [] # lm_x = [] # is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation # x+=[self.speaker1] # type_x += [self.speaker1] # for utt in self.x_encoded[index][-self.num_turns:]: # if is_speaker1: # add the prefix special token for each utterance # type_x += [self.speaker1]*(len(utt)) # x += utt # # else: # # x+=[self.speaker2] # # type_x += [self.speaker2]*(len(utt)+1) # # x += utt # is_speaker1 = not is_speaker1 # total_input_length = len(x) # if true_index >= self.pos_len: # rand_index = random.randint(0,self.pos_len-1) # x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos] # type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2) # else: # x += [self.ref_start] + self.y_encoded[index] + [self.eos] # type_x += [self.ref_start]*(len(self.y_encoded[index])+2) # position_x = list(range(len(x))) # x = torch.Tensor(x) # type_x = torch.Tensor(type_x) # position_x = torch.Tensor(position_x) # x_len = x.shape[0] # label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) # return x,type_x,position_x,lm_x, label class SnliDataset(Dataset): """Take a list of samples with form [[x,...],y,meta] """ # need 3 special tokens # # as <ref start> 2 # $ as <speaker1> 3 # % as <speaker2> 4 # '<|endoftext|>' as <eos> 50256 def _split(self,data): positive_label = set(['entailment']) premise = [] hypothesis = [] label = [] for p,h,l in tqdm(data): premise.append(self.tokenizer.encode(text_standardize(p))) hypothesis.append(self.tokenizer.encode(text_standardize(h))) if l in positive_label: label.append(torch.tensor(1)) else: label.append(torch.tensor(0)) return premise,hypothesis,label def _filter(self,premise,hypothesis,label,filter_mode=None): data = zip(premise,hypothesis,label) if filter_mode == None: data_filt = data else: data_filt = [x for x in data if x[2]!='-'] premise_filt,hypothesis_filt,label_filt = zip(*data_filt) return premise_filt,hypothesis_filt,label_filt def parse_snli(self,path=None): with open(path) as f: data = [json.loads(line) for line in f] data_processed = [(line['sentence1'],line['sentence2'],line['gold_label']) for line in data] return data_processed def __init__(self,tokenizer,path='../data/snli_1.0/snli_1.0_train.jsonl',filter_mode=None,num_turns=5): self.data = self.parse_snli(path) self.tokenizer = tokenizer self.premise_encoded,self.hypothesis_encoded,self.label = self._split(self.data) self.premise_encoded,self.hypothesis_encoded,self.label = self._filter(self.premise_encoded,self.hypothesis_encoded,self.label,filter_mode) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 def __getitem__(self,index): x = [] type_x = [] lm_x = [] x += [self.speaker1] x += self.premise_encoded[index] type_x += [self.speaker1]*(len(self.premise_encoded[index])+1) # the premise part x += [self.ref_start] x += self.hypothesis_encoded[index] x += [self.eos] type_x += [self.ref_start]*(len(self.hypothesis_encoded[index])+2) # the hypothesis part label = self.label[index] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) return x,type_x,position_x,lm_x,label def __len__(self): return len(self.premise_encoded) class GptDataset_full(Dataset): def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] aug_all = [] keyword_all = [] for x, y, meta, aug, keyword in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) aug_all.append(self.tokenizer.encode(text_standardize(aug))) keyword_all.append(self.tokenizer.encode(text_standardize(keyword))) return x_all,y_all,meta_all,aug_all, keyword_all def _filt(self, length=1024): data = zip(self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded) data = [sample for sample in data if sum([len(sen) for sen in sample[0]][-self.args.num_turns:])+len(sample[1])+len(sample[3])+len(sample[4]) < 850] self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = zip(*data) self.x_encoded = list(self.x_encoded) self.y_encoded = list(self.y_encoded) self.meta = list(self.meta) self.aug_encoded = list(self.aug_encoded) self.keyword_encoded = list(self.keyword_encoded) def __init__(self,x_y_meta,tokenizer,args): self.x_y_meta = x_y_meta self.num_turns = args.num_turns self.tokenizer = tokenizer self.args = args self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = self._split(x_y_meta) self._filt() # TODO: add back filt for mix-review self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 self.augment = 5 if self.args.augment: print("Using augment sentences.") if self.args.keyword: print("Using keywords.") def __getitem__(self,index): x = [] type_x = [] lm_x = [] if self.args.augment: x += [self.augment] + self.aug_encoded[index] if self.args.keyword: x += [self.augment] + self.keyword_encoded[index] type_x += [self.augment] * len(x) is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-100]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-100] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) class GptDataset_nli(GptDataset_full): def __init__(self, x_y_meta, tokenizer, args, infer=False): super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args) self.pos_len = len(self.x_encoded) self.num_turns = 5 self.infer = infer def __len__(self): if self.infer: return len(self.x_encoded) else: return 2 * len(self.x_encoded) def __getitem__(self,index): # client utterances - premise -speaker1 # response - hypothesis - ref_start true_index = index if index >= self.pos_len: index = index - self.pos_len x = [] type_x = [] lm_x = [] is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation x+=[self.speaker1] type_x += [self.speaker1] for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance type_x += [self.speaker1]*(len(utt)) x += utt # else: # x+=[self.speaker2] # type_x += [self.speaker2]*(len(utt)+1) # x += utt is_speaker1 = not is_speaker1 total_input_length = len(x) if true_index >= self.pos_len: rand_index = random.randint(0,self.pos_len-1) x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2) else: x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) x_len = x.shape[0] label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) return x,type_x,position_x,lm_x, label class XLDataset_nli(GptDataset_nli): def __init__(self, x_y_meta, tokenizer, args, infer=False): super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args) self.pos_len = len(self.x_encoded) self.num_turns = 5 self.infer = infer # self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 self.pad, self.sep, self.cls = 5, 4, 3 self.unk, self.s, self.s_bar = 0, 1, 2 def __len__(self): if self.infer: return len(self.x_encoded) else: return 2 * len(self.x_encoded) def __getitem__(self,index): # client utterances - premise -speaker1 # response - hypothesis - ref_start true_index = index if index >= self.pos_len: index = index - self.pos_len x = [] type_x = [] lm_x = [] mask_x = [] is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance type_x += [self.unk]*(len(utt)) x += utt else: type_x += [self.unk]*(len(utt)) x += utt is_speaker1 = not is_speaker1 # import pdb;pdb.set_trace() x += [self.sep] type_x += [self.unk] total_input_length = len(x) if true_index >= self.pos_len: rand_index = random.randint(0,self.pos_len-1) x += self.y_encoded[rand_index] + [self.sep, self.cls] type_x += [self.s]*(len(self.y_encoded[rand_index])+1) + [self.s_bar] else: # x += [self.ref_start] + self.y_encoded[index] + [self.eos] # type_x += [self.ref_start]*(len(self.y_encoded[index])+2) x += self.y_encoded[index] + [self.sep, self.cls] type_x += [self.s]*(len(self.y_encoded[index])+1) + [self.s_bar] position_x = list(range(len(x))) mask_x = [self.s] * len(x) # #### # x = x[-100:] # mask_x = mask_x[-100:] # type_x = type_x[-100:] # left padding x = [self.pad] * (self.args.max_length-len(x)) + x[-self.args.max_length:] mask_x = [self.unk] * (self.args.max_length-len(mask_x)) + mask_x[-self.args.max_length:] type_x = [self.sep] * (self.args.max_length-len(type_x)) + type_x[-self.args.max_length:] x = torch.Tensor(x).long() mask_x = torch.Tensor(mask_x).long() type_x = torch.Tensor(type_x).long() position_x = torch.Tensor(position_x) x_len = x.shape[0] label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) # label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(0) if USE_CUDA: x = x.cuda() mask_x = mask_x.cuda() type_x = type_x.cuda() label = label.cuda() # return x, mask_x, type_x, label, position_x, lm_x return x, mask_x, type_x, label def get_data(args, tokenizer, split_size): """ Return the data loaders needed for training and evaluation. :param args: command line arguments. :param tokenizer: the tokenizer used in preparing the data. :param split_size: the portion of train, test, validation set. :return data_loader: The data loader for the training set. :return val_loader: The data loader for the validation set. """ # random.seed(args.seed) # torch.random.manual_seed(args.seed) # torch.cuda.manual_seed(args.seed) # torch.manual_seed(args.seed) if args.special_input: print("Using mutated data.") pickle_handler = open('../data_processed/' + args.special_input, 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) else: print("Using full data.") pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data. x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args) print("Dataset initialized. There are {} samples.".format(len(gpt_data))) test_size = int(len(gpt_data) * split_size['test']) val_size = int(len(gpt_data) * split_size['val']) gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data, [len(gpt_data) - test_size - val_size, test_size, val_size]) # import pdb;pdb.set_trace() # with open('../../mi_data/train','wb') as f: # pickle.dump(gpt_train, f) # with open('../../mi_data/test','wb') as f: # pickle.dump(gpt_test, f) # with open('../../mi_data/val','wb') as f: # pickle.dump(gpt_val, f) if 'train_batch_size' not in args: args.train_batch_size = 1 data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) return data_loader, test_loader, val_loader def prepare_mix_review(args, tokenizer): print("Preparing Alexander dataset") pickle_handler = open('../data_processed/data_alex', 'rb') data = pickle.load(pickle_handler) gpt_alex = GptDataset_full(data, tokenizer, args=args) print("Alexander dataset prepared. Has {} samples".format(len(gpt_alex))) return gpt_alex def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0.7, collate_fn=collate_fn): mix_amount = int(mix_ratio*(0.7**epoch)*len(gpt_train)) gpt_alex_active,_ = torch.utils.data.random_split(gpt_alex, [mix_amount, len(gpt_alex)-mix_amount]) data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) return data_loader def get_data_old(args, tokenizer, split_size): """ Return the data loaders needed for training and evaluation. :param args: command line arguments. :param tokenizer: the tokenizer used in preparing the data. :param split_size: the portion of train, test, validation set. :return data_loader: The data loader for the training set. :return val_loader: The data loader for the validation set. """ if args.special_input: print("Using mutated data.") pickle_handler = open('../data_processed/' + args.special_input, 'rb') x_y_meta = pickle.load(pickle_handler) if args.augment: print("testing keywords with augment loader.") gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns) else: gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) elif args.augment: print("Using augmented data") pickle_handler = open('../data_processed/x_y_meta_aug', 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns) elif args.keyword: print("Using keyword cross attention") pickle_handler = open('../data_processed/x_y_meta_keyword', 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset_keyword(x_y_meta, tokenizer) else: print("Using vanilla data.") pickle_handler = open('../data_processed/x_y_meta', 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) print("Dataset initialized. There are {} samples.".format(len(gpt_data))) test_size = int(len(gpt_data) * split_size['test']) val_size = int(len(gpt_data) * split_size['val']) gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data, [len(gpt_data) - test_size - val_size, test_size, val_size]) if args.keyword: data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn_keyword) val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn_keyword) else: data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) return data_loader, val_loader