Skip to content
Snippets Groups Projects
Commit 2eaec2db authored by shensq's avatar shensq
Browse files

sampling now compatible with kbert embedding

parent f3f3048d
No related branches found
No related tags found
No related merge requests found
...@@ -18,7 +18,7 @@ from torch.autograd import Variable ...@@ -18,7 +18,7 @@ from torch.autograd import Variable
from tqdm import tqdm, trange from tqdm import tqdm, trange
from rouge import Rouge from rouge import Rouge
from utils import clean_text,text_standardize,values_lexicon_encode from utils import clean_text,text_standardize,values_lexicon_encode
from gpt_loader import GptDataset, collate_fn, GptDataset_aug, GptDataset_keyword, collate_fn_keyword, get_data from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data
# import nltk # import nltk
# from nltk.translate.meteor_score import meteor_score # from nltk.translate.meteor_score import meteor_score
...@@ -231,6 +231,7 @@ if __name__ == '__main__': ...@@ -231,6 +231,7 @@ if __name__ == '__main__':
parser.add_argument('--augment', action='store_true') parser.add_argument('--augment', action='store_true')
parser.add_argument('--special_input',type=str) parser.add_argument('--special_input',type=str)
parser.add_argument('--keyword', action='store_true') parser.add_argument('--keyword', action='store_true')
parser.add_argument('--kbert', action='store_true')
parser.add_argument('--cross_attention', action='store_true') parser.add_argument('--cross_attention', action='store_true')
parser.add_argument('--num_turns', type=int, default=5) parser.add_argument('--num_turns', type=int, default=5)
args = parser.parse_args() args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment