#!/usr/bin/env python3
import sys
# sys.path.insert(0,'/home/shensq/LIT/pip_package') # make sure the modified version of pytorch_transformer
import transformers
# assert pytorch_transformers.__file__[-36:]=='pip_package/transformers/__init__.py'
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import argparse
import logging
import pickle
import re
from tqdm import trange
import random 
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset,DataLoader
from torch.autograd import Variable
from tqdm import tqdm, trange
from rouge import Rouge 
from utils import clean_text,text_standardize,values_lexicon_encode
from gpt_loader import GptDataset, collate_fn, GptDataset_aug, get_data
# import nltk
# from nltk.translate.meteor_score import meteor_score

def top_k_logits(logits, k):
    """
    Masks everything but the k top entries as -infinity (1e10).
    Used to mask logits such that e^-infinity -> 0 won't contribute to the
    sum of the denominator.
    """
    if k == 0:
        return logits
    else:
        values = torch.topk(logits, k)[0]
        batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
        return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)

def get_topic_keywords(meta):
    # TODO: temperary function
    keywords_up = []
    keywords_down = []
    if meta[1]=='Weight management':
        keywords_up += [6551, 4483, 2057, 9799, 4425, 4461, 4255, 5517]
        keywords_down += [46040, 21856, 2526, 13230, 7523, 15220]
    if meta[1]=='Smoking cessation':
        keywords_up += [46040, 21856, 2526, 13230, 7523, 15220]
        keywords_down += [6551, 4483, 2057, 9799, 4425, 4461, 4255, 5517]
    return keywords_up, keywords_down

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def sample_sequence(model, length, context, num_samples=1, temperature=1,
                        top_k=0, top_p=0.0, device='cuda', attention_mask=None):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    prev = context
    past = None
    attention_size = attention_mask.shape[-1]
    output_attention_mask = torch.tril(torch.ones(512, 512, dtype=attention_mask.dtype))
    output_attention_mask = output_attention_mask.view(1,1,*output_attention_mask.shape)
    output_attention_mask[:,:,:attention_size,:attention_size] = attention_mask

    with torch.no_grad():
        for i in trange(length):
#             inputs = {'input_ids': generated, 'past': None, 'key_word': key_word, 'use_keyword':use_keyword}
            current_length = generated.shape[-1]
            inputs = {'input_ids': generated, 'past': None, 'attention_mask':output_attention_mask[:,:,:current_length,:current_length]}
            logits, past = model(**inputs)
            next_token_logits = logits[0, -1, :] / (temperature if temperature>0 else 1.)
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            # if top_k > 0 or top_p > 0.0: # greedy, top_p, top_k
            if temperature == 0:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else: # temperature
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            while (i == 0) and (next_token[0] == 50256):
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
            prev = next_token.unsqueeze(0)
            if next_token[0] in [50256]:
                break
    return generated



def load_model_data(args):
    #  === prepare data and model
    # ====== Load GPT2 model ========
    model_dir = '../models/'+args.model_dir
    model = GPT2LMHeadModel.from_pretrained(model_dir)
    if USE_CUDA:
        model.cuda()
    tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
    return model, tokenizer

