Skip to content
Snippets Groups Projects
Commit f3f3048d authored by shensq's avatar shensq
Browse files

Add kbert soft-position. Attention mask is not implemented

parent 98e0ca30
No related branches found
No related tags found
No related merge requests found
from .load_data import GptDataset,collate_fn,collate_fn_nli,GptDataset_nli,SnliDataset,GptDataset_aug,GptDataset_keyword, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review, XLDataset_nli
from .load_data import GptDataset,collate_fn,collate_fn_nli,SnliDataset,GptDataset_aug, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review
......@@ -6,6 +6,7 @@ import random
import sys
import pickle
from tqdm import tqdm
from collections import deque
import copy
sys.path.append("..")
from utils import text_standardize
......@@ -98,29 +99,6 @@ class GptDataset(Dataset):
def __len__(self):
return len(self.x_encoded)
# class GptDataset_keyword(Dataset):
# def _split(self,x_y_meta):
# x_all = []
# y_all = []
# meta_all = []
# aug_all = []
# for x,y,meta,aug in x_y_meta:
# meta_all.append(meta)
# x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
# y_all.append(self.tokenizer.encode(text_standardize(y)))
# key_word.append(self.tokenizer.encode(text_standardize(aug)))
# return x_all,y_all,meta_all,aug_all
# def __init__(self,x_y_meta,tokenizer,num_turns=5):
# self.x_y_meta = x_y_meta
# self.num_turns = num_turns
# self.tokenizer = tokenizer
# self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta)
# self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
# self.augment = 5
# self.keyword = 10 # '+'
class GptDataset_aug(Dataset):
def _split(self,x_y_meta):
x_all = []
......@@ -180,6 +158,7 @@ class GptDataset_aug(Dataset):
return x,type_x,position_x,lm_x,total_input_length,self.meta[index]
def __len__(self):
return len(self.x_encoded)
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
We should build a custom collate_fn rather than using default collate_fn,
......@@ -293,119 +272,6 @@ def collate_fn_keyword(data):
keyword_x = keyword_x.cuda()
return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x))
class GptDataset_keyword(Dataset):
def _split(self, x_y_meta):
x_all = []
y_all = []
meta_all = []
keyword_all = []
for x, y, meta, keyword in x_y_meta:
meta_all.append(meta)
x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
y_all.append(self.tokenizer.encode(text_standardize(y)))
keyword_all.append(self.tokenizer.encode(text_standardize(keyword)))
return x_all, y_all, meta_all, keyword_all
def __init__(self, x_y_meta, tokenizer, num_turns=5):
self.x_y_meta = x_y_meta
self.num_turns = num_turns
self.tokenizer = tokenizer
self.x_encoded, self.y_encoded, self.meta, self.keyword = self._split(x_y_meta)
self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256
def __getitem__(self, index):
x = []
type_x = []
lm_x = []
is_speaker1 = bool(len(self.x_encoded[index]) % 2) # which speaker start the conversation
for utt in self.x_encoded[index]:
if is_speaker1: # add the prefix special token for each utterance
x += [self.speaker1]
type_x += [self.speaker1] * (len(utt) + 1)
else:
x += [self.speaker2]
type_x += [self.speaker2] * (len(utt) + 1)
x += utt
is_speaker1 = not is_speaker1
lm_x += [-1] * len(x) # all position for the input is masked for loss calculation
total_input_length = len(x)
x += [self.ref_start] + self.y_encoded[index] + [self.eos]
type_x += [self.ref_start] * (len(self.y_encoded[index]) + 2)
lm_x += [-1] + self.y_encoded[index] + [self.eos]
position_x = list(range(len(x)))
x = torch.Tensor(x)
type_x = torch.Tensor(type_x)
position_x = torch.Tensor(position_x)
lm_x = torch.Tensor(lm_x)
x_len = x.shape[0]
keyword_x = [] + self.keyword[index]
keyword_x = torch.Tensor(keyword_x)
return x, type_x, position_x, lm_x, total_input_length, self.meta[index], keyword_x
def __len__(self):
return len(self.x_encoded)
# class GptDataset_nli(GptDataset):
# def __init__(self, x_y_meta, tokenizer, filter_mode=None,num_turns=5,augment=True):
# super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, num_turns=num_turns)
# self.augment = augment
# self.pos_len = len(self.x_encoded)
# def __len__(self):
# if self.augment:
# return 2 * len(self.x_encoded)
# else:
# return len(self.x_encoded)
# def __getitem__(self,index):
# # client utterances - premise -speaker1
# # response - hypothesis - ref_start
# true_index = index
# if index >= self.pos_len:
# index = index - self.pos_len
# x = []
# type_x = []
# lm_x = []
# is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
# x+=[self.speaker1]
# type_x += [self.speaker1]
# for utt in self.x_encoded[index][-self.num_turns:]:
# if is_speaker1: # add the prefix special token for each utterance
# type_x += [self.speaker1]*(len(utt))
# x += utt
# # else:
# # x+=[self.speaker2]
# # type_x += [self.speaker2]*(len(utt)+1)
# # x += utt
# is_speaker1 = not is_speaker1
# total_input_length = len(x)
# if true_index >= self.pos_len:
# rand_index = random.randint(0,self.pos_len-1)
# x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos]
# type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2)
# else:
# x += [self.ref_start] + self.y_encoded[index] + [self.eos]
# type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
# position_x = list(range(len(x)))
# x = torch.Tensor(x)
# type_x = torch.Tensor(type_x)
# position_x = torch.Tensor(position_x)
# x_len = x.shape[0]
# label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
# return x,type_x,position_x,lm_x, label
class SnliDataset(Dataset):
"""Take a list of samples with form [[x,...],y,meta]
"""
......@@ -561,151 +427,89 @@ class GptDataset_full(Dataset):
def __len__(self):
return len(self.x_encoded)
class GptDataset_KBERT(Dataset):
def get_comet_aug_deque(self, comet_data, num_turns=5):
clause_dq = deque()
for comet_in, comet_out in comet_data:
if comet_out == "":
continue
loc = int(comet_in.split()[0])
if loc >= (10 - num_turns):
clause_dq.append((loc, comet_out))
return clause_dq
def __init__(self, x_y_meta, tokenizer, args):
self.data = x_y_meta
self.num_turns = args.num_turns
self.tokenizer = tokenizer
self.args = args
self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256
self.augment = 5
class GptDataset_nli(GptDataset_full):
def __init__(self, x_y_meta, tokenizer, args, infer=False):
super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args)
self.pos_len = len(self.x_encoded)
self.num_turns = 5
self.infer = infer
def __len__(self):
if self.infer:
return len(self.x_encoded)
else:
return 2 * len(self.x_encoded)
def __getitem__(self,index):
# client utterances - premise -speaker1
# response - hypothesis - ref_start
true_index = index
if index >= self.pos_len:
index = index - self.pos_len
if self.args.augment:
print("Using augment sentences.")
if self.args.keyword:
print("Using keywords.")
def __getitem__(self, index):
x = []
type_x = []
lm_x = []
is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
x+=[self.speaker1]
type_x += [self.speaker1]
for utt in self.x_encoded[index][-self.num_turns:]:
if is_speaker1: # add the prefix special token for each utterance
type_x += [self.speaker1]*(len(utt))
x += utt
# else:
# x+=[self.speaker2]
# type_x += [self.speaker2]*(len(utt)+1)
# x += utt
is_speaker1 = not is_speaker1
soft_position_x = []
total_input_length = len(x)
if true_index >= self.pos_len:
rand_index = random.randint(0,self.pos_len-1)
x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos]
type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2)
else:
x += [self.ref_start] + self.y_encoded[index] + [self.eos]
type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
position_x = list(range(len(x)))
dq = self.get_comet_aug_deque(self.data[index][3]) # the comet info
context = self.data[index][0]
response = self.data[index][1]
x = torch.Tensor(x)
type_x = torch.Tensor(type_x)
position_x = torch.Tensor(position_x)
x_len = x.shape[0]
label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
is_speaker1 = bool(self.args.num_turns % 2)
soft_loc = 0 # keep tract of the location of main sentences, point to the next token to be added
for i in range(10 - self.args.num_turns, 10):
utternace_encoded = self.tokenizer.encode(text_standardize(context[i]))
return x,type_x,position_x,lm_x, label
# add the prefix special token for each utterance
if is_speaker1:
x += [self.speaker1]
type_x += [self.speaker1] * (len(utternace_encoded) + 1)
else:
x += [self.speaker2]
type_x += [self.speaker2] * (len(utternace_encoded) + 1)
x += utternace_encoded
class XLDataset_nli(GptDataset_nli):
def __init__(self, x_y_meta, tokenizer, args, infer=False):
super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args)
self.pos_len = len(self.x_encoded)
self.num_turns = 5
self.infer = infer
# self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
self.pad, self.sep, self.cls = 5, 4, 3
self.unk, self.s, self.s_bar = 0, 1, 2
soft_position_x += list(range(soft_loc, soft_loc + len(utternace_encoded) + 1))
def __len__(self):
if self.infer:
return len(self.x_encoded)
else:
return 2 * len(self.x_encoded)
# add the aug, if it is the right place
if len(dq) != 0 and dq[0][0] == i:
comet_output = dq.popleft()[1]
comet_encoded = self.tokenizer.encode(text_standardize(comet_output))
x += [self.augment] + comet_encoded
type_x += [self.augment] * (len(comet_encoded) + 1)
soft_position_x += list(range(soft_loc, soft_loc + len(comet_encoded) + 1))
def __getitem__(self,index):
# client utterances - premise -speaker1
# response - hypothesis - ref_start
true_index = index
if index >= self.pos_len:
index = index - self.pos_len
x = []
type_x = []
lm_x = []
mask_x = []
is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
for utt in self.x_encoded[index][-self.num_turns:]:
if is_speaker1: # add the prefix special token for each utterance
type_x += [self.unk]*(len(utt))
x += utt
else:
type_x += [self.unk]*(len(utt))
x += utt
# update the pointer to the new seq end, add one for the delimiter token
soft_loc += len(utternace_encoded) + 1
is_speaker1 = not is_speaker1
# import pdb;pdb.set_trace()
x += [self.sep]
type_x += [self.unk]
lm_x += [-100] * len(x) # all position for the input is masked for loss calculation
total_input_length = len(x)
if true_index >= self.pos_len:
rand_index = random.randint(0,self.pos_len-1)
x += self.y_encoded[rand_index] + [self.sep, self.cls]
type_x += [self.s]*(len(self.y_encoded[rand_index])+1) + [self.s_bar]
else:
# x += [self.ref_start] + self.y_encoded[index] + [self.eos]
# type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
x += self.y_encoded[index] + [self.sep, self.cls]
type_x += [self.s]*(len(self.y_encoded[index])+1) + [self.s_bar]
position_x = list(range(len(x)))
mask_x = [self.s] * len(x)
# ####
# x = x[-100:]
# mask_x = mask_x[-100:]
# type_x = type_x[-100:]
# left padding
x = [self.pad] * (self.args.max_length-len(x)) + x[-self.args.max_length:]
mask_x = [self.unk] * (self.args.max_length-len(mask_x)) + mask_x[-self.args.max_length:]
type_x = [self.sep] * (self.args.max_length-len(type_x)) + type_x[-self.args.max_length:]
x = torch.Tensor(x).long()
mask_x = torch.Tensor(mask_x).long()
type_x = torch.Tensor(type_x).long()
position_x = torch.Tensor(position_x)
x_len = x.shape[0]
label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
# label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(0)
if USE_CUDA:
x = x.cuda()
mask_x = mask_x.cuda()
type_x = type_x.cuda()
label = label.cuda()
# return x, mask_x, type_x, label, position_x, lm_x
return x, mask_x, type_x, label
response_encoded = self.tokenizer.encode(text_standardize(response))
x += [self.ref_start] + response_encoded + [self.eos]
type_x += [self.ref_start] * (len(response_encoded) + 2)
lm_x += [-100] + response_encoded + [self.eos]
soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2))
x = torch.Tensor(x)
type_x = torch.Tensor(type_x)
soft_position_x = torch.Tensor(soft_position_x)
lm_x = torch.Tensor(lm_x)
x_len = x.shape[0]
return x, type_x, soft_position_x, lm_x, total_input_length, self.data[index][2]
def __len__(self):
return len(self.data)
def get_data(args, tokenizer, split_size):
"""
......@@ -725,11 +529,16 @@ def get_data(args, tokenizer, split_size):
pickle_handler = open('../data_processed/' + args.special_input, 'rb')
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
else:
elif not args.kbert:
print("Using full data.")
pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data.
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args)
else:
print("Using KBERT data")
pickle_handler = open("../data_processed/x_y_with_comet",'rb')
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset_KBERT(x_y_meta, tokenizer, args=args)
print("Dataset initialized. There are {} samples.".format(len(gpt_data)))
test_size = int(len(gpt_data) * split_size['test'])
......@@ -769,55 +578,4 @@ def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0
data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn)
return data_loader
def get_data_old(args, tokenizer, split_size):
"""
Return the data loaders needed for training and evaluation.
:param args: command line arguments.
:param tokenizer: the tokenizer used in preparing the data.
:param split_size: the portion of train, test, validation set.
:return data_loader: The data loader for the training set.
:return val_loader: The data loader for the validation set.
"""
if args.special_input:
print("Using mutated data.")
pickle_handler = open('../data_processed/' + args.special_input, 'rb')
x_y_meta = pickle.load(pickle_handler)
if args.augment:
print("testing keywords with augment loader.")
gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns)
else:
gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
elif args.augment:
print("Using augmented data")
pickle_handler = open('../data_processed/x_y_meta_aug', 'rb')
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns)
elif args.keyword:
print("Using keyword cross attention")
pickle_handler = open('../data_processed/x_y_meta_keyword', 'rb')
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset_keyword(x_y_meta, tokenizer)
else:
print("Using vanilla data.")
pickle_handler = open('../data_processed/x_y_meta', 'rb')
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
print("Dataset initialized. There are {} samples.".format(len(gpt_data)))
test_size = int(len(gpt_data) * split_size['test'])
val_size = int(len(gpt_data) * split_size['val'])
gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data,
[len(gpt_data) - test_size - val_size, test_size,
val_size])
if args.keyword:
data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn_keyword)
val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False,
collate_fn=collate_fn_keyword)
else:
data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
collate_fn=collate_fn)
val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
return data_loader, val_loader
return data_loader
\ No newline at end of file
......@@ -16,7 +16,7 @@ from torch.autograd import Variable
from tqdm import tqdm, trange
import random
from utils import clean_text, text_standardize, construct_grouped_parameters, get_unfreezing_funcs
from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, prepare_mix_review, update_mix_review, get_data
from gpt_loader import GptDataset, collate_fn,collate_fn_keyword, prepare_mix_review, update_mix_review, get_data
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
......@@ -76,6 +76,7 @@ def parse_arguments():
parser.add_argument('--use_disc_lr', action='store_true')
parser.add_argument('--use_unfreezing', action='store_true')
parser.add_argument('--num_turns', type=int, default=5)
parser.add_argument('--kbert', action='store_true')
args = parser.parse_args()
print(args)
return args
......@@ -91,11 +92,11 @@ def load_model(args):
# ====== Load GPT2 model ========
model_dir = '../models/' + args.model_dir
# model = GPT2LMHeadModel.from_pretrained(model_dir)
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2')
if USE_CUDA:
model.cuda()
# tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
print('Model loaded.')
return model, tokenizer
......@@ -120,7 +121,7 @@ def main():
data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
# gpt_alex = prepare_mix_review(args, tokenizer)
# data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data
# import pdb;pdb.set_trace()
import pdb;pdb.set_trace()
# ========== Prepare optimizer =============
# the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding
num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size
......@@ -153,7 +154,6 @@ def main():
x, type_x, pos_x, lm_x, x_len, _ = sample
keyword_x = None
input_len = x_len[0]
lm_x[:, x_len[0] + 1 + args.first_K_tokens:-1] = -1
# loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, key_word=keyword_x,
# use_keyword=args.cross_attention)[0]
......
......@@ -3,13 +3,13 @@ pwd
# python retrieve_candidate.py --model_dir mi_nli
mkdir -p ../models/mi_tuned_5turn
python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5
python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
mkdir -p ../models/mi_tuned_aug
python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment
python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment --top_p 0.95
#mkdir -p ../models/mi_tuned_5turn
#python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5
#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
#
#mkdir -p ../models/mi_tuned_aug
#python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment
#python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment --top_p 0.95
# mkdir -p ../models/mi_tuned_keyword
#python gpt_tuning.py --output_dir mi_tuned_keyword --num_train_epochs 10 --num_turns 5 --keyword
......@@ -18,4 +18,10 @@ python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_tu
# mkdir -p ../models/mi_tuned_both
# python gpt_tuning.py --output_dir mi_tuned_both --num_train_epochs 10 --num_turns 10 --keyword --augment
# python gpt_sample.py --model_dir mi_tuned_both --output_dir mi_tuned_both --num_turns 10 --keyword --augment --top_p 0.95
mkdir -p ../models/mi_tuned_kbert
python gpt_tuning.py --output_dir mi_tuned_kbert --num_train_epochs 10 --num_turns 5 --kbert
#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
echo "Finished."
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment