Skip to content
Snippets Groups Projects
gpt_sample.py 11.2 KiB
Newer Older
  • Learn to ignore specific revisions
  • DeepLearning VM's avatar
    DeepLearning VM committed
    #!/usr/bin/env python3
    import sys
    # sys.path.insert(0,'/home/shensq/LIT/pip_package') # make sure the modified version of pytorch_transformer
    import transformers
    # assert pytorch_transformers.__file__[-36:]=='pip_package/transformers/__init__.py'
    from transformers import GPT2LMHeadModel, GPT2Tokenizer
    import argparse
    import logging
    import pickle
    import re
    from tqdm import trange
    import random 
    import torch
    import torch.nn.functional as F
    import numpy as np
    from torch.utils.data import Dataset,DataLoader
    from torch.autograd import Variable
    from tqdm import tqdm, trange
    from rouge import Rouge 
    from utils import clean_text,text_standardize,values_lexicon_encode
    
    from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data
    
    DeepLearning VM's avatar
    DeepLearning VM committed
    # import nltk
    # from nltk.translate.meteor_score import meteor_score
    
    def top_k_logits(logits, k):
        """
        Masks everything but the k top entries as -infinity (1e10).
        Used to mask logits such that e^-infinity -> 0 won't contribute to the
        sum of the denominator.
        """
        if k == 0:
            return logits
        else:
            values = torch.topk(logits, k)[0]
            batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
            return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
    
    def get_topic_keywords(meta):
        # TODO: temperary function
        keywords_up = []
        keywords_down = []
        if meta[1]=='Weight management':
            keywords_up += [6551, 4483, 2057, 9799, 4425, 4461, 4255, 5517]
            keywords_down += [46040, 21856, 2526, 13230, 7523, 15220]
        if meta[1]=='Smoking cessation':
            keywords_up += [46040, 21856, 2526, 13230, 7523, 15220]
            keywords_down += [6551, 4483, 2057, 9799, 4425, 4461, 4255, 5517]
        return keywords_up, keywords_down
    
    def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
        """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
            Args:
                logits: logits distribution shape (vocabulary size)
                top_k > 0: keep only top k tokens with highest probability (top-k filtering).
                top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                    Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
            From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
        """
        assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
        top_k = min(top_k, logits.size(-1))  # Safety check
        if top_k > 0:
            # Remove all tokens with a probability less than the last token of the top-k
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = filter_value
    
        if top_p > 0.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    
            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
    
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logits[indices_to_remove] = filter_value
        return logits
    
    
    shensq's avatar
    shensq committed
    def sample_sequence(model, length, context, num_samples=1, temperature=1,
                            top_k=0, top_p=0.0, device='cuda', attention_mask=None):
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        context = torch.tensor(context, dtype=torch.long, device=device)
        context = context.unsqueeze(0).repeat(num_samples, 1)
        generated = context
        prev = context
        past = None
    
        with torch.no_grad():
            for i in trange(length):
    #             inputs = {'input_ids': generated, 'past': None, 'key_word': key_word, 'use_keyword':use_keyword}
                inputs = {'input_ids': generated, 'past': None}
                logits, past = model(**inputs)
                next_token_logits = logits[0, -1, :] / (temperature if temperature>0 else 1.)
                filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # if top_k > 0 or top_p > 0.0: # greedy, top_p, top_k
                if temperature == 0:
                    next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
                else: # temperature
                    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
                while (i == 0) and (next_token[0] == 50256):
                    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
                prev = next_token.unsqueeze(0)
                if next_token[0] in [50256]:
                    break
        return generated
    
    
    
    def load_model_data(args):
        #  === prepare data and model
        # ====== Load GPT2 model ========
        model_dir = '../models/'+args.model_dir
        model = GPT2LMHeadModel.from_pretrained(model_dir)
        if USE_CUDA:
            model.cuda()
        tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
        return model, tokenizer
    
    def run_model(args, model, tokenizer, test_loader):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.eval()
    
        if args.length == -1:
            args.length = model.config.n_ctx // 2
        elif args.length > model.config.n_ctx:
            raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
    
        hyp = []
        ref = []
        context = []
        f = open('../result/'+args.output_dir+'.txt','w')
        f_ref = open('../result/reference_'+args.output_dir+'.txt','w')
        
        for i,sample in enumerate(test_loader):
    
            # if args.cross_attention:
            #     x, type_x, pos_x, lm_x, x_len, meta, keyword_x = sample
            # else:
            #     x, type_x, pos_x, lm_x, x_len, meta = sample
            #     keyword_x = None
            x, type_x, pos_x, lm_x, x_len, attention_mask = sample
    
    DeepLearning VM's avatar
    DeepLearning VM committed
            input_len = x_len[0] # The number of tokens of the context utterances
            context_tokens = x[0][:input_len+1] # at evaluation stage, the input is without the ground truth
    
            generated = 0
            for i in range(args.nsamples // args.batch_size):
                decode_length = int(len(context_tokens))
                # if args.augment:
                #     decode_length = int(0.5 * (5/6) * len(context_tokens))
                out = sample_sequence(
                    model=model,length=decode_length,
                    context=context_tokens,
    
    shensq's avatar
    shensq committed
                    temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
                    device=device, attention_mask = attention_mask
    
    DeepLearning VM's avatar
    DeepLearning VM committed
                )           
                out = out[:, len(context_tokens):-1].tolist() # the generated result,get rid of eos
    
                ref.append(tokenizer.decode(x[0].tolist()[len(context_tokens):-1]))
                f_ref.write(tokenizer.decode(x[0].tolist()[len(context_tokens):-1]))
                f_ref.write('\n')
    
                hyp.append(tokenizer.decode(out[0]))
                f.write(tokenizer.decode(out[0]))
                f.write('\n')
    
                context.append(tokenizer.decode(x[0].tolist()[:len(context_tokens)]))
        f.close()
        f_ref.close()
        return hyp, ref, context
    
    def calculate_metric(hyp, ref, context, effective_length=1024):
        # ===== Calculate rouge ========
        with open('../result/rouge.txt','a') as f_result:
            rouge = Rouge()
            print(len(hyp))
            print(len(ref))
            hyp, ref = zip(*[(x,y) for x,y in zip(hyp, ref) if len(x)>3 and len(y)>3])
            print(len(hyp))
            hyp = [x[:effective_length] for x in hyp]
            ref = [x[:effective_length] for x in ref]
            scores = rouge.get_scores(hyp, ref,avg=True)
            print("ROUGE",scores)
            import time 
            f_result.write(time.asctime()+'\n')
            f_result.write(args.model_dir+ '\t' + str(effective_length) +'\n')
            f_result.write(str(scores))
            f_result.write('\n')
        # == dump output====
        print("#ref{} #hyp{}".format(len(ref),len(hyp)))
        with open("../data_processed/output_" + args.model_dir+'p{}k{}'.format(args.top_p,args.top_k),'wb') as f_output:
            pickle.dump(zip(hyp,ref,context), f_output)
            
    #     # ====== Calculate Meteor =========
    #     meteor_sum = 0
    #     for i in range(min(len(ref),len(hyp))):
    #         meteor_sum += meteor_score([ref[i]],hyp[i])
    
    #     meteor_sum/=min(len(ref),len(hyp))
    #     print(meteor_sum)   
    
    def rouge_rank(hyp, ref, context):
        rouge = Rouge()
        # import pdb;pdb.set_trace()
        hyp, ref = zip(*[(x,y) for x,y in zip(hyp, ref) if len(x)>3 and len(y)>3])
        scores = rouge.get_scores(hyp, ref,avg=False) # type: list
        scores_content = zip(scores, hyp, ref, context, range(len(hyp)))
        scores_content = sorted(scores_content, key=lambda x:x[0]['rouge-1']['f'], reverse=True)
        return scores_content
    
    if __name__ == '__main__':
        USE_CUDA = torch.cuda.is_available()
        logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt = '%m/%d/%Y %H:%M:%S',
                            level = logging.INFO)
        logger = logging.getLogger(__name__)
    
        # Parse command line arguments
        parser = argparse.ArgumentParser()
        parser.add_argument('--model_dir', type=str, default='345M_Alex', help='pretrained model name or path to local checkpoint')
        parser.add_argument("--seed", type=int, default=42)
        parser.add_argument("--nsamples", type=int, default=1)
        parser.add_argument("--batch_size", type=int, default=-1)
        parser.add_argument("--length", type=int, default=64)
        parser.add_argument("--temperature", type=float, default=1.0)
        parser.add_argument("--top_k", type=int, default=0)
        parser.add_argument("--top_p", type=float, default=0)
        parser.add_argument('--output_dir',type=str,default='generate', help="The name of the output file.")
        parser.add_argument('--modified_decoding', action='store_true')
        parser.add_argument('--augment', action='store_true')
        parser.add_argument('--special_input',type=str)
        parser.add_argument('--keyword', action='store_true')
    
        parser.add_argument('--kbert', action='store_true')
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        parser.add_argument('--cross_attention', action='store_true')
        parser.add_argument('--num_turns', type=int, default=5)
        args = parser.parse_args()
        if args.batch_size == -1:
            args.batch_size = 1
        assert args.nsamples % args.batch_size == 0
        print(args)
    
        # Setup the random seeds.
        np.random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.manual_seed(args.seed)
    
        model, tokenizer = load_model_data(args)
    
    shensq's avatar
    shensq committed
        split_size = {'train': 0.90, 'test': 0.05, 'val': 0.05}
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
        # model, tokenizer, test_loader = load_model_data(args) # TODO: this is for old get_data
        # import pdb;pdb.set_trace()
        hyp, ref, context = run_model(args, model, tokenizer, test_loader)
        sample_ranked = rouge_rank(hyp, ref, context)
        with open("../data_processed/rouge_rank_" + args.model_dir,'wb') as f:
            pickle.dump(sample_ranked, f)
        calculate_metric(hyp, ref, context)
        # calculate_metric(hyp, ref, context, 5)