Skip to content
Snippets Groups Projects
gpt_tuning.py 8.84 KiB
Newer Older
  • Learn to ignore specific revisions
  • DeepLearning VM's avatar
    DeepLearning VM committed
    # Path to the pytorch checkpoint
    # /Users/shensq/Documents/LIT_ai_counseling/gpt2/models/pytorch_345M'
    import sys
    
    # sys.path.insert(0, '/home/shensq/LIT/pip_package')
    
    import re
    import argparse
    import torch
    import pickle
    import os
    import transformers
    from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, AdamW, WEIGHTS_NAME, CONFIG_NAME
    from torch.utils.data import Dataset, DataLoader
    from torch.autograd import Variable
    from tqdm import tqdm, trange
    import random
    from utils import clean_text, text_standardize, construct_grouped_parameters, get_unfreezing_funcs
    
    from gpt_loader import GptDataset, collate_fn,collate_fn_keyword, prepare_mix_review, update_mix_review, get_data
    
    DeepLearning VM's avatar
    DeepLearning VM committed
    
    # OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
    import logging
    
    
    
    def evaluate(model, data_loader, use_keyword=None):
        """
        Evaluate the model on validation set.
        :param model: The model being training.
        :param data_loader: the data loader for validation set.
        :param use_keyword: whether the input contains keyword or not.
        :return: eval_loss: the average loss on the validation set.
        """
        model.eval()
        eval_loss = 0
        for sample in tqdm(data_loader):
            if use_keyword:
                x, type_x, pos_x, lm_x, x_len, _, keyword_x = sample
            else:
                x, type_x, pos_x, lm_x, x_len, _ = sample
                keyword_x = None
    #         loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, key_word=keyword_x,
    #                      use_keyword=use_keyword)[0]
            loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x)[0]
            eval_loss += loss.item()
        eval_loss /= len(data_loader)
        model.train()
        return eval_loss
    
    
    def parse_arguments():
        """
        Parse command line argument using argparse.
        :return args: A parser object with hyper-parameters' name and their values.
        """
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_dir", default='345M_Alex', type=str, required=False,
                            help="The directory of the model to be tuned.")
        parser.add_argument("--output_dir", default='mi_tuned', type=str, required=False,
                            help="The output directory where the model predictions and checkpoints will be written.")
        parser.add_argument('--seed', type=int, default=42)
        parser.add_argument('--num_train_epochs', type=int, default=1)
    
    shensq's avatar
    shensq committed
        parser.add_argument('--train_batch_size', type=int, default=2)
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        parser.add_argument('--max_grad_norm', type=int, default=1)
        parser.add_argument('--learning_rate', type=float, default=6.25e-5)
        parser.add_argument('--warmup_proportion', type=float, default=0.1)
        parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
        parser.add_argument('--weight_decay', type=float, default=0.01)
        parser.add_argument('--lm_coef', type=float, default=0.9)
        parser.add_argument('--n_valid', type=int, default=374)
        parser.add_argument('--augment', action='store_true')
        parser.add_argument('--keyword', action='store_true')
        parser.add_argument('--cross_attention', action='store_true')
        parser.add_argument('--special_input', type=str)
        parser.add_argument('--first_K_tokens', type=int, default=1024)
        parser.add_argument('--use_disc_lr', action='store_true')
        parser.add_argument('--use_unfreezing', action='store_true')
        parser.add_argument('--num_turns', type=int, default=5)
    
        parser.add_argument('--kbert', action='store_true')
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        args = parser.parse_args()
        print(args)
        return args
    
    def load_model(args):
        """
        Load model and the corresponding tokenizer from pre-trained weight.
        :param args: The command line arguments.
        :return model: The main model.
        :return tokenzier: The tokenzier comes with the main model.
        """
        USE_CUDA = torch.cuda.is_available()
        # ====== Load GPT2 model ========
        model_dir = '../models/' + args.model_dir
    #     model = GPT2LMHeadModel.from_pretrained(model_dir)
    
        model = GPT2LMHeadModel.from_pretrained('gpt2')
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        if USE_CUDA:
            model.cuda()
    #     tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
    
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        print('Model loaded.')
        return model, tokenizer
    
    def main():
        args = parse_arguments()
    
        # ====== Set random seed =========
        random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.manual_seed(args.seed)
        # ======= Prepare ==========
        logging.basicConfig(level=logging.INFO)
        USE_CUDA = torch.cuda.is_available()
        FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
        LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
        ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor
    
        model, tokenizer = load_model(args)
        # =============== Load & process data ==============
    
    shensq's avatar
    shensq committed
        split_size = {'train': 0.90, 'test': 0.05, 'val': 0.05}
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        data_loader, test_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer)
        # gpt_alex = prepare_mix_review(args, tokenizer)
        # data_loader, val_loader = get_data(args, split_size=split_size, tokenizer=tokenizer) # TODO: this is for old get_data
    
    DeepLearning VM's avatar
    DeepLearning VM committed
        # ========== Prepare optimizer =============
        # the gpt2 model from library has unnamed LM head. LM head's weights are tied to input embedding
        num_train_optimization_steps = len(data_loader) * args.num_train_epochs // args.train_batch_size
        param_optimizer = list(model.named_parameters())
        optimizer_grouped_parameters = construct_grouped_parameters(param_optimizer, args.learning_rate,
                                                                    use_discr=args.use_disc_lr)
    
        lm_funcs = get_unfreezing_funcs(optimizer_grouped_parameters, warmup_portion=args.warmup_proportion,
                                        total_steps=num_train_optimization_steps, use_unfreezing=args.use_unfreezing)
    
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lm_funcs)
    
        # Training
        print("Start training.")
        model.train()
        exp_average_loss = None
        progress_bar = trange(int(args.num_train_epochs), desc="Epoch", leave=True)
        min_eval_loss = 100  # large enough number
        early_terminate_counter = 0
        for epo in progress_bar:
        # for _ in range(int(args.num_train_epochs)):
            # data_loader = update_mix_review(gpt_train, gpt_alex, epo, mix_ratio=4, mix_decay=0.7)
            for sample in tqdm(data_loader):
            # for sample in data_loader:
    #             import pdb;pdb.set_trace()
    
    shensq's avatar
    shensq committed
    #             if args.cross_attention:
    #                 x, type_x, pos_x, lm_x, x_len, _, keyword_x = sample
    #             else:
    #                 x, type_x, pos_x, lm_x, x_len, _ = sample
    #                 keyword_x = None
    
                x, type_x, pos_x, lm_x, x_len, attention_mask = sample
    
    DeepLearning VM's avatar
    DeepLearning VM committed
                input_len = x_len[0]
                lm_x[:, x_len[0] + 1 + args.first_K_tokens:-1] = -1
    #             loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, key_word=keyword_x,
    #                          use_keyword=args.cross_attention)[0]
    
    shensq's avatar
    shensq committed
                loss = model(x, position_ids=pos_x, token_type_ids=type_x, labels=lm_x, attention_mask=attention_mask)[0]
    
    DeepLearning VM's avatar
    DeepLearning VM committed
                loss.backward()
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                exp_average_loss = loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
                progress_bar.set_description("Training loss: {}".format(exp_average_loss))
    
            eval_loss = evaluate(model, val_loader, use_keyword=args.cross_attention)
            print("Eval loss: {}".format(eval_loss))
    
    shensq's avatar
    shensq committed
            if eval_loss < min_eval_loss:  # save the model only when the loss is the smallest
            # if True:
    
    DeepLearning VM's avatar
    DeepLearning VM committed
                early_terminate_counter = 0
                min_eval_loss = eval_loss
                # ==== Save the model ====
                # Save a trained model, configuration and tokenizer
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                # If we save using the predefined names, we can load using `from_pretrained`
                output_dir = '../models/'
                output_model_file = os.path.join(output_dir + args.output_dir, WEIGHTS_NAME)
                output_config_file = os.path.join(output_dir + args.output_dir, CONFIG_NAME)
    
                torch.save(model_to_save.state_dict(), output_model_file)
                model_to_save.config.to_json_file(output_config_file)
                tokenizer.save_vocabulary(output_dir + args.output_dir)
            else:
                print("eval loss increasing!")
                early_terminate_counter += 1
    
    shensq's avatar
    shensq committed
                if early_terminate_counter > 3:  # if the eval loss does not decrease for 5 epochs, terminate early.
    
    DeepLearning VM's avatar
    DeepLearning VM committed
                    return
    
    if __name__ == '__main__':
        main()