Skip to content
Snippets Groups Projects
load_data.py 30.5 KiB
Newer Older
DeepLearning VM's avatar
DeepLearning VM committed
import torch
from torch.utils.data import Dataset,DataLoader
from torch.autograd import Variable
import json
import random
import sys
import pickle
from tqdm import tqdm
from collections import deque
DeepLearning VM's avatar
DeepLearning VM committed
import copy
sys.path.append("..")
from utils import text_standardize


USE_CUDA = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor


# ==== Code for data loading =====
class GptDataset(Dataset):
    """Take a list of samples with form [[x,...],y,meta]
    """
    # need 3 special tokens
    # # as <ref start> 2
    # $ as <speaker1> 3
    # % as <speaker2> 4
    # '<|endoftext|>' as <eos> 50256
    def _split(self,x_y_meta):
        x_all = []
        y_all = []
        meta_all = []
        for x,y,meta in x_y_meta:
            meta_all.append(meta)
            x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
            y_all.append(self.tokenizer.encode(text_standardize(y)))

        return x_all,y_all,meta_all
    
    def _filter(self,x_all,y_all,meta_all,filter_mode=None):
        allowed_pattern = ['SR_only','CR_only','Smoking_only','Diet_only']
        data = zip(x_all,y_all,meta_all)
        if filter_mode not in allowed_pattern:
            data_filt = data
        if filter_mode=='SR_only':
            data_filt = [x for x in data if x[2][2]=='SR']
        if filter_mode=='CR_only':
            data_filt = [x for x in data if x[2][2]=='CR']
        if filter_mode=='Smoking_only':
            data_filt = [x for x in data if x[2][1]=='Smoking cessation']
        if filter_mode=='Diet_only':
            data_filt = [x for x in data if x[2][1]=='Weight management']
        x_filt,y_filt,meta_filt = zip(*data_filt)
        return x_filt, y_filt, meta_filt

    def __init__(self,x_y_meta,tokenizer,filter_mode=None,num_turns=5):
        
        self.x_y_meta = x_y_meta
        self.num_turns = num_turns
        self.tokenizer = tokenizer
        self.x_encoded,self.y_encoded,self.meta = self._split(x_y_meta)
        self.x_encoded,self.y_encoded,self.meta = self._filter(self.x_encoded,self.y_encoded,self.meta,filter_mode)
        self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256

    def __getitem__(self,index):
        x = []
        type_x = []
        lm_x = []
        is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation

        for utt in self.x_encoded[index][-self.num_turns:]:
            if is_speaker1: # add the prefix special token for each utterance
                x+=[self.speaker1]
                type_x += [self.speaker1]*(len(utt)+1)
            else:
                x+=[self.speaker2]
                type_x += [self.speaker2]*(len(utt)+1)
            x += utt
            is_speaker1 = not is_speaker1
        lm_x += [-1]*len(x) # all position for the input is masked for loss calculation

        total_input_length = len(x)

        x += [self.ref_start] + self.y_encoded[index] + [self.eos]

        type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
        lm_x += [-1] + self.y_encoded[index] + [self.eos]
        position_x = list(range(len(x)))

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        position_x = torch.Tensor(position_x)
        lm_x = torch.Tensor(lm_x)
        x_len = x.shape[0]
        
        return x,type_x,position_x,lm_x,total_input_length,self.meta[index]

    def __len__(self):
        return len(self.x_encoded)

