Skip to content
Snippets Groups Projects
Commit 0a58afd0 authored by shensq's avatar shensq
Browse files

Gptdataset_full now use x_y_with_comet

parent 2eaec2db
No related branches found
No related tags found
No related merge requests found
......@@ -354,6 +354,8 @@ class GptDataset_full(Dataset):
keyword_all = []
for x, y, meta, aug, keyword in x_y_meta:
meta_all.append(meta)
# update for the new data format
aug = ''.join([a[1] for a in aug])
x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
y_all.append(self.tokenizer.encode(text_standardize(y)))
aug_all.append(self.tokenizer.encode(text_standardize(aug)))
......@@ -531,7 +533,7 @@ def get_data(args, tokenizer, split_size):
gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
elif not args.kbert:
print("Using full data.")
pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data.
pickle_handler = open('../data_processed/x_y_with_comet', 'rb') # TODO: change back to the old data.
x_y_meta = pickle.load(pickle_handler)
gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args)
else:
......
......@@ -121,7 +121,7 @@ def main():
data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
# gpt_alex = prepare_mix_review(args, tokenizer)
# data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data
import pdb;pdb.set_trace()
# ========== Prepare optimizer =============
# the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding
num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment