Skip to content
Snippets Groups Projects
Commit e8b32c57 authored by DeepLearning VM's avatar DeepLearning VM
Browse files
parents b36eedad 892a8c4a
No related branches found
No related tags found
No related merge requests found
...@@ -84,11 +84,11 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, ...@@ -84,11 +84,11 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1,
generated = context generated = context
prev = context prev = context
past = None past = None
import pdb;pdb.set_trace()
with torch.no_grad(): with torch.no_grad():
for i in trange(length): for i in trange(length):
# inputs = {'input_ids': generated, 'past': None, 'key_word': key_word, 'use_keyword':use_keyword} # inputs = {'input_ids': generated, 'past': None, 'key_word': key_word, 'use_keyword':use_keyword}
inputs = {'input_ids': generated, 'past': None} inputs = {'input_ids': generated, 'past': None, 'attention_mask':attention_mask}
logits, past = model(**inputs) logits, past = model(**inputs)
next_token_logits = logits[0, -1, :] / (temperature if temperature>0 else 1.) next_token_logits = logits[0, -1, :] / (temperature if temperature>0 else 1.)
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
...@@ -138,6 +138,7 @@ def run_model(args, model, tokenizer, test_loader): ...@@ -138,6 +138,7 @@ def run_model(args, model, tokenizer, test_loader):
# else: # else:
# x, type_x, pos_x, lm_x, x_len, meta = sample # x, type_x, pos_x, lm_x, x_len, meta = sample
# keyword_x = None # keyword_x = None
import pdb;pdb.set_trace()
x, type_x, pos_x, lm_x, x_len, attention_mask = sample x, type_x, pos_x, lm_x, x_len, attention_mask = sample
input_len = x_len[0] # The number of tokens of the context utterances input_len = x_len[0] # The number of tokens of the context utterances
context_tokens = x[0][:input_len+1] # at evaluation stage, the input is without the ground truth context_tokens = x[0][:input_len+1] # at evaluation stage, the input is without the ground truth
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment