diff --git a/code/gpt_sample.py b/code/gpt_sample.py
index c439b2593896f45407d7b62caec6126533ad4b68..019238abdad569fab2b57e794b957af776342233 100644
--- a/code/gpt_sample.py
+++ b/code/gpt_sample.py
@@ -18,7 +18,7 @@ from torch.autograd import Variable
 from tqdm import tqdm, trange
 from rouge import Rouge 
 from utils import clean_text,text_standardize,values_lexicon_encode
-from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, get_data
+from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data
 # import nltk
 # from nltk.translate.meteor_score import meteor_score
 
@@ -231,6 +231,7 @@ if __name__ == '__main__':
     parser.add_argument('--augment', action='store_true')
     parser.add_argument('--special_input',type=str)
     parser.add_argument('--keyword', action='store_true')
+    parser.add_argument('--kbert', action='store_true')
     parser.add_argument('--cross_attention', action='store_true')
     parser.add_argument('--num_turns', type=int, default=5)
     args = parser.parse_args()