class GptDataset_aug(Dataset):
    def _split(self,x_y_meta):
        x_all = []
        y_all = []
        meta_all = []
        aug_all = []
        for x,y,meta,aug in x_y_meta: 
            meta_all.append(meta)
            x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
            y_all.append(self.tokenizer.encode(text_standardize(y)))
            aug_all.append(self.tokenizer.encode(text_standardize(aug)))
        return x_all,y_all,meta_all,aug_all

    def __init__(self,x_y_meta,tokenizer,num_turns=5):
        self.x_y_meta = x_y_meta
        self.num_turns = num_turns
        self.tokenizer = tokenizer
        self.x_encoded,self.y_encoded,self.meta,self.aug_encoded = self._split(x_y_meta)
        self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
        self.augment = 5

    def __getitem__(self,index):
        x = []
        type_x = []
        lm_x = []

        x += [self.augment] + self.aug_encoded[index]
        type_x += [self.augment] * len(x)

        is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation

        for utt in self.x_encoded[index][-self.num_turns:]:
            if is_speaker1: # add the prefix special token for each utterance
                x+=[self.speaker1]
                type_x += [self.speaker1]*(len(utt)+1)
            else:
                x+=[self.speaker2]
                type_x += [self.speaker2]*(len(utt)+1)
            x += utt
            is_speaker1 = not is_speaker1
        lm_x += [-1]*len(x) # all position for the input is masked for loss calculation

        total_input_length = len(x)

        x += [self.ref_start] + self.y_encoded[index] + [self.eos]

        type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
        lm_x += [-1] + self.y_encoded[index] + [self.eos]
        position_x = list(range(len(x)))

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        position_x = torch.Tensor(position_x)
        lm_x = torch.Tensor(lm_x)
        x_len = x.shape[0]
        
        return x,type_x,position_x,lm_x,total_input_length,self.meta[index]
    def __len__(self):
        return len(self.x_encoded)
