diff --git a/code/gpt_loader/__init__.py b/code/gpt_loader/__init__.py
index ad82fded6eec648ab7d3fb1181f395c1f35b5126..d083d72ca7c1ef5628bb74d9f6c7dbcba0b8c9ae 100644
--- a/code/gpt_loader/__init__.py
+++ b/code/gpt_loader/__init__.py
@@ -1 +1 @@
-from .load_data import GptDataset,collate_fn,collate_fn_nli,GptDataset_nli,SnliDataset,GptDataset_aug,GptDataset_keyword, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review, XLDataset_nli
+from .load_data import GptDataset,collate_fn,collate_fn_nli,SnliDataset,GptDataset_aug, collate_fn_keyword, get_data, prepare_mix_review, update_mix_review
diff --git a/code/gpt_loader/load_data.py b/code/gpt_loader/load_data.py
index c873c275f271c3b8dd218007e7c6fbd87788815a..9c9b40f5c889e2cd08a8c916bd4ceea73d4d08bd 100644
--- a/code/gpt_loader/load_data.py
+++ b/code/gpt_loader/load_data.py
@@ -6,6 +6,7 @@ import random
 import sys
 import pickle
 from tqdm import tqdm
+from collections import deque
 import copy
 sys.path.append("..")
 from utils import text_standardize
@@ -98,29 +99,6 @@ class GptDataset(Dataset):
     def __len__(self):
         return len(self.x_encoded)
 
-# class GptDataset_keyword(Dataset):
-#     def _split(self,x_y_meta):
-#         x_all = []
-#         y_all = []
-#         meta_all = []
-#         aug_all = []
-#         for x,y,meta,aug in x_y_meta:
-#             meta_all.append(meta)
-#             x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
-#             y_all.append(self.tokenizer.encode(text_standardize(y)))
-#             key_word.append(self.tokenizer.encode(text_standardize(aug)))
-
-#         return x_all,y_all,meta_all,aug_all
-
-#     def __init__(self,x_y_meta,tokenizer,num_turns=5):
-#         self.x_y_meta = x_y_meta
-#         self.num_turns = num_turns
-#         self.tokenizer = tokenizer
-#         self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta)
-#         self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
-#         self.augment = 5
-#         self.keyword = 10 # '+'
-
 class GptDataset_aug(Dataset):
     def _split(self,x_y_meta):
         x_all = []
@@ -180,6 +158,7 @@ class GptDataset_aug(Dataset):
         return x,type_x,position_x,lm_x,total_input_length,self.meta[index]
     def __len__(self):
         return len(self.x_encoded)
