From 2eaec2db12d1ffef63e1a68a9f2d3f5f10931306 Mon Sep 17 00:00:00 2001
From: Siqi Shen <shensq@umich.edu>
Date: Tue, 10 Mar 2020 19:01:09 -0400
Subject: [PATCH] sampling now compatible with kbert embedding

---
 code/gpt_sample.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/code/gpt_sample.py b/code/gpt_sample.py
index c439b25..019238a 100644
--- a/code/gpt_sample.py
+++ b/code/gpt_sample.py
@@ -18,7 +18,7 @@ from torch.autograd import Variable
 from tqdm import tqdm, trange
 from rouge import Rouge 
 from utils import clean_text,text_standardize,values_lexicon_encode
-from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, get_data
+from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data
 # import nltk
 # from nltk.translate.meteor_score import meteor_score
 
@@ -231,6 +231,7 @@ if __name__ == '__main__':
     parser.add_argument('--augment', action='store_true')
     parser.add_argument('--special_input',type=str)
     parser.add_argument('--keyword', action='store_true')
+    parser.add_argument('--kbert', action='store_true')
     parser.add_argument('--cross_attention', action='store_true')
     parser.add_argument('--num_turns', type=int, default=5)
     args = parser.parse_args()
-- 
GitLab