def run_model(args, model, tokenizer, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    if args.length == -1:
        args.length = model.config.n_ctx // 2
    elif args.length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)

    hyp = []
    ref = []
    context = []
    f = open('../result/'+args.output_dir+'.txt','w')
    f_ref = open('../result/reference_'+args.output_dir+'.txt','w')
    
    for i,sample in enumerate(test_loader):
        # if args.cross_attention:
        #     x, type_x, pos_x, lm_x, x_len, meta, keyword_x = sample
        # else:
        #     x, type_x, pos_x, lm_x, x_len, meta = sample
        #     keyword_x = None
        x, type_x, pos_x, lm_x, x_len, attention_mask = sample
        input_len = x_len[0] # The number of tokens of the context utterances
        context_tokens = x[0][:input_len+1] # at evaluation stage, the input is without the ground truth

        generated = 0
        for i in range(args.nsamples // args.batch_size):
            decode_length = int(len(context_tokens))
            # if args.augment:
            #     decode_length = int(0.5 * (5/6) * len(context_tokens))
            out = sample_sequence(
                model=model,length=decode_length,
                context=context_tokens,
                temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
                device=device, attention_mask = attention_mask
            )           
            out = out[:, len(context_tokens):-1].tolist() # the generated result,get rid of eos

            ref.append(tokenizer.decode(x[0].tolist()[len(context_tokens):-1]))
            f_ref.write(tokenizer.decode(x[0].tolist()[len(context_tokens):-1]))
            f_ref.write('\n')

            hyp.append(tokenizer.decode(out[0]))
            f.write(tokenizer.decode(out[0]))
            f.write('\n')

            context.append(tokenizer.decode(x[0].tolist()[:len(context_tokens)]))
    f.close()
    f_ref.close()
    return hyp, ref, context

def calculate_metric(hyp, ref, context, effective_length=1024):
    # ===== Calculate rouge ========
    with open('../result/rouge.txt','a') as f_result:
        rouge = Rouge()
        print(len(hyp))
        print(len(ref))
        hyp, ref = zip(*[(x,y) for x,y in zip(hyp, ref) if len(x)>3 and len(y)>3])
        print(len(hyp))
        hyp = [x[:effective_length] for x in hyp]
        ref = [x[:effective_length] for x in ref]
        scores = rouge.get_scores(hyp, ref,avg=True)
        print("ROUGE",scores)
        import time 
        f_result.write(time.asctime()+'\n')
        f_result.write(args.model_dir+ '\t' + str(effective_length) +'\n')
        f_result.write(str(scores))
        f_result.write('\n')
    # == dump output====
    print("#ref{} #hyp{}".format(len(ref),len(hyp)))
    with open("../data_processed/output_" + args.model_dir+'p{}k{}'.format(args.top_p,args.top_k),'wb') as f_output:
        pickle.dump(zip(hyp,ref,context), f_output)
        
#     # ====== Calculate Meteor =========
#     meteor_sum = 0
#     for i in range(min(len(ref),len(hyp))):
#         meteor_sum += meteor_score([ref[i]],hyp[i])

#     meteor_sum/=min(len(ref),len(hyp))
#     print(meteor_sum)   

def rouge_rank(hyp, ref, context):
    rouge = Rouge()
    # import pdb;pdb.set_trace()
    hyp, ref = zip(*[(x,y) for x,y in zip(hyp, ref) if len(x)>3 and len(y)>3])
    scores = rouge.get_scores(hyp, ref,avg=False) # type: list
    scores_content = zip(scores, hyp, ref, context, range(len(hyp)))
    scores_content = sorted(scores_content, key=lambda x:x[0]['rouge-1']['f'], reverse=True)
    return scores_content

if __name__ == '__main__':
    USE_CUDA = torch.cuda.is_available()
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO)
    logger = logging.getLogger(__name__)

    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='345M_Alex', help='pretrained model name or path to local checkpoint')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--nsamples", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=-1)
    parser.add_argument("--length", type=int, default=64)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--top_p", type=float, default=0)
    parser.add_argument('--output_dir',type=str,default='generate', help="The name of the output file.")
    parser.add_argument('--modified_decoding', action='store_true')
    parser.add_argument('--augment', action='store_true')
    parser.add_argument('--special_input',type=str)
    parser.add_argument('--keyword', action='store_true')
    parser.add_argument('--kbert', action='store_true')
    parser.add_argument('--cross_attention', action='store_true')
    parser.add_argument('--num_turns', type=int, default=5)
    args = parser.parse_args()
    if args.batch_size == -1:
        args.batch_size = 1
    assert args.nsamples % args.batch_size == 0
    print(args)

    # Setup the random seeds.
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)

    model, tokenizer = load_model_data(args)
    split_size = {'train': 0.90, 'test': 0.05, 'val': 0.05}
    data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
    # model, tokenizer, test_loader = load_model_data(args) # TODO: this is for old get_data
    # import pdb;pdb.set_trace()
    hyp, ref, context = run_model(args, model, tokenizer, test_loader)
    sample_ranked = rouge_rank(hyp, ref, context)
    with open("../data_processed/rouge_rank_" + args.model_dir,'wb') as f:
        pickle.dump(sample_ranked, f)
    calculate_metric(hyp, ref, context)
    # calculate_metric(hyp, ref, context, 5)