From 2eaec2db12d1ffef63e1a68a9f2d3f5f10931306 Mon Sep 17 00:00:00 2001 From: Siqi Shen <shensq@umich.edu> Date: Tue, 10 Mar 2020 19:01:09 -0400 Subject: [PATCH] sampling now compatible with kbert embedding --- code/gpt_sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/code/gpt_sample.py b/code/gpt_sample.py index c439b25..019238a 100644 --- a/code/gpt_sample.py +++ b/code/gpt_sample.py @@ -18,7 +18,7 @@ from torch.autograd import Variable from tqdm import tqdm, trange from rouge import Rouge from utils import clean_text,text_standardize,values_lexicon_encode -from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, get_data +from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data # import nltk # from nltk.translate.meteor_score import meteor_score @@ -231,6 +231,7 @@ if __name__ == '__main__': parser.add_argument('--augment', action='store_true') parser.add_argument('--special_input',type=str) parser.add_argument('--keyword', action='store_true') + parser.add_argument('--kbert', action='store_true') parser.add_argument('--cross_attention', action='store_true') parser.add_argument('--num_turns', type=int, default=5) args = parser.parse_args() -- GitLab