diff --git a/code/gpt_loader/__init__.py b/code/gpt_loader/__init__.py index ad82fded6eec648ab7d3fb1181f395c1f35b5126..d083d72ca7c1ef5628bb74d9f6c7dbcba0b8c9ae 100644 --- a/code/gpt_loader/__init__.py +++ b/code/gpt_loader/__init__.py @@ -1 +1 @@ -from .load_data import GptDataset,collate_fn,collate_fn_nli,GptDataset_nli,SnliDataset,GptDataset_aug,GptDataset_keyword, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review, XLDataset_nli +from .load_data import GptDataset,collate_fn,collate_fn_nli,SnliDataset,GptDataset_aug, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review diff --git a/code/gpt_loader/load_data.py b/code/gpt_loader/load_data.py index c873c275f271c3b8dd218007e7c6fbd87788815a..9c9b40f5c889e2cd08a8c916bd4ceea73d4d08bd 100644 --- a/code/gpt_loader/load_data.py +++ b/code/gpt_loader/load_data.py @@ -6,6 +6,7 @@ import random import sys import pickle from tqdm import tqdm +from collections import deque import copy sys.path.append("..") from utils import text_standardize @@ -98,29 +99,6 @@ class GptDataset(Dataset): def __len__(self): return len(self.x_encoded) -# class GptDataset_keyword(Dataset): -# def _split(self,x_y_meta): -# x_all = [] -# y_all = [] -# meta_all = [] -# aug_all = [] -# for x,y,meta,aug in x_y_meta: -# meta_all.append(meta) -# x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) -# y_all.append(self.tokenizer.encode(text_standardize(y))) -# key_word.append(self.tokenizer.encode(text_standardize(aug))) - -# return x_all,y_all,meta_all,aug_all - -# def __init__(self,x_y_meta,tokenizer,num_turns=5): -# self.x_y_meta = x_y_meta -# self.num_turns = num_turns -# self.tokenizer = tokenizer -# self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta) -# self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 -# self.augment = 5 -# self.keyword = 10 # '+' - class GptDataset_aug(Dataset): def _split(self,x_y_meta): x_all = [] @@ -180,6 +158,7 @@ class GptDataset_aug(Dataset): return x,type_x,position_x,lm_x,total_input_length,self.meta[index] def __len__(self): return len(self.x_encoded) + def collate_fn(data): """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, @@ -293,119 +272,6 @@ def collate_fn_keyword(data): keyword_x = keyword_x.cuda() return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x)) -class GptDataset_keyword(Dataset): - def _split(self, x_y_meta): - x_all = [] - y_all = [] - meta_all = [] - keyword_all = [] - for x, y, meta, keyword in x_y_meta: - meta_all.append(meta) - x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x]) - y_all.append(self.tokenizer.encode(text_standardize(y))) - keyword_all.append(self.tokenizer.encode(text_standardize(keyword))) - return x_all, y_all, meta_all, keyword_all - - def __init__(self, x_y_meta, tokenizer, num_turns=5): - - self.x_y_meta = x_y_meta - self.num_turns = num_turns - self.tokenizer = tokenizer - self.x_encoded, self.y_encoded, self.meta, self.keyword = self._split(x_y_meta) - self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256 - - def __getitem__(self, index): - x = [] - type_x = [] - lm_x = [] - is_speaker1 = bool(len(self.x_encoded[index]) % 2) # which speaker start the conversation - - for utt in self.x_encoded[index]: - if is_speaker1: # add the prefix special token for each utterance - x += [self.speaker1] - type_x += [self.speaker1] * (len(utt) + 1) - else: - x += [self.speaker2] - type_x += [self.speaker2] * (len(utt) + 1) - x += utt - is_speaker1 = not is_speaker1 - lm_x += [-1] * len(x) # all position for the input is masked for loss calculation - - total_input_length = len(x) - - x += [self.ref_start] + self.y_encoded[index] + [self.eos] - - type_x += [self.ref_start] * (len(self.y_encoded[index]) + 2) - lm_x += [-1] + self.y_encoded[index] + [self.eos] - position_x = list(range(len(x))) - - x = torch.Tensor(x) - type_x = torch.Tensor(type_x) - position_x = torch.Tensor(position_x) - lm_x = torch.Tensor(lm_x) - x_len = x.shape[0] - - keyword_x = [] + self.keyword[index] - keyword_x = torch.Tensor(keyword_x) - return x, type_x, position_x, lm_x, total_input_length, self.meta[index], keyword_x - - def __len__(self): - return len(self.x_encoded) - -# class GptDataset_nli(GptDataset): -# def __init__(self, x_y_meta, tokenizer, filter_mode=None,num_turns=5,augment=True): -# super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, num_turns=num_turns) -# self.augment = augment -# self.pos_len = len(self.x_encoded) - -# def __len__(self): -# if self.augment: -# return 2 * len(self.x_encoded) -# else: -# return len(self.x_encoded) - -# def __getitem__(self,index): -# # client utterances - premise -speaker1 -# # response - hypothesis - ref_start -# true_index = index -# if index >= self.pos_len: -# index = index - self.pos_len - -# x = [] -# type_x = [] -# lm_x = [] -# is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation - -# x+=[self.speaker1] -# type_x += [self.speaker1] -# for utt in self.x_encoded[index][-self.num_turns:]: -# if is_speaker1: # add the prefix special token for each utterance -# type_x += [self.speaker1]*(len(utt)) -# x += utt -# # else: -# # x+=[self.speaker2] -# # type_x += [self.speaker2]*(len(utt)+1) -# # x += utt -# is_speaker1 = not is_speaker1 - -# total_input_length = len(x) - -# if true_index >= self.pos_len: -# rand_index = random.randint(0,self.pos_len-1) -# x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos] -# type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2) -# else: -# x += [self.ref_start] + self.y_encoded[index] + [self.eos] -# type_x += [self.ref_start]*(len(self.y_encoded[index])+2) -# position_x = list(range(len(x))) - -# x = torch.Tensor(x) -# type_x = torch.Tensor(type_x) -# position_x = torch.Tensor(position_x) -# x_len = x.shape[0] -# label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) -# return x,type_x,position_x,lm_x, label - class SnliDataset(Dataset): """Take a list of samples with form [[x,...],y,meta] """ @@ -561,151 +427,89 @@ class GptDataset_full(Dataset): def __len__(self): return len(self.x_encoded) +class GptDataset_KBERT(Dataset): + def get_comet_aug_deque(self, comet_data, num_turns=5): + clause_dq = deque() + for comet_in, comet_out in comet_data: + if comet_out == "": + continue + loc = int(comet_in.split()[0]) + if loc >= (10 - num_turns): + clause_dq.append((loc, comet_out)) + return clause_dq + + def __init__(self, x_y_meta, tokenizer, args): + self.data = x_y_meta + self.num_turns = args.num_turns + self.tokenizer = tokenizer + self.args = args + self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256 + self.augment = 5 -class GptDataset_nli(GptDataset_full): - def __init__(self, x_y_meta, tokenizer, args, infer=False): - super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args) - self.pos_len = len(self.x_encoded) - self.num_turns = 5 - self.infer = infer - def __len__(self): - if self.infer: - return len(self.x_encoded) - else: - return 2 * len(self.x_encoded) - - def __getitem__(self,index): - # client utterances - premise -speaker1 - # response - hypothesis - ref_start - true_index = index - if index >= self.pos_len: - index = index - self.pos_len + if self.args.augment: + print("Using augment sentences.") + if self.args.keyword: + print("Using keywords.") + def __getitem__(self, index): x = [] type_x = [] lm_x = [] - is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation - - x+=[self.speaker1] - type_x += [self.speaker1] - for utt in self.x_encoded[index][-self.num_turns:]: - if is_speaker1: # add the prefix special token for each utterance - type_x += [self.speaker1]*(len(utt)) - x += utt - # else: - # x+=[self.speaker2] - # type_x += [self.speaker2]*(len(utt)+1) - # x += utt - is_speaker1 = not is_speaker1 + soft_position_x = [] - total_input_length = len(x) - - if true_index >= self.pos_len: - rand_index = random.randint(0,self.pos_len-1) - x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos] - type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2) - else: - x += [self.ref_start] + self.y_encoded[index] + [self.eos] - type_x += [self.ref_start]*(len(self.y_encoded[index])+2) - position_x = list(range(len(x))) + dq = self.get_comet_aug_deque(self.data[index][3]) # the comet info + context = self.data[index][0] + response = self.data[index][1] - x = torch.Tensor(x) - type_x = torch.Tensor(type_x) - position_x = torch.Tensor(position_x) - x_len = x.shape[0] - label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) + is_speaker1 = bool(self.args.num_turns % 2) + soft_loc = 0 # keep tract of the location of main sentences, point to the next token to be added + for i in range(10 - self.args.num_turns, 10): + utternace_encoded = self.tokenizer.encode(text_standardize(context[i])) - return x,type_x,position_x,lm_x, label + # add the prefix special token for each utterance + if is_speaker1: + x += [self.speaker1] + type_x += [self.speaker1] * (len(utternace_encoded) + 1) + else: + x += [self.speaker2] + type_x += [self.speaker2] * (len(utternace_encoded) + 1) + x += utternace_encoded -class XLDataset_nli(GptDataset_nli): - def __init__(self, x_y_meta, tokenizer, args, infer=False): - super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args) - self.pos_len = len(self.x_encoded) - self.num_turns = 5 - self.infer = infer -# self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256 - self.pad, self.sep, self.cls = 5, 4, 3 - self.unk, self.s, self.s_bar = 0, 1, 2 + soft_position_x += list(range(soft_loc, soft_loc + len(utternace_encoded) + 1)) - def __len__(self): - if self.infer: - return len(self.x_encoded) - else: - return 2 * len(self.x_encoded) + # add the aug, if it is the right place + if len(dq) != 0 and dq[0][0] == i: + comet_output = dq.popleft()[1] + comet_encoded = self.tokenizer.encode(text_standardize(comet_output)) + x += [self.augment] + comet_encoded + type_x += [self.augment] * (len(comet_encoded) + 1) + soft_position_x += list(range(soft_loc, soft_loc + len(comet_encoded) + 1)) - def __getitem__(self,index): - # client utterances - premise -speaker1 - # response - hypothesis - ref_start - true_index = index - if index >= self.pos_len: - index = index - self.pos_len - - - x = [] - type_x = [] - lm_x = [] - mask_x = [] - is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation - - - for utt in self.x_encoded[index][-self.num_turns:]: - if is_speaker1: # add the prefix special token for each utterance - type_x += [self.unk]*(len(utt)) - x += utt - else: - type_x += [self.unk]*(len(utt)) - x += utt + # update the pointer to the new seq end, add one for the delimiter token + soft_loc += len(utternace_encoded) + 1 is_speaker1 = not is_speaker1 -# import pdb;pdb.set_trace() - x += [self.sep] - type_x += [self.unk] - + + lm_x += [-100] * len(x) # all position for the input is masked for loss calculation total_input_length = len(x) - - if true_index >= self.pos_len: - rand_index = random.randint(0,self.pos_len-1) - x += self.y_encoded[rand_index] + [self.sep, self.cls] - type_x += [self.s]*(len(self.y_encoded[rand_index])+1) + [self.s_bar] - else: -# x += [self.ref_start] + self.y_encoded[index] + [self.eos] -# type_x += [self.ref_start]*(len(self.y_encoded[index])+2) - x += self.y_encoded[index] + [self.sep, self.cls] - type_x += [self.s]*(len(self.y_encoded[index])+1) + [self.s_bar] - - position_x = list(range(len(x))) - mask_x = [self.s] * len(x) - -# #### -# x = x[-100:] -# mask_x = mask_x[-100:] -# type_x = type_x[-100:] - - # left padding - x = [self.pad] * (self.args.max_length-len(x)) + x[-self.args.max_length:] - mask_x = [self.unk] * (self.args.max_length-len(mask_x)) + mask_x[-self.args.max_length:] - type_x = [self.sep] * (self.args.max_length-len(type_x)) + type_x[-self.args.max_length:] - - x = torch.Tensor(x).long() - mask_x = torch.Tensor(mask_x).long() - type_x = torch.Tensor(type_x).long() - position_x = torch.Tensor(position_x) - - x_len = x.shape[0] - - label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1) - # label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(0) - - if USE_CUDA: - x = x.cuda() - mask_x = mask_x.cuda() - type_x = type_x.cuda() - label = label.cuda() -# return x, mask_x, type_x, label, position_x, lm_x - return x, mask_x, type_x, label + response_encoded = self.tokenizer.encode(text_standardize(response)) + x += [self.ref_start] + response_encoded + [self.eos] + + type_x += [self.ref_start] * (len(response_encoded) + 2) + lm_x += [-100] + response_encoded + [self.eos] + + soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2)) + x = torch.Tensor(x) + type_x = torch.Tensor(type_x) + soft_position_x = torch.Tensor(soft_position_x) + lm_x = torch.Tensor(lm_x) + x_len = x.shape[0] + return x, type_x, soft_position_x, lm_x, total_input_length, self.data[index][2] + def __len__(self): + return len(self.data) def get_data(args, tokenizer, split_size): """ @@ -725,11 +529,16 @@ def get_data(args, tokenizer, split_size): pickle_handler = open('../data_processed/' + args.special_input, 'rb') x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) - else: + elif not args.kbert: print("Using full data.") pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data. x_y_meta = pickle.load(pickle_handler) gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args) + else: + print("Using KBERT data") + pickle_handler = open("../data_processed/x_y_with_comet",'rb') + x_y_meta = pickle.load(pickle_handler) + gpt_data = GptDataset_KBERT(x_y_meta, tokenizer, args=args) print("Dataset initialized. There are {} samples.".format(len(gpt_data))) test_size = int(len(gpt_data) * split_size['test']) @@ -769,55 +578,4 @@ def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0 data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn) - return data_loader - -def get_data_old(args, tokenizer, split_size): - """ - Return the data loaders needed for training and evaluation. - :param args: command line arguments. - :param tokenizer: the tokenizer used in preparing the data. - :param split_size: the portion of train, test, validation set. - :return data_loader: The data loader for the training set. - :return val_loader: The data loader for the validation set. - """ - if args.special_input: - print("Using mutated data.") - pickle_handler = open('../data_processed/' + args.special_input, 'rb') - x_y_meta = pickle.load(pickle_handler) - if args.augment: - print("testing keywords with augment loader.") - gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns) - else: - gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) - elif args.augment: - print("Using augmented data") - pickle_handler = open('../data_processed/x_y_meta_aug', 'rb') - x_y_meta = pickle.load(pickle_handler) - gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns) - elif args.keyword: - print("Using keyword cross attention") - pickle_handler = open('../data_processed/x_y_meta_keyword', 'rb') - x_y_meta = pickle.load(pickle_handler) - gpt_data = GptDataset_keyword(x_y_meta, tokenizer) - else: - print("Using vanilla data.") - pickle_handler = open('../data_processed/x_y_meta', 'rb') - x_y_meta = pickle.load(pickle_handler) - gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns) - - print("Dataset initialized. There are {} samples.".format(len(gpt_data))) - test_size = int(len(gpt_data) * split_size['test']) - val_size = int(len(gpt_data) * split_size['val']) - gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data, - [len(gpt_data) - test_size - val_size, test_size, - val_size]) - if args.keyword: - data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, - collate_fn=collate_fn_keyword) - val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, - collate_fn=collate_fn_keyword) - else: - data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True, - collate_fn=collate_fn) - val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn) - return data_loader, val_loader + return data_loader \ No newline at end of file diff --git a/code/gpt_tuning.py b/code/gpt_tuning.py index 550ceb50727393c377b22f4e9cf5020e5fe418ae..7934382fd9cfdf5837d51d14a4a3d4130dc4b861 100644 --- a/code/gpt_tuning.py +++ b/code/gpt_tuning.py @@ -16,7 +16,7 @@ from torch.autograd import Variable from tqdm import tqdm, trange import random from utils import clean_text, text_standardize, construct_grouped_parameters, get_unfreezing_funcs -from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, prepare_mix_review, update_mix_review, get_data +from gpt_loader import GptDataset, collate_fn,collate_fn_keyword, prepare_mix_review, update_mix_review, get_data # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows import logging @@ -76,6 +76,7 @@ def parse_arguments(): parser.add_argument('--use_disc_lr', action='store_true') parser.add_argument('--use_unfreezing', action='store_true') parser.add_argument('--num_turns', type=int, default=5) + parser.add_argument('--kbert', action='store_true') args = parser.parse_args() print(args) return args @@ -91,11 +92,11 @@ def load_model(args): # ====== Load GPT2 model ======== model_dir = '../models/' + args.model_dir # model = GPT2LMHeadModel.from_pretrained(model_dir) - model = GPT2LMHeadModel.from_pretrained('gpt2-medium') + model = GPT2LMHeadModel.from_pretrained('gpt2') if USE_CUDA: model.cuda() # tokenizer = GPT2Tokenizer.from_pretrained(model_dir) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') print('Model loaded.') return model, tokenizer @@ -120,7 +121,7 @@ def main(): data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # gpt_alex = prepare_mix_review(args, tokenizer) # data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data - # import pdb;pdb.set_trace() + import pdb;pdb.set_trace() # ========== Prepare optimizer ============= # the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size @@ -153,7 +154,6 @@ def main(): x, type_x, pos_x, lm_x, x_len, _ = sample keyword_x = None input_len = x_len[0] - lm_x[:, x_len[0] + 1 + args.first_K_tokens:-1] = -1 # loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, key_word=keyword_x, # use_keyword=args.cross_attention)[0] diff --git a/code/run_compare_aug.sh b/code/run_compare_aug.sh index 169590b5265b7c2a7a50acdb610c34232dc3a3dc..df718d7f4e429a195564da5632818ea7ca3f8aaa 100644 --- a/code/run_compare_aug.sh +++ b/code/run_compare_aug.sh @@ -3,13 +3,13 @@ pwd # python retrieve_candidate.py --model_dir mi_nli -mkdir -p ../models/mi_tuned_5turn -python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5 -python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95 - -mkdir -p ../models/mi_tuned_aug -python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment -python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment --top_p 0.95 +#mkdir -p ../models/mi_tuned_5turn +#python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5 +#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95 +# +#mkdir -p ../models/mi_tuned_aug +#python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment +#python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment --top_p 0.95 # mkdir -p ../models/mi_tuned_keyword #python gpt_tuning.py --output_dir mi_tuned_keyword --num_train_epochs 10 --num_turns 5 --keyword @@ -18,4 +18,10 @@ python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_tu # mkdir -p ../models/mi_tuned_both # python gpt_tuning.py --output_dir mi_tuned_both --num_train_epochs 10 --num_turns 10 --keyword --augment # python gpt_sample.py --model_dir mi_tuned_both --output_dir mi_tuned_both --num_turns 10 --keyword --augment --top_p 0.95 + +mkdir -p ../models/mi_tuned_kbert +python gpt_tuning.py --output_dir mi_tuned_kbert --num_train_epochs 10 --num_turns 5 --kbert +#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95 + echo "Finished." +