+
 def collate_fn(data):
     """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
     We should build a custom collate_fn rather than using default collate_fn,
@@ -293,119 +272,6 @@ def collate_fn_keyword(data):
         keyword_x = keyword_x.cuda()
     return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x))
 
-class GptDataset_keyword(Dataset):
-    def _split(self, x_y_meta):
-        x_all = []
-        y_all = []
-        meta_all = []
-        keyword_all = []
-        for x, y, meta, keyword in x_y_meta:
-            meta_all.append(meta)
-            x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
-            y_all.append(self.tokenizer.encode(text_standardize(y)))
-            keyword_all.append(self.tokenizer.encode(text_standardize(keyword)))
-        return x_all, y_all, meta_all, keyword_all
-
-    def __init__(self, x_y_meta, tokenizer, num_turns=5):
-
-        self.x_y_meta = x_y_meta
-        self.num_turns = num_turns
-        self.tokenizer = tokenizer
-        self.x_encoded, self.y_encoded, self.meta, self.keyword = self._split(x_y_meta)
-        self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256
-
-    def __getitem__(self, index):
-        x = []
-        type_x = []
-        lm_x = []
-        is_speaker1 = bool(len(self.x_encoded[index]) % 2)  # which speaker start the conversation
-
-        for utt in self.x_encoded[index]:
-            if is_speaker1:  # add the prefix special token for each utterance
-                x += [self.speaker1]
-                type_x += [self.speaker1] * (len(utt) + 1)
-            else:
-                x += [self.speaker2]
-                type_x += [self.speaker2] * (len(utt) + 1)
-            x += utt
-            is_speaker1 = not is_speaker1
-        lm_x += [-1] * len(x)  # all position for the input is masked for loss calculation
-
-        total_input_length = len(x)
-
-        x += [self.ref_start] + self.y_encoded[index] + [self.eos]
-
-        type_x += [self.ref_start] * (len(self.y_encoded[index]) + 2)
-        lm_x += [-1] + self.y_encoded[index] + [self.eos]
-        position_x = list(range(len(x)))
-
-        x = torch.Tensor(x)
-        type_x = torch.Tensor(type_x)
-        position_x = torch.Tensor(position_x)
-        lm_x = torch.Tensor(lm_x)
-        x_len = x.shape[0]
-
-        keyword_x = [] + self.keyword[index]
-        keyword_x = torch.Tensor(keyword_x)
-        return x, type_x, position_x, lm_x, total_input_length, self.meta[index], keyword_x
-
-    def __len__(self):
-        return len(self.x_encoded)
-
-# class GptDataset_nli(GptDataset):
-#     def __init__(self, x_y_meta, tokenizer, filter_mode=None,num_turns=5,augment=True):
-#         super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, num_turns=num_turns)
-#         self.augment = augment
-#         self.pos_len = len(self.x_encoded)
-
-#     def __len__(self):
-#         if self.augment:
-#             return 2 * len(self.x_encoded)
-#         else:
-#             return len(self.x_encoded)
-
-#     def __getitem__(self,index):
-#         # client utterances - premise -speaker1 
-#         # response - hypothesis - ref_start
-#         true_index = index
-#         if index >= self.pos_len:
-#             index = index - self.pos_len
-
-#         x = []
-#         type_x = []
-#         lm_x = []
-#         is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
-        
-#         x+=[self.speaker1]
-#         type_x += [self.speaker1]
-#         for utt in self.x_encoded[index][-self.num_turns:]:
-#             if is_speaker1: # add the prefix special token for each utterance
-#                 type_x += [self.speaker1]*(len(utt))
-#                 x += utt
-#             # else:
-#             #     x+=[self.speaker2]
-#             #     type_x += [self.speaker2]*(len(utt)+1)
-#             #     x += utt
-#             is_speaker1 = not is_speaker1
-
-#         total_input_length = len(x)
-        
-#         if true_index >= self.pos_len:
-#             rand_index = random.randint(0,self.pos_len-1)
-#             x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos]
-#             type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2)
-#         else:
-#             x += [self.ref_start] + self.y_encoded[index] + [self.eos]
-#             type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
-#         position_x = list(range(len(x)))
-
-#         x = torch.Tensor(x)
-#         type_x = torch.Tensor(type_x)
-#         position_x = torch.Tensor(position_x)
-#         x_len = x.shape[0]
-#         label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
-#         return x,type_x,position_x,lm_x, label
-
 class SnliDataset(Dataset):
     """Take a list of samples with form [[x,...],y,meta]
     """
@@ -561,151 +427,89 @@ class GptDataset_full(Dataset):
     def __len__(self):
         return len(self.x_encoded)
 
+class GptDataset_KBERT(Dataset):
+    def get_comet_aug_deque(self, comet_data, num_turns=5):
+        clause_dq = deque()
+        for comet_in, comet_out in comet_data:
+            if comet_out == "":
+                continue
+            loc = int(comet_in.split()[0])
+            if loc >= (10 - num_turns):
+                clause_dq.append((loc, comet_out))
+        return clause_dq
+
+    def __init__(self, x_y_meta, tokenizer, args):
+        self.data = x_y_meta
+        self.num_turns = args.num_turns
+        self.tokenizer = tokenizer
+        self.args = args
+        self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256
+        self.augment = 5
 
-class GptDataset_nli(GptDataset_full):
-    def __init__(self, x_y_meta, tokenizer, args, infer=False):
-        super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args)
-        self.pos_len = len(self.x_encoded)
-        self.num_turns = 5
-        self.infer = infer
-    def __len__(self):
-        if self.infer:
-            return len(self.x_encoded)
-        else:
-            return 2 * len(self.x_encoded)
-
-    def __getitem__(self,index):
-        # client utterances - premise -speaker1 
-        # response - hypothesis - ref_start
-        true_index = index
-        if index >= self.pos_len:
-            index = index - self.pos_len
+        if self.args.augment:
+            print("Using augment sentences.")
+        if self.args.keyword:
+            print("Using keywords.")
 
+    def __getitem__(self, index):
         x = []
         type_x = []
         lm_x = []
-        is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
-        
-        x+=[self.speaker1]
-        type_x += [self.speaker1]
-        for utt in self.x_encoded[index][-self.num_turns:]:
-            if is_speaker1: # add the prefix special token for each utterance
-                type_x += [self.speaker1]*(len(utt))
-                x += utt
-            # else:
-            #     x+=[self.speaker2]
-            #     type_x += [self.speaker2]*(len(utt)+1)
-            #     x += utt
-            is_speaker1 = not is_speaker1
+        soft_position_x = []
 
-        total_input_length = len(x)
-        
-        if true_index >= self.pos_len:
-            rand_index = random.randint(0,self.pos_len-1)
-            x += [self.ref_start] + self.y_encoded[rand_index] + [self.eos]
-            type_x += [self.ref_start]*(len(self.y_encoded[rand_index])+2)
-        else:
-            x += [self.ref_start] + self.y_encoded[index] + [self.eos]
-            type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
-        position_x = list(range(len(x)))
+        dq = self.get_comet_aug_deque(self.data[index][3])  # the comet info
+        context = self.data[index][0]
+        response = self.data[index][1]
 
-        x = torch.Tensor(x)
-        type_x = torch.Tensor(type_x)
-        position_x = torch.Tensor(position_x)
-        x_len = x.shape[0]
-        label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
+        is_speaker1 = bool(self.args.num_turns % 2)
+        soft_loc = 0  # keep tract of the location of main sentences, point to the next token to be added
+        for i in range(10 - self.args.num_turns, 10):
+            utternace_encoded = self.tokenizer.encode(text_standardize(context[i]))
 
-        return x,type_x,position_x,lm_x, label
+            # add the prefix special token for each utterance
+            if is_speaker1:
+                x += [self.speaker1]
+                type_x += [self.speaker1] * (len(utternace_encoded) + 1)
+            else:
+                x += [self.speaker2]
+                type_x += [self.speaker2] * (len(utternace_encoded) + 1)
+            x += utternace_encoded
 
-class XLDataset_nli(GptDataset_nli):
-    def __init__(self, x_y_meta, tokenizer, args, infer=False):
-        super(GptDataset_nli, self).__init__(x_y_meta,tokenizer, args)
-        self.pos_len = len(self.x_encoded)
-        self.num_turns = 5
-        self.infer = infer
-#         self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
-        self.pad, self.sep, self.cls = 5, 4, 3
-        self.unk, self.s, self.s_bar = 0, 1, 2
+            soft_position_x += list(range(soft_loc, soft_loc + len(utternace_encoded) + 1))
 
-    def __len__(self):
-        if self.infer:
-            return len(self.x_encoded)
-        else:
-            return 2 * len(self.x_encoded)
+            # add the aug, if it is the right place
+            if len(dq) != 0 and dq[0][0] == i:
+                comet_output = dq.popleft()[1]
+                comet_encoded = self.tokenizer.encode(text_standardize(comet_output))
+                x += [self.augment] + comet_encoded
+                type_x += [self.augment] * (len(comet_encoded) + 1)
+                soft_position_x += list(range(soft_loc, soft_loc + len(comet_encoded) + 1))
 
-    def __getitem__(self,index):
-        # client utterances - premise -speaker1 
-        # response - hypothesis - ref_start
-        true_index = index
-        if index >= self.pos_len:
-            index = index - self.pos_len
-    
-        
-        x = []
-        type_x = []
-        lm_x = []
-        mask_x = []
-        is_speaker1 = bool(len(self.x_encoded[index])%2) # which speaker start the conversation
-        
-        
-        for utt in self.x_encoded[index][-self.num_turns:]:
-            if is_speaker1: # add the prefix special token for each utterance
-                type_x += [self.unk]*(len(utt))
-                x += utt
-            else:
-                type_x += [self.unk]*(len(utt))
-                x += utt
+            # update the pointer to the new seq end, add one for the delimiter token
+            soft_loc += len(utternace_encoded) + 1
             is_speaker1 = not is_speaker1
-#         import pdb;pdb.set_trace()
-        x += [self.sep]
-        type_x += [self.unk]
-        
+
+        lm_x += [-100] * len(x)  # all position for the input is masked for loss calculation
         total_input_length = len(x)
-        
-        if true_index >= self.pos_len:
-            rand_index = random.randint(0,self.pos_len-1)
-            x += self.y_encoded[rand_index] + [self.sep, self.cls]
-            type_x += [self.s]*(len(self.y_encoded[rand_index])+1) + [self.s_bar]
-        else:
-#             x += [self.ref_start] + self.y_encoded[index] + [self.eos]
-#             type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
-            x += self.y_encoded[index] + [self.sep, self.cls]
-            type_x += [self.s]*(len(self.y_encoded[index])+1) + [self.s_bar]
-        
-        position_x = list(range(len(x)))
-        mask_x = [self.s] * len(x)
-        
-#         ####
-#         x = x[-100:]
-#         mask_x = mask_x[-100:]
-#         type_x = type_x[-100:]
-        
-        # left padding 
-        x = [self.pad] * (self.args.max_length-len(x)) + x[-self.args.max_length:]
-        mask_x = [self.unk] * (self.args.max_length-len(mask_x)) + mask_x[-self.args.max_length:]
-        type_x = [self.sep] * (self.args.max_length-len(type_x)) + type_x[-self.args.max_length:]
-        
-        x = torch.Tensor(x).long()
-        mask_x = torch.Tensor(mask_x).long()
-        type_x = torch.Tensor(type_x).long()
-        position_x = torch.Tensor(position_x)
-        
-        x_len = x.shape[0]
-        
-        label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(1)
-        # label = torch.tensor(0) if true_index>self.pos_len else torch.tensor(0)
-        
-        if USE_CUDA:
-            x = x.cuda()
-            mask_x = mask_x.cuda()
-            type_x = type_x.cuda()
-            label = label.cuda()
-#         return x, mask_x, type_x, label, position_x, lm_x
-        return x, mask_x, type_x, label
 
+        response_encoded = self.tokenizer.encode(text_standardize(response))
+        x += [self.ref_start] + response_encoded + [self.eos]
+
+        type_x += [self.ref_start] * (len(response_encoded) + 2)
+        lm_x += [-100] + response_encoded + [self.eos]
+
+        soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2))
 
+        x = torch.Tensor(x)
+        type_x = torch.Tensor(type_x)
+        soft_position_x = torch.Tensor(soft_position_x)
+        lm_x = torch.Tensor(lm_x)
+        x_len = x.shape[0]
 
+        return x, type_x, soft_position_x, lm_x, total_input_length, self.data[index][2]
 
+    def __len__(self):
+        return len(self.data)
 
 def get_data(args, tokenizer, split_size):
     """
@@ -725,11 +529,16 @@ def get_data(args, tokenizer, split_size):
         pickle_handler = open('../data_processed/' + args.special_input, 'rb')
         x_y_meta = pickle.load(pickle_handler)
         gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
-    else:
+    elif not args.kbert:
         print("Using full data.")
         pickle_handler = open('../data_processed/x_y_meta_all', 'rb') # TODO: change back to the old data.
         x_y_meta = pickle.load(pickle_handler)
         gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args)
+    else:
+        print("Using KBERT data")
+        pickle_handler = open("../data_processed/x_y_with_comet",'rb')
+        x_y_meta = pickle.load(pickle_handler)
+        gpt_data = GptDataset_KBERT(x_y_meta, tokenizer, args=args)
     print("Dataset initialized. There are {} samples.".format(len(gpt_data)))
 
     test_size = int(len(gpt_data) * split_size['test'])
@@ -769,55 +578,4 @@ def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0
 
     data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
                                 collate_fn=collate_fn)
-    return data_loader
-
-def get_data_old(args, tokenizer, split_size):
-    """
-    Return the data loaders needed for training and evaluation.
-    :param args: command line arguments.
-    :param tokenizer: the tokenizer used in preparing the data.
-    :param split_size: the portion of train, test, validation set.
-    :return data_loader: The data loader for the training set.
-    :return val_loader: The data loader for the validation set.
-    """
-    if args.special_input:
-        print("Using mutated data.")
-        pickle_handler = open('../data_processed/' + args.special_input, 'rb')
-        x_y_meta = pickle.load(pickle_handler)
-        if args.augment:
-            print("testing keywords with augment loader.")
-            gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns)
-        else:
-            gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
-    elif args.augment:
-        print("Using augmented data")
-        pickle_handler = open('../data_processed/x_y_meta_aug', 'rb')
-        x_y_meta = pickle.load(pickle_handler)
-        gpt_data = GptDataset_aug(x_y_meta, tokenizer, num_turns=args.num_turns)
-    elif args.keyword:
-        print("Using keyword cross attention")
-        pickle_handler = open('../data_processed/x_y_meta_keyword', 'rb')
-        x_y_meta = pickle.load(pickle_handler)
-        gpt_data = GptDataset_keyword(x_y_meta, tokenizer)
-    else:
-        print("Using vanilla data.")
-        pickle_handler = open('../data_processed/x_y_meta', 'rb')
-        x_y_meta = pickle.load(pickle_handler)
-        gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
-
-    print("Dataset initialized. There are {} samples.".format(len(gpt_data)))
-    test_size = int(len(gpt_data) * split_size['test'])
-    val_size = int(len(gpt_data) * split_size['val'])
-    gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data,
-                                                                 [len(gpt_data) - test_size - val_size, test_size,
-                                                                  val_size])
-    if args.keyword:
-        data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
-                                 collate_fn=collate_fn_keyword)
-        val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False,
-                                collate_fn=collate_fn_keyword)
-    else:
-        data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
-                                 collate_fn=collate_fn)
-        val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
-    return data_loader, val_loader
+    return data_loader
\ No newline at end of file
diff --git a/code/gpt_tuning.py b/code/gpt_tuning.py
index 550ceb50727393c377b22f4e9cf5020e5fe418ae..7934382fd9cfdf5837d51d14a4a3d4130dc4b861 100644
--- a/code/gpt_tuning.py
+++ b/code/gpt_tuning.py
@@ -16,7 +16,7 @@ from torch.autograd import Variable
 from tqdm import tqdm, trange
 import random
 from utils import clean_text, text_standardize, construct_grouped_parameters, get_unfreezing_funcs
-from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, prepare_mix_review, update_mix_review, get_data
+from gpt_loader import GptDataset, collate_fn,collate_fn_keyword, prepare_mix_review, update_mix_review, get_data
 
 # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
 import logging
@@ -76,6 +76,7 @@ def parse_arguments():
     parser.add_argument('--use_disc_lr', action='store_true')
     parser.add_argument('--use_unfreezing', action='store_true')
     parser.add_argument('--num_turns', type=int, default=5)
+    parser.add_argument('--kbert', action='store_true')
     args = parser.parse_args()
     print(args)
     return args
@@ -91,11 +92,11 @@ def load_model(args):
     # ====== Load GPT2 model ========
     model_dir = '../models/' + args.model_dir
 #     model = GPT2LMHeadModel.from_pretrained(model_dir)
-    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
+    model = GPT2LMHeadModel.from_pretrained('gpt2')
     if USE_CUDA:
         model.cuda()
 #     tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
-    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
+    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
     print('Model loaded.')
     return model, tokenizer
 
@@ -120,7 +121,7 @@ def main():
     data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
     # gpt_alex = prepare_mix_review(args, tokenizer)
     # data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data
-    # import pdb;pdb.set_trace()
+    import pdb;pdb.set_trace()
     # ========== Prepare optimizer =============
     # the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding
     num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size