DeepLearning VM's avatar
DeepLearning VM committed
def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
    We should build a custom collate_fn rather than using default collate_fn,
    because merging sequences (including padding) is not supported in default.
    Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
    Args:
        data: list of tuple (src_seq, trg_seq).
            - src_seq: torch tensor of shape (?); variable length.
            - trg_seq: torch tensor of shape (?); variable length.
    Returns:
        src_seqs: torch tensor of shape (batch_size, padded_length).
        src_lengths: list of length (batch_size); valid length for each padded source sequence.
        trg_seqs: torch tensor of shape (batch_size, padded_length).
        trg_lengths: list of length (batch_size); valid length for each padded target sequence.
    """
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths
shensq's avatar
shensq committed
    def merge_matrix(matrices):
        max_size = max([m.shape[-1] for m in attention_mask])
        padded_matrices = torch.zeros(len(matrices), 1,  max_size, max_size)
        for i,m in enumerate(matrices):
            m_size = m.shape[-1]
            padded_matrices[i,:,:m_size,:m_size] = m
        return padded_matrices

DeepLearning VM's avatar
DeepLearning VM committed

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)

    # seperate source and target sequences
shensq's avatar
shensq committed
    src_seqs, trg_seqs, pos_seqs,lm_seqs,total_input_length,attention_mask = zip(*data)
DeepLearning VM's avatar
DeepLearning VM committed

    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)
    pos_seqs, pos_lengths = merge(pos_seqs)
    lm_seqs, lm_lengths = merge(lm_seqs)
shensq's avatar
shensq committed
    if type(attention_mask[0]) is not list:
        attention_mask = merge_matrix(attention_mask)
DeepLearning VM's avatar
DeepLearning VM committed
    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        pos_seqs = pos_seqs.cuda()
        lm_seqs = lm_seqs.cuda()
shensq's avatar
shensq committed
        attention_mask = attention_mask.cuda()
shensq's avatar
shensq committed
    return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, attention_mask
DeepLearning VM's avatar
DeepLearning VM committed

def collate_fn_nli(data):
    """Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
    We should build a custom collate_fn rather than using default collate_fn,
    because merging sequences (including padding) is not supported in default.
    Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
    Args:
        data: list of tuple (src_seq, trg_seq).
            - src_seq: torch tensor of shape (?); variable length.
            - trg_seq: torch tensor of shape (?); variable length.
    Returns:
        src_seqs: torch tensor of shape (batch_size, padded_length).
        src_lengths: list of length (batch_size); valid length for each padded source sequence.
        trg_seqs: torch tensor of shape (batch_size, padded_length).
        trg_lengths: list of length (batch_size); valid length for each padded target sequence.
    """
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)

    # seperate source and target sequences
    src_seqs, trg_seqs, pos_seqs,lm_seqs,label = zip(*data)

    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)
    pos_seqs, pos_lengths = merge(pos_seqs)
    # lm_seqs, lm_lengths = merge(lm_seqs)
    label = torch.tensor(label)
    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        pos_seqs = pos_seqs.cuda()
        # lm_seqs = lm_seqs.cuda()
        label = label.cuda()
    return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),lm_seqs, label

def collate_fn_keyword(data):
    def merge(sequences):
        lengths = [len(seq) for seq in sequences]
        padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq[:end]
        return padded_seqs, lengths

    # sort a list by sequence length (descending order) to use pack_padded_sequence
    data.sort(key=lambda x: len(x[0]), reverse=True)

    # seperate source and target sequences
    src_seqs, trg_seqs, pos_seqs, lm_seqs, total_input_length, meta, keyword_x = zip(*data)

    # merge sequences (from tuple of 1D tensor to 2D tensor)
    src_seqs, src_lengths = merge(src_seqs)
    trg_seqs, trg_lengths = merge(trg_seqs)
    pos_seqs, pos_lengths = merge(pos_seqs)
    lm_seqs, lm_lengths = merge(lm_seqs)
    keyword_x, _ = merge(keyword_x)
    if USE_CUDA:
        src_seqs = src_seqs.cuda()
        trg_seqs = trg_seqs.cuda()
        pos_seqs = pos_seqs.cuda()
        lm_seqs = lm_seqs.cuda()
        keyword_x = keyword_x.cuda()
    return Variable(LongTensor(src_seqs)), Variable(LongTensor(trg_seqs)), Variable(LongTensor(pos_seqs)),Variable(LongTensor(lm_seqs)), total_input_length, meta,Variable(LongTensor(keyword_x))

class SnliDataset(Dataset):
    """Take a list of samples with form [[x,...],y,meta]
    """
    # need 3 special tokens
    # # as <ref start> 2
    # $ as <speaker1> 3
    # % as <speaker2> 4
    # '<|endoftext|>' as <eos> 50256
    def _split(self,data):
        positive_label = set(['entailment'])
        premise = []
        hypothesis = []
        label = []
        for p,h,l in tqdm(data):
            premise.append(self.tokenizer.encode(text_standardize(p)))
            hypothesis.append(self.tokenizer.encode(text_standardize(h)))
            if l in positive_label:
                label.append(torch.tensor(1))
            else:
                label.append(torch.tensor(0))
        return premise,hypothesis,label
    
    def _filter(self,premise,hypothesis,label,filter_mode=None):
        data = zip(premise,hypothesis,label)
        if filter_mode == None:
            data_filt = data
        else:
            data_filt = [x for x in data if x[2]!='-']
            
        premise_filt,hypothesis_filt,label_filt = zip(*data_filt)
        return premise_filt,hypothesis_filt,label_filt

    def parse_snli(self,path=None):
        with open(path) as f:
            data = [json.loads(line) for line in f]
        data_processed = [(line['sentence1'],line['sentence2'],line['gold_label']) for line in data]
        return data_processed

    def __init__(self,tokenizer,path='../data/snli_1.0/snli_1.0_train.jsonl',filter_mode=None,num_turns=5):
        
        self.data = self.parse_snli(path)
        self.tokenizer = tokenizer
        self.premise_encoded,self.hypothesis_encoded,self.label = self._split(self.data)
        self.premise_encoded,self.hypothesis_encoded,self.label = self._filter(self.premise_encoded,self.hypothesis_encoded,self.label,filter_mode)
        self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256

    def __getitem__(self,index):
        x = []
        type_x = []
        lm_x = []
        
        x += [self.speaker1]
        x += self.premise_encoded[index]
        type_x += [self.speaker1]*(len(self.premise_encoded[index])+1) # the premise part
        
        x += [self.ref_start] 
        x += self.hypothesis_encoded[index]
        x += [self.eos]
        type_x += [self.ref_start]*(len(self.hypothesis_encoded[index])+2) # the hypothesis part
        
        label = self.label[index]
        
        position_x = list(range(len(x)))

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        position_x = torch.Tensor(position_x)
        
        return x,type_x,position_x,lm_x,label

    def __len__(self):
        return len(self.premise_encoded)

class GptDataset_full(Dataset):
    def _split(self,x_y_meta):
        x_all = []
        y_all = []
        meta_all = []
        aug_all = []
        keyword_all = []
        for x, y, meta, aug, keyword in x_y_meta:
            meta_all.append(meta)
            # update for the new data format
            aug = ''.join([a[1] for a in aug])
DeepLearning VM's avatar
DeepLearning VM committed
            x_all.append([self.tokenizer.encode(text_standardize(x_i)) for x_i in x])
            y_all.append(self.tokenizer.encode(text_standardize(y)))
            aug_all.append(self.tokenizer.encode(text_standardize(aug)))
            keyword_all.append(self.tokenizer.encode(text_standardize(keyword)))
        return x_all,y_all,meta_all,aug_all, keyword_all

    def _filt(self, length=1024):
        data = zip(self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded)
        data = [sample for sample in data if sum([len(sen) for sen in sample[0]][-self.args.num_turns:])+len(sample[1])+len(sample[3])+len(sample[4]) < 850]
        self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = zip(*data)
        self.x_encoded = list(self.x_encoded)
        self.y_encoded = list(self.y_encoded)
        self.meta = list(self.meta)
        self.aug_encoded = list(self.aug_encoded)
        self.keyword_encoded = list(self.keyword_encoded)

    def __init__(self,x_y_meta,tokenizer,args):
        self.x_y_meta = x_y_meta
        self.num_turns = args.num_turns
        self.tokenizer = tokenizer
        self.args = args
        self.x_encoded,self.y_encoded,self.meta,self.aug_encoded, self.keyword_encoded = self._split(x_y_meta)
        self._filt() # TODO: add back filt for mix-review
        self.ref_start, self.speaker1,self.speaker2,self.eos = 2,3,4,50256
        self.augment = 5
        if self.args.augment:
            print("Using augment sentences.")
        if self.args.keyword:
            print("Using keywords.")

    def __getitem__(self,index):
        x = []
        type_x = []
        lm_x = []

        if self.args.augment:
            x += [self.augment] + self.aug_encoded[index]
        if self.args.keyword:
            x += [self.augment] + self.keyword_encoded[index]
        type_x += [self.augment] * len(x)
        
        is_speaker1 = bool(self.num_turns % 2) # which speaker start the conversation

        for utt in self.x_encoded[index][-self.num_turns:]:
            if is_speaker1: # add the prefix special token for each utterance
                x+=[self.speaker1]
                type_x += [self.speaker1]*(len(utt)+1)
            else:
                x+=[self.speaker2]
                type_x += [self.speaker2]*(len(utt)+1)
            x += utt
            is_speaker1 = not is_speaker1
        lm_x += [-100]*len(x) # all position for the input is masked for loss calculation

        total_input_length = len(x)

        x += [self.ref_start] + self.y_encoded[index] + [self.eos]

        type_x += [self.ref_start]*(len(self.y_encoded[index])+2)
        lm_x += [-100] + self.y_encoded[index] + [self.eos]
        position_x = list(range(len(x)))

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        position_x = torch.Tensor(position_x)
        lm_x = torch.Tensor(lm_x)
        x_len = x.shape[0]
        
        return x,type_x,position_x,lm_x,total_input_length,self.meta[index]

    def __len__(self):
        return len(self.x_encoded)

class GptDataset_KBERT_old(Dataset):
    def get_comet_aug_deque(self, comet_data, num_turns=5):
        clause_dq = deque()
        for comet_in, comet_out in comet_data:
            if comet_out == "":
                continue
            loc = int(comet_in.split()[0])
            if loc >= (10 - num_turns):
                clause_dq.append((loc, comet_out))
        return clause_dq

    def __init__(self, x_y_meta, tokenizer, args):
        self.data = x_y_meta
        self.num_turns = args.num_turns
        self.tokenizer = tokenizer
        self.args = args
        self.ref_start, self.speaker1, self.speaker2, self.eos = 2, 3, 4, 50256
        self.augment = 5
DeepLearning VM's avatar
DeepLearning VM committed

        if self.args.augment:
            print("Using augment sentences.")
        if self.args.keyword:
            print("Using keywords.")
DeepLearning VM's avatar
DeepLearning VM committed

    def __getitem__(self, index):
DeepLearning VM's avatar
DeepLearning VM committed
        x = []
        type_x = []
        lm_x = []
DeepLearning VM's avatar
DeepLearning VM committed

        dq = self.get_comet_aug_deque(self.data[index][3])  # the comet info
shensq's avatar
shensq committed

        mask_info = []

        context = self.data[index][0]
        response = self.data[index][1]
DeepLearning VM's avatar
DeepLearning VM committed

        is_speaker1 = bool(self.args.num_turns % 2)
        soft_loc = 0  # keep tract of the location of main sentences, point to the next token to be added
shensq's avatar
shensq committed
        utterance_start_loc = 0
        for i in range(10 - self.args.num_turns, 10):
            utternace_encoded = self.tokenizer.encode(text_standardize(context[i]))
DeepLearning VM's avatar
DeepLearning VM committed

            # add the prefix special token for each utterance
            if is_speaker1:
                x += [self.speaker1]
                type_x += [self.speaker1] * (len(utternace_encoded) + 1)
            else:
                x += [self.speaker2]
                type_x += [self.speaker2] * (len(utternace_encoded) + 1)
            x += utternace_encoded
shensq's avatar
shensq committed
            utterance_end_loc = len(x)
DeepLearning VM's avatar
DeepLearning VM committed

            soft_position_x += list(range(soft_loc, soft_loc + len(utternace_encoded) + 1))
DeepLearning VM's avatar
DeepLearning VM committed

            # add the aug, if it is the right place
            while len(dq) != 0 and dq[0][0] == i:
                comet_output = dq.popleft()[1]
                comet_encoded = self.tokenizer.encode(text_standardize(comet_output))
shensq's avatar
shensq committed

                x += [self.augment] + comet_encoded
                type_x += [self.augment] * (len(comet_encoded) + 1)
                soft_position_x += list(range(soft_loc, soft_loc + len(comet_encoded) + 1))
shensq's avatar
shensq committed
                mask_info.append([utterance_start_loc, utterance_end_loc, len(comet_encoded)+1])
            # update the pointer to the new seq end, add one for the delimiter token
            soft_loc += len(utternace_encoded) + 1
DeepLearning VM's avatar
DeepLearning VM committed
            is_speaker1 = not is_speaker1
shensq's avatar
shensq committed
            utterance_start_loc = len(x)

        lm_x += [-100] * len(x)  # all position for the input is masked for loss calculation
DeepLearning VM's avatar
DeepLearning VM committed
        total_input_length = len(x)

        response_encoded = self.tokenizer.encode(text_standardize(response))
        x += [self.ref_start] + response_encoded + [self.eos]

        type_x += [self.ref_start] * (len(response_encoded) + 2)
        lm_x += [-100] + response_encoded + [self.eos]

        soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2))
DeepLearning VM's avatar
DeepLearning VM committed

        x = torch.Tensor(x)
        type_x = torch.Tensor(type_x)
        soft_position_x = torch.Tensor(soft_position_x)
        lm_x = torch.Tensor(lm_x)
        x_len = x.shape[0]
DeepLearning VM's avatar
DeepLearning VM committed

shensq's avatar
shensq committed
        # process the mask
        attention_mask = torch.tril(torch.ones(x_len, x_len))
        for u_start, u_end, branch_len in mask_info:
            attention_mask[u_end+branch_len+1: u_end+1:u_end+branch_len+1] = 0 # [1st token after branch: , 1st token in branch: last token in branch+1]
        attention_mask = attention_mask.view(1, x_len, x_len)

        return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask
DeepLearning VM's avatar
DeepLearning VM committed

    def __len__(self):
        return len(self.data)
DeepLearning VM's avatar
DeepLearning VM committed


class GptDataset_KBERT(Dataset):
    def __init__(self, tokenizer, args):
        pickle_handler = open("../data_processed/data_comet_dict", 'rb')

        self.data = pickle.load(pickle_handler)
        self.tokenizer = tokenizer
        self.args = args
        self.num_turns = args.num_turns
        self.ref, self.speaker1, self.speaker2 = tokenizer.ref, tokenizer.speaker1, tokenizer.speaker2
        self.eos = tokenizer.eos
        self.augment = tokenizer.augment

        if not self.args.kbert:
            self.args.kbert_mask = False
            self.args.kbert_position = False
            print("Not using kbert scheme.")
        if self.args.kbert_mask:
            print("using kbert-style attention mask")
        if self.args.kbert_position:
            print("using kbert-style soft-postional encoding")

    def __getitem__(self, index):
        # preprare variables
        x = []
        type_x = []
        lm_x = []
        soft_position_x = []
        attention_mask = []

        # 0. unpack needed input info
        context = self.data[index]['context']
        srl_mask = self.data[index]['srl_mask']
        comet_output = self.data[index]['comet']  # a list of dict or None
        response = self.data[index]['response']

        # 1. encode the response.
        response_encoded = self.tokenizer.encode(text_standardize(response))

        # 2. encode each utterance.
        context_encoded = []
        for i in range(10 - self.args.num_turns, 10):
            context_encoded.append(self.tokenizer.encode(text_standardize(context[i])))

        # 3. encode the comet output for each utterance.
        comet_encoded = []
        for i in range(len(comet_output)):
            comet_text_i = ""
            if comet_output[i] is None:
                comet_encoded.append(None)
                continue
            for rel in comet_output[i]:
                for candidate in comet_output[i][rel]['beams']:
                    if candidate != 'none':
                        comet_text_i += rel + " " + candidate + " "
                        break
            comet_encoded.append(self.tokenizer.encode(text_standardize(comet_text_i)))

        # 4. use the encoded seq to build the input and attention mask
        is_speaker1 = bool(self.args.num_turns % 2)
        soft_loc = 0
        for i in range(self.args.num_turns):

            # add an utterance. update x & type_x
            if is_speaker1:
                x += [self.speaker1]
                type_x += [self.speaker1] * (len(context_encoded[i]) + 1)
            else:
                x += [self.speaker2]
                type_x += [self.speaker2] * (len(context_encoded[i]) + 1)
            x += context_encoded[i]

            # update pos_x
            # concate aug part after x. but the index is from the last related token
            soft_position_x += list(range(soft_loc, soft_loc + (len(context_encoded[i]) + 1)))

            last_related_token_index = len(srl_mask[i]) - 1 - srl_mask[i][::-1].index(1)

            # add comet output
            if self.args.kbert:
                if comet_encoded[i] is not None:
                    x += [self.augment] + comet_encoded[i]
                    type_x += [self.augment] * (len(comet_encoded[i]) + 1)
                    # +2 for the special token and the requirement of one-number larger than the utterance
                    soft_position_x += list(range(soft_loc + 2 + last_related_token_index,
                                                  soft_loc + 2 + last_related_token_index + (len(comet_encoded[i]) + 1)))

            soft_loc += (len(context_encoded[i]) + 1)
            is_speaker1 = not is_speaker1

        lm_x += [-100] * len(x)  # all position for the input is masked for loss calculation
        total_input_length = len(x)

        response_encoded = self.tokenizer.encode(text_standardize(response))
        x += [self.ref] + response_encoded + [self.eos]

        type_x += [self.ref] * (len(response_encoded) + 2)
        lm_x += [-100] + response_encoded + [self.eos]

        soft_position_x += list(range(soft_loc, soft_loc + len(response_encoded) + 2))
        
        
        x = x[:self.max_length]
        type_x = type_x[:self.max_length]
        lm_x = lm_x[:self.max_length]
        soft_position_x = soft_position_x[:self.max_length]
        
        # build attention mask
        attention_mask = torch.tril(torch.ones(len(x), len(x)))
        if self.args.kbert_mask:
            aug_start = 0  # where the aug begin
            utt_start = 0  # where the utt begin

            for turn in range(self.args.num_turns):
                aug_start += len(context_encoded[turn]) + 1
                # iter through every token in the comet output
                if comet_encoded[turn] is not None:
                    for aug_token_pos in range(aug_start, aug_start + len(comet_encoded[turn]) + 1):
                        # set the attention related to the aug part to be all zero
                        attention_mask[aug_token_pos, :] = torch.zeros_like(attention_mask[aug_token_pos, :])

                        attention_mask[:, aug_token_pos] = torch.zeros_like(attention_mask[:, aug_token_pos])
                        # set attention on related token to be one
                        for normal_token_pos in range(len(context_encoded[turn])):
                            attention_mask[aug_token_pos, utt_start + normal_token_pos + 1] += srl_mask[turn][
                                normal_token_pos]
                        # set attention on previous aug tokens to be one
                        for previous_aug_token_poc in range(aug_start, aug_token_pos + 1):
                            attention_mask[aug_token_pos, previous_aug_token_poc] += 1

                    aug_start += len(comet_encoded[turn]) + 1
                    utt_start += len(comet_encoded[turn]) + 1
                utt_start += (len(context_encoded[turn]) + 1)

        x = torch.tensor(x)
        type_x = torch.tensor(type_x)
        if not self.args.kbert_position:
            soft_position_x = list(range(len(x)))
        soft_position_x = torch.tensor(soft_position_x)
        lm_x = torch.tensor(lm_x)
        return x, type_x, soft_position_x, lm_x, total_input_length, attention_mask

    def __len__(self):
        return len(self.data)

DeepLearning VM's avatar
DeepLearning VM committed
def get_data(args, tokenizer, split_size):
    """
    Return the data loaders needed for training and evaluation.
    :param args: command line arguments.
    :param tokenizer: the tokenizer used in preparing the data.
    :param split_size: the portion of train, test, validation set.
    :return data_loader: The data loader for the training set.
    :return val_loader: The data loader for the validation set.
    """
    # random.seed(args.seed)
    # torch.random.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    # torch.manual_seed(args.seed)
    if args.special_input:
        print("Using mutated data.")
        pickle_handler = open('../data_processed/' + args.special_input, 'rb')
        x_y_meta = pickle.load(pickle_handler)
        gpt_data = GptDataset(x_y_meta, tokenizer, args.output_dir, num_turns=args.num_turns)
#     #======================origin without kbert======
#     elif not args.kbert:
#         print("Using full data.")
#         pickle_handler = open('../data_processed/x_y_with_comet', 'rb') # TODO: change back to the old data.
#         x_y_meta = pickle.load(pickle_handler)
#         gpt_data = GptDataset_full(x_y_meta, tokenizer, args=args)
    else:
        print("Using KBERT data")
        gpt_data = GptDataset_KBERT(tokenizer, args=args)
DeepLearning VM's avatar
DeepLearning VM committed
    print("Dataset initialized. There are {} samples.".format(len(gpt_data)))

    test_size = int(len(gpt_data) * split_size['test'])
    val_size = int(len(gpt_data) * split_size['val'])

    gpt_train, gpt_test, gpt_val = torch.utils.data.random_split(gpt_data,
                                                                 [len(gpt_data) - test_size - val_size, test_size,
                                                                  val_size])

    # import pdb;pdb.set_trace()
    # with open('../../mi_data/train','wb') as f:
    #     pickle.dump(gpt_train, f)
    # with open('../../mi_data/test','wb') as f:
    #     pickle.dump(gpt_test, f)
    # with open('../../mi_data/val','wb') as f:
    #     pickle.dump(gpt_val, f)

    if 'train_batch_size' not in args:
        args.train_batch_size = 1
    data_loader = DataLoader(dataset=gpt_train, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
                            collate_fn=collate_fn)
    test_loader = DataLoader(dataset=gpt_test, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
    val_loader = DataLoader(dataset=gpt_val, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn)
    return data_loader, test_loader, val_loader

def prepare_mix_review(args, tokenizer):
    print("Preparing Alexander dataset")
    pickle_handler = open('../data_processed/data_alex', 'rb')
    data = pickle.load(pickle_handler)
    gpt_alex = GptDataset_full(data, tokenizer, args=args)
    print("Alexander dataset prepared. Has {} samples".format(len(gpt_alex)))
    return gpt_alex

def update_mix_review(gpt_train, gpt_alex, epoch, args, mix_ratio=4, mix_decay=0.7, collate_fn=collate_fn):
    mix_amount = int(mix_ratio*(0.7**epoch)*len(gpt_train))
    gpt_alex_active,_ = torch.utils.data.random_split(gpt_alex, [mix_amount, len(gpt_alex)-mix_amount])

    data_loader = DataLoader(dataset=gpt_train+gpt_alex_active, batch_size=args.train_batch_size, shuffle=True, drop_last=True,
                                collate_fn=collate_fn)