import torch from torch.utils.data import Dataset,DataLoader from torch.autograd import Variable import json import random import sys import pickle from tqdm import tqdm from collections import deque import copy sys.path.append("..") from utils import text_standardize USE_CUDA = torch.cuda.is_available() FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor # ==== Code for data loading ===== class GptDataset(Dataset): """Take a list of samples with form [[x,...],y,meta] """ # need 3 special tokens # # as <ref start> 2 # $ as <speaker1> 3 # % as <speaker2> 4 # '<|endoftext|>' as <eos> 50256 def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] for x,y,meta in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) return x_all,y_all,meta_all def _filter(self,x_all,y_all,meta_all,filter_mode=None): allowed_pattern = ['SR_only','CR_only','Smoking_only','Diet_only'] data = zip(x_all,y_all,meta_all) if filter_mode not in allowed_pattern: data_filt = data if filter_mode=='SR_only': data_filt = [x for x in data if x[2][2]=='SR'] if filter_mode=='CR_only': data_filt = [x for x in data if x[2][2]=='CR'] if filter_mode=='Smoking_only': data_filt = [x for x in data if x[2][1]=='Smoking cessation'] if filter_mode=='Diet_only': data_filt = [x for x in data if x[2][1]=='Weight management'] x_filt,y_filt,meta_filt = zip(*data_filt) return x_filt, y_filt, meta_filt def __init__(self,x_y_meta,tokenizer,filter_mode=None,num_turns=5): self.x_y_meta = x_y_meta self.num_turns = num_turns self.tokenizer = tokenizer self.x_encoded,self.y_encoded,self.meta = self._split(x_y_meta) self.x_encoded,self.y_encoded,self.meta = self._filter(self.x_encoded,self.y_encoded,self.meta,filter_mode) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 def __getitem__(self,index): x = [] type_x = [] lm_x = [] is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-1]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-1] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) class GptDataset_aug(Dataset): def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] aug_all = [] for x,y,meta,aug in x_y_meta: meta_all.append(meta) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) aug_all.append(self.tokenizer.encode(text_standardize(aug))) return x_all,y_all,meta_all,aug_all def __init__(self,x_y_meta,tokenizer,num_turns=5): self.x_y_meta = x_y_meta self.num_turns = num_turns self.tokenizer = tokenizer self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 self.augment = 5 def __getitem__(self,index): x = [] type_x = [] lm_x = [] x += [self.augment] + self.aug_encoded[index] type_x += [self.augment] * len(x) is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-1]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-1] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) def collate_fn(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths def merge_matrix(matrices): max_size = max([m.shape[-1] for m in attention_mask]) padded_matrices = torch.zeros(len(matrices), 1, max_size, max_size) for i,m in enumerate(matrices): m_size = m.shape[-1] padded_matrices[i,:,:m_size,:m_size] = m return padded_matrices # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs,lm_seqs,total_input_length,attention_mask = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) lm_seqs, lm_lengths = merge(lm_seqs) if type(attention_mask[0]) is not list: attention_mask = merge_matrix(attention_mask) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() lm_seqs = lm_seqs.cuda() attention_mask = attention_mask.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, attention_mask def collate_fn_conditional(data): def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs,lm_seqs,total_input_length, meta = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) lm_seqs, lm_lengths = merge(lm_seqs) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() lm_seqs = lm_seqs.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta def collate_fn_nli(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs,lm_seqs,label = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) # lm_seqs, lm_lengths = merge(lm_seqs) label = torch.tensor(label) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() # lm_seqs = lm_seqs.cuda() label = label.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),lm_seqs, label def collate_fn_keyword(data): def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, lengths # sort a list by sequence length (descending order) to use pack_padded_sequence data.sort(key=lambda x: len(x[0]), reverse=True) # seperate source and target sequences src_seqs, trg_seqs, pos_seqs, lm_seqs, total_input_length, meta, keyword_x = zip(*data) # merge sequences (from tuple of 1D tensor to 2D tensor) src_seqs, src_lengths = merge(src_seqs) trg_seqs, trg_lengths = merge(trg_seqs) pos_seqs, pos_lengths = merge(pos_seqs) lm_seqs, lm_lengths = merge(lm_seqs) keyword_x, _ = merge(keyword_x) if USE_CUDA: src_seqs = src_seqs.cuda() trg_seqs = trg_seqs.cuda() pos_seqs = pos_seqs.cuda() lm_seqs = lm_seqs.cuda() keyword_x = keyword_x.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x)) class SnliDataset(Dataset): """Take a list of samples with form [[x,...],y,meta] """ # need 3 special tokens # # as <ref start> 2 # $ as <speaker1> 3 # % as <speaker2> 4 # '<|endoftext|>' as <eos> 50256 def _split(self,data): positive_label = set(['entailment']) premise = [] hypothesis = [] label = [] for p,h,l in tqdm(data): premise.append(self.tokenizer.encode(text_standardize(p))) hypothesis.append(self.tokenizer.encode(text_standardize(h))) if l in positive_label: label.append(torch.tensor(1)) else: label.append(torch.tensor(0)) return premise,hypothesis,label def _filter(self,premise,hypothesis,label,filter_mode=None): data = zip(premise,hypothesis,label) if filter_mode == None: data_filt = data else: data_filt = [x for x in data if x[2]!='-'] premise_filt,hypothesis_filt,label_filt = zip(*data_filt) return premise_filt,hypothesis_filt,label_filt def parse_snli(self,path=None): with open(path) as f: data = [json.loads(line) for line in f] data_processed = [(line['sentence1'],line['sentence2'],line['gold_label']) for line in data] return data_processed def __init__(self,tokenizer,path='../data/snli_1.0/snli_1.0_train.jsonl',filter_mode=None,num_turns=5): self.data = self.parse_snli(path) self.tokenizer = tokenizer self.premise_encoded,self.hypothesis_encoded,self.label = self._split(self.data) self.premise_encoded,self.hypothesis_encoded,self.label = self._filter(self.premise_encoded,self.hypothesis_encoded,self.label,filter_mode) self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 def __getitem__(self,index): x = [] type_x = [] lm_x = [] x += [self.speaker1] x += self.premise_encoded[index] type_x += [self.speaker1]*(len(self.premise_encoded[index])+1) # the premise part x += [self.ref_start] x += self.hypothesis_encoded[index] x += [self.eos] type_x += [self.ref_start]*(len(self.hypothesis_encoded[index])+2) # the hypothesis part label = self.label[index] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) return x,type_x,position_x,lm_x,label def __len__(self): return len(self.premise_encoded) class GptDataset_full(Dataset): def _split(self,x_y_meta): x_all = [] y_all = [] meta_all = [] aug_all = [] keyword_all = [] for x, y, meta, aug, keyword in x_y_meta: meta_all.append(meta) # update for the new data format aug = ''.join([a[1] for a in aug]) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) aug_all.append(self.tokenizer.encode(text_standardize(aug))) keyword_all.append(self.tokenizer.encode(text_standardize(keyword))) return x_all,y_all,meta_all,aug_all, keyword_all def _filt(self, length=1024): data = zip(self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded) data = [sample for sample in data if sum([len(sen) for sen in sample[0]][-self.args.num_turns:])+len(sample[1])+len(sample[3])+len(sample[4]) < 850] self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = zip(*data) self.x_encoded = list(self.x_encoded) self.y_encoded = list(self.y_encoded) self.meta = list(self.meta) self.aug_encoded = list(self.aug_encoded) self.keyword_encoded = list(self.keyword_encoded) def __init__(self,x_y_meta,tokenizer,args): self.x_y_meta = x_y_meta self.num_turns = args.num_turns self.tokenizer = tokenizer self.args = args self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = self._split(x_y_meta) self._filt() # TODO: add back filt for mix-review self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 self.augment = 5 if self.args.augment: print("Using augment sentences.") if self.args.keyword: print("Using keywords.") def __getitem__(self,index): x = [] type_x = [] lm_x = [] if self.args.augment: x += [self.augment] + self.aug_encoded[index] if self.args.keyword: x += [self.augment] + self.keyword_encoded[index] type_x += [self.augment] * len(x) is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x+=[self.speaker1] type_x += [self.speaker1]*(len(utt)+1) else: x+=[self.speaker2] type_x += [self.speaker2]*(len(utt)+1) x += utt is_speaker1 = not is_speaker1 lm_x += [-100]*len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref_start] + self.y_encoded[index] + [self.eos] type_x += [self.ref_start]*(len(self.y_encoded[index])+2) lm_x += [-100] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.Tensor(x) type_x = torch.Tensor(type_x) position_x = torch.Tensor(position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) class GptDataset_full_condition(Dataset): def _split(self, x_y_meta): x_all = [] y_all = [] meta_all = [] aug_all = [] keyword_all = [] for x, y, meta, aug, keyword in x_y_meta: meta_all.append(meta) # update for the new data format aug = ''.join([a[1] for a in aug]) x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) y_all.append(self.tokenizer.encode(text_standardize(y))) aug_all.append(self.tokenizer.encode(text_standardize(aug))) keyword_all.append(self.tokenizer.encode(text_standardize(keyword))) return x_all, y_all, meta_all, aug_all, keyword_all def _filt(self, length=1024): data = zip(self.x_encoded, self.y_encoded, self.meta, self.aug_encoded, self.keyword_encoded) data = [sample for sample in data if sum([len(sen) for sen in sample[0]][-self.args.num_turns:]) + len(sample[1]) + len(sample[3]) + len( sample[4]) < 850] self.x_encoded, self.y_encoded, self.meta, self.aug_encoded, self.keyword_encoded = zip(*data) self.x_encoded = list(self.x_encoded) self.y_encoded = list(self.y_encoded) self.meta = list(self.meta) self.aug_encoded = list(self.aug_encoded) self.keyword_encoded = list(self.keyword_encoded) def __init__(self, x_y_meta, tokenizer, args): self.x_y_meta = x_y_meta self.num_turns = args.num_turns self.tokenizer = tokenizer self.args = args self.x_encoded, self.y_encoded, self.meta, self.aug_encoded, self.keyword_encoded = self._split(x_y_meta) self._filt() # TODO: add back filt for mix-review self.ref, self.speaker1, self.speaker2 = tokenizer.ref, tokenizer.speaker1, tokenizer.speaker2 self.eos = tokenizer.eos self.augment = tokenizer.augment self.is_ref, self.is_non_ref = tokenizer.is_ref, tokenizer.is_non_ref self.code_set = set(['GIV', 'QUEST', 'SEEK', 'AF', 'EMPH', 'PWOP', 'PWP', 'CON']) def __getitem__(self, index): x = [] type_x = [] lm_x = [] is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation if self.meta[index][2] in self.code_set: x += [self.is_non_ref] type_x += [self.is_non_ref] else: x += [self.is_ref] type_x += [self.is_ref] for utt in self.x_encoded[index][-self.num_turns:]: if is_speaker1: # add the prefix special token for each utterance x += [self.speaker1] type_x += [self.speaker1] * (len(utt) + 1) else: x += [self.speaker2] type_x += [self.speaker2] * (len(utt) + 1) x += utt is_speaker1 = not is_speaker1 lm_x += [-100] * len(x) # all position for the input is masked for loss calculation total_input_length = len(x) x += [self.ref] + self.y_encoded[index] + [self.eos] type_x += [self.ref] * (len(self.y_encoded[index]) + 2) lm_x += [-100] + self.y_encoded[index] + [self.eos] position_x = list(range(len(x))) x = torch.tensor(x) type_x = torch.tensor(type_x) position_x = torch.tensor(position_x) lm_x = torch.tensor(lm_x) x_len = x.shape[0] return x, type_x, position_x, lm_x, total_input_length, self.meta[index] def __len__(self): return len(self.x_encoded) class GptDataset_KBERT_old(Dataset): def get_comet_aug_deque(self, comet_data, num_turns=5): clause_dq = deque() for comet_in, comet_out in comet_data: if comet_out == "": continue loc = int(comet_in.split()[0]) if loc >= (10 - num_turns): clause_dq.append((loc, comet_out)) return clause_dq def __init__(self, x_y_meta, tokenizer, args): self.data = x_y_meta self.num_turns = args.num_turns self.tokenizer = tokenizer self.args = args self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256 self.augment = 5 if self.args.augment: print("Using augment sentences.") if self.args.keyword: print("Using keywords.") def __getitem__(self, index): x = [] type_x = [] lm_x = [] soft_position_x = [] dq = self.get_comet_aug_deque(self.data[index][3]) # the comet info mask_info = [] context = self.data[index][0] response = self.data[index][1] is_speaker1 = bool(self.args.num_turns % 2) soft_loc = 0 # keep tract of the location of main sentences, point to the next token to be added utterance_start_loc = 0 for i in range(10 - self.args.num_turns, 10): utternace_encoded = self.tokenizer.encode(text_standardize(context[i])) # add the prefix special token for each utterance if is_speaker1: x += [self.speaker1] type_x += [self.speaker1] * (len(utternace_encoded) + 1) else: x += [self.speaker2] type_x += [self.speaker2] * (len(utternace_encoded) + 1) x += utternace_encoded utterance_end_loc = len(x) soft_position_x += list(range(soft_loc, soft_loc + len(utternace_encoded) + 1)) # add the aug, if it is the right place while len(dq) != 0 and dq[0][0] == i: comet_output = dq.popleft()[1] comet_encoded = self.tokenizer.encode(text_standardize(comet_output)) x += [self.augment] + comet_encoded type_x += [self.augment] * (len(comet_encoded) + 1) soft_position_x += list(range(soft_loc, soft_loc + len(comet_encoded) + 1)) mask_info.append([utterance_start_loc, utterance_end_loc, len(comet_encoded)+1]) # update the pointer to the new seq end, add one for the delimiter token soft_loc += len(utternace_encoded) + 1 is_speaker1 = not is_speaker1 utterance_start_loc = len(x) lm_x += [-100] * len(x) # all position for the input is masked for loss calculation total_input_length = len(x) response_encoded = self.tokenizer.encode(text_standardize(response)) x += [self.ref_start] + response_encoded + [self.eos] type_x += [self.ref_start] * (len(response_encoded) + 2) lm_x += [-100] + response_encoded + [self.eos] soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2)) x = torch.Tensor(x) type_x = torch.Tensor(type_x) soft_position_x = torch.Tensor(soft_position_x) lm_x = torch.Tensor(lm_x) x_len = x.shape[0] # process the mask attention_mask = torch.tril(torch.ones(x_len, x_len)) for u_start, u_end, branch_len in mask_info: attention_mask[u_end+branch_len+1: u_end+1:u_end+branch_len+1] = 0 # [1st token after branch: , 1st token in branch: last token in branch+1] attention_mask = attention_mask.view(1, x_len, x_len) return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask def __len__(self): return len(self.data) class GptDataset_KBERT(Dataset): def __init__(self, tokenizer, args, file_path="../data_processed/data_comet_dict"): pickle_handler = open(file_path, 'rb') self.data = pickle.load(pickle_handler) self.max_length = 510 self.tokenizer = tokenizer self.args = args self.num_turns = args.num_turns self.ref, self.speaker1, self.speaker2 = tokenizer.ref, tokenizer.speaker1, tokenizer.speaker2 self.eos = tokenizer.eos self.augment = tokenizer.augment if not self.args.kbert: self.args.kbert_mask = False self.args.kbert_position = False print("Not using kbert scheme.") if self.args.kbert_mask: print("using kbert-style attention mask") if self.args.kbert_position: print("using kbert-style soft-postional encoding") def __getitem__(self, index): # preprare variables x = [] type_x = [] lm_x = [] soft_position_x = [] attention_mask = [] # 0. unpack needed input info context = self.data[index]['context'] srl_mask = self.data[index]['srl_mask'] comet_output = self.data[index]['comet'] # a list of dict or None response = self.data[index]['response'] # 1. encode the response. response_encoded = self.tokenizer.encode(text_standardize(response)) # 2. encode each utterance. context_encoded = [] for i in range(10 - self.args.num_turns, 10): context_encoded.append(self.tokenizer.encode(text_standardize(context[i]))) # 3. encode the comet output for each utterance. comet_encoded = [] for i in range(len(comet_output)): comet_text_i = "" if comet_output[i] is None: comet_encoded.append(None) continue for rel in comet_output[i]: for candidate in comet_output[i][rel]['beams']: if candidate != 'none': comet_text_i += rel + " " + candidate + " " break comet_encoded.append(self.tokenizer.encode(text_standardize(comet_text_i))) # 4. use the encoded seq to build the input and attention mask is_speaker1 = bool(self.args.num_turns % 2) soft_loc = 0 for i in range(self.args.num_turns): # add an utterance. update x & type_x if is_speaker1: x += [self.speaker1] type_x += [self.speaker1] * (len(context_encoded[i]) + 1) else: x += [self.speaker2] type_x += [self.speaker2] * (len(context_encoded[i]) + 1) x += context_encoded[i] # update pos_x # concate aug part after x. but the index is from the last related token soft_position_x += list(range(soft_loc, soft_loc + (len(context_encoded[i]) + 1))) last_related_token_index = len(srl_mask[i]) - 1 - srl_mask[i][::-1].index(1) # add comet output if self.args.kbert: if comet_encoded[i] is not None: x += [self.augment] + comet_encoded[i] type_x += [self.augment] * (len(comet_encoded[i]) + 1) # +2 for the special token and the requirement of one-number larger than the utterance soft_position_x += list(range(soft_loc + 2 + last_related_token_index, soft_loc + 2 + last_related_token_index + (len(comet_encoded[i]) + 1))) soft_loc += (len(context_encoded[i]) + 1) is_speaker1 = not is_speaker1 lm_x += [-100] * len(x) # all position for the input is masked for loss calculation total_input_length = len(x) response_encoded = self.tokenizer.encode(text_standardize(response)) x += [self.ref] + response_encoded + [self.eos] type_x += [self.ref] * (len(response_encoded) + 2) lm_x += [-100] + response_encoded + [self.eos] soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2)) x = x[:self.max_length] type_x = type_x[:self.max_length] lm_x = lm_x[:self.max_length] soft_position_x = soft_position_x[:self.max_length] # build attention mask attention_mask = torch.tril(torch.ones(len(x), len(x))) if self.args.kbert_mask: aug_start = 0 # where the aug begin utt_start = 0 # where the utt begin for turn in range(self.args.num_turns): aug_start += len(context_encoded[turn]) + 1 # iter through every token in the comet output if comet_encoded[turn] is not None: for aug_token_pos in range(aug_start, aug_start + len(comet_encoded[turn]) + 1): # set the attention related to the aug part to be all zero attention_mask[aug_token_pos, :] = torch.zeros_like(attention_mask[aug_token_pos, :]) attention_mask[:, aug_token_pos] = torch.zeros_like(attention_mask[:, aug_token_pos]) # set attention on related token to be one for normal_token_pos in range(len(context_encoded[turn])): attention_mask[aug_token_pos, utt_start + normal_token_pos + 1] += srl_mask[turn][ normal_token_pos] # set attention on previous aug tokens to be one for previous_aug_token_poc in range(aug_start, aug_token_pos + 1): attention_mask[aug_token_pos, previous_aug_token_poc] += 1 aug_start += len(comet_encoded[turn]) + 1 utt_start += len(comet_encoded[turn]) + 1 utt_start += (len(context_encoded[turn]) + 1) x = torch.tensor(x) type_x = torch.tensor(type_x) if not self.args.kbert_position: soft_position_x = list(range(len(x))) soft_position_x = torch.tensor(soft_position_x) lm_x = torch.tensor(lm_x) return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask def __len__(self): return len(self.data) def get_data(args, tokenizer, split_size): """ Return the data loaders needed for training and evaluation. :param args: command line arguments. :param tokenizer: the tokenizer used in preparing the data. :param split_size: the portion of train, test, validation set. :return data_loader: The data loader for the training set. :return val_loader: The data loader for the validation set. """ # random.seed(args.seed) # torch.random.manual_seed(args.seed) # torch.cuda.manual_seed(args.seed) # torch.manual_seed(args.seed) if args.special_input: print("Using mutated data.") pickle_handler = open('../data_processed/' + args.special_input, 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) # TODO: not finished # #======================origin without kbert====== # elif not args.kbert: # print("Using full data.") # pickle_handler = open('../data_processed/x_y_with_comet', 'rb') # TODO: change back to the old data. # x_y_meta = pickle.load(pickle_handler) # gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args) elif args.conditional: print("using conditional generation data") file_path = "../data_processed/" data_train = pickle.load(open(file_path+'train_ref', 'rb')) + pickle.load(open(file_path+'train_non_ref','rb')) gpt_train= GptDataset_full_condition(data_train, tokenizer, args=args) data_test = pickle.load(open(file_path+'test_ref', 'rb')) gpt_test = GptDataset_full_condition(data_test, tokenizer, args=args) data_val = pickle.load(open(file_path+'test_ref', 'rb')) gpt_val = GptDataset_full_condition(data_val, tokenizer, args=args) elif args.kbert: print("Using KBERT data") gpt_data = GptDataset_KBERT(tokenizer, args=args) print("Dataset initialized. There are {} samples.".format(len(gpt_data))) test_size = int(len(gpt_data) * split_size['test']) val_size = int(len(gpt_data) * split_size['val']) gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data, [len(gpt_data) - test_size - val_size, test_size, val_size]) else: print("Using full data.") file_path = "../data_processed/" data_train = pickle.load(open(file_path+'train_ref', 'rb')) gpt_train= GptDataset_full_condition(data_train, tokenizer, args=args) data_test = pickle.load(open(file_path+'test_ref', 'rb')) gpt_test = GptDataset_full_condition(data_test, tokenizer, args=args) data_val = pickle.load(open(file_path+'test_ref', 'rb')) gpt_val = GptDataset_full_condition(data_val, tokenizer, args=args) # # ======= one-time plug in======== # x_y_meta_pre = pickle.load(open("../data_processed/data_comet_dict",'rb')) # x_y_meta = [] # for x in x_y_meta_pre: # context_i = x['context'] # response_i = x['response'] # meta_i = x['meta'] # x_y_meta.append([context_i, response_i, meta_i, "", ""]) # # with open('../data_processed/train_ref', 'wb') as f: # idx = gpt_train.indices # gpt_train =[x_y_meta[i] for i in idx] # pickle.dump(gpt_train, f) # with open('../data_processed/test_ref', 'wb') as f: # idx = gpt_test.indices # gpt_test = [x_y_meta[i] for i in idx] # pickle.dump(gpt_test, f) # with open('../data_processed/val_ref', 'wb') as f: # idx = gpt_val.indices # gpt_val = [x_y_meta[i] for i in idx] # pickle.dump(gpt_val, f) if 'train_batch_size' not in args: args.train_batch_size = 1 if args.conditional: data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn_conditional) test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn_conditional) val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn_conditional) else: data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) return data_loader, test_loader, val_loader def prepare_mix_review(args, tokenizer): print("Preparing Alexander dataset") pickle_handler = open('../data_processed/data_alex', 'rb') data = pickle.load(pickle_handler) gpt_alex = GptDataset_full(data, tokenizer, args=args) print("Alexander dataset prepared. Has {} samples".format(len(gpt_alex))) return gpt_alex def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0.7, collate_fn=collate_fn): mix_amount = int(mix_ratio*(0.7**epoch)*len(gpt_train)) gpt_alex_active,_ = torch.utils.data.random_split(gpt_alex, [mix_amount, len(gpt_alex)-mix_amount]) data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) return data_loader