@@ -153,7 +154,6 @@ def main():
                 x, type_x, pos_x, lm_x, x_len, _ = sample
                 keyword_x = None
             input_len = x_len[0]
-
             lm_x[:, x_len[0] + 1 + args.first_K_tokens:-1] = -1
 #             loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, key_word=keyword_x,
 #                          use_keyword=args.cross_attention)[0]
diff --git a/code/run_compare_aug.sh b/code/run_compare_aug.sh
index 169590b5265b7c2a7a50acdb610c34232dc3a3dc..df718d7f4e429a195564da5632818ea7ca3f8aaa 100644
--- a/code/run_compare_aug.sh
+++ b/code/run_compare_aug.sh
@@ -3,13 +3,13 @@ pwd
 
 # python retrieve_candidate.py --model_dir mi_nli
 
-mkdir -p ../models/mi_tuned_5turn
-python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5
-python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
-
-mkdir -p ../models/mi_tuned_aug
-python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment
-python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment  --top_p 0.95
+#mkdir -p ../models/mi_tuned_5turn
+#python gpt_tuning.py --output_dir mi_tuned_5turn --num_train_epochs 10 --num_turns 5
+#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
+#
+#mkdir -p ../models/mi_tuned_aug
+#python gpt_tuning.py --output_dir mi_tuned_aug --num_train_epochs 10 --num_turns 5 --augment
+#python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_turns 5 --augment  --top_p 0.95
 
 # mkdir -p ../models/mi_tuned_keyword
 #python gpt_tuning.py --output_dir mi_tuned_keyword --num_train_epochs 10 --num_turns 5 --keyword
@@ -18,4 +18,10 @@ python gpt_sample.py --model_dir mi_tuned_aug --output_dir mi_tuned_aug --num_tu
 # mkdir -p ../models/mi_tuned_both
 # python gpt_tuning.py --output_dir mi_tuned_both --num_train_epochs 10 --num_turns 10 --keyword --augment
 # python gpt_sample.py --model_dir mi_tuned_both --output_dir mi_tuned_both --num_turns 10 --keyword --augment --top_p 0.95
+
+mkdir -p ../models/mi_tuned_kbert
+python gpt_tuning.py --output_dir mi_tuned_kbert --num_train_epochs 10 --num_turns 5 --kbert
+#python gpt_sample.py --model_dir mi_tuned_5turn --output_dir mi_tuned_5turn --num_turns 5 --top_p 0.95
+
 echo "Finished."
+