From 0a58afd0f52b44358aad46e1e136ad2bd212e6de Mon Sep 17 00:00:00 2001
From: Siqi Shen <shensq@umich.edu>
Date: Tue, 10 Mar 2020 19:17:26 -0400
Subject: [PATCH] Gptdataset_full now use x_y_with_comet

---
 code/gpt_loader/load_data.py | 4 +++-
 code/gpt_tuning.py           | 2 +-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/code/gpt_loader/load_data.py b/code/gpt_loader/load_data.py
index 9c9b40f..db508be 100644
--- a/code/gpt_loader/load_data.py
+++ b/code/gpt_loader/load_data.py
@@ -354,6 +354,8 @@ class GptDataset_full(Dataset):
         keyword_all = []
         for x, y, meta, aug, keyword in x_y_meta:
             meta_all.append(meta)
+            # update for the new data format
+            aug = ''.join([a[1] for a in aug])
             x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
             y_all.append(self.tokenizer.encode(text_standardize(y)))
             aug_all.append(self.tokenizer.encode(text_standardize(aug)))
@@ -531,7 +533,7 @@ def get_data(args, tokenizer, split_size):
         gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
     elif not args.kbert:
         print("Using full data.")
-        pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data.
+        pickle_handler = open('../data_processed/x_y_with_comet', 'rb') # TODO: change back to the old data.
         x_y_meta = pickle.load(pickle_handler)
         gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args)
     else:
diff --git a/code/gpt_tuning.py b/code/gpt_tuning.py
index 7934382..697bebb 100644
--- a/code/gpt_tuning.py
+++ b/code/gpt_tuning.py
@@ -121,7 +121,7 @@ def main():
     data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
     # gpt_alex = prepare_mix_review(args, tokenizer)
     # data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data
-    import pdb;pdb.set_trace()
+
     # ========== Prepare optimizer =============
     # the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding
     num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size
-- 
GitLab