Skip to content
Snippets Groups Projects
Commit e8b32c57 authored by DeepLearning VM's avatar DeepLearning VM
Browse files
parents b36eedad 892a8c4a
No related branches found
No related tags found
No related merge requests found
......@@ -84,11 +84,11 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1,
generated = context
prev = context
past = None
import pdb;pdb.set_trace()
with torch.no_grad():
for i in trange(length):
# inputs = {'input_ids': generated, 'past': None, 'key_word': key_word, 'use_keyword':use_keyword}
inputs = {'input_ids': generated, 'past': None}
inputs = {'input_ids': generated, 'past': None, 'attention_mask':attention_mask}
logits, past = model(**inputs)
next_token_logits = logits[0, -1, :] / (temperature if temperature>0 else 1.)
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
......@@ -138,6 +138,7 @@ def run_model(args, model, tokenizer, test_loader):
# else:
# x, type_x, pos_x, lm_x, x_len, meta = sample
# keyword_x = None
import pdb;pdb.set_trace()
x, type_x, pos_x, lm_x, x_len, attention_mask = sample
input_len = x_len[0] # The number of tokens of the context utterances
context_tokens = x[0][:input_len+1] # at evaluation stage, the input is without the ground truth
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment