File size: 2,890 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Finetuning functions to do post-distillation
"""
from os.path import join
from omegaconf import OmegaConf

import torch
from torch.nn import Module

from src.utils.setup import update_config_from_args
from src.dataloaders import load_data
from src.trainer import get_trainer, get_optimizer, get_scheduler


def prepare_finetune_configs(args, model_config: dict,
                             finetune_config_name: str = None,
                             finetune_checkpoint_name: str = None,
                             config_dir='./configs/experiment'):
    """
    Prepare finetuning configs
    """
    # Load finetuning config
    finetune_config = (finetune_config_name if finetune_config_name is not None else 
                       finetune_checkpoint_name.split('-f=')[-1].split('-')[0])
    finetune_config_path = join(config_dir, f'{finetune_config}.yaml')
    finetune_config = OmegaConf.load(finetune_config_path)
    finetune_config = update_config_from_args(finetune_config, args,
                                              ignore_args=['lr', 'weight_decay'])
    # Update data tokenizer to match model
    if getattr(finetune_config.dataset, 'pretrained_model_config', None) is not None:
        for k in ['pretrained_model_name_or_path', 'cache_dir']:
            finetune_config.dataset.pretrained_model_config[k] = model_config['model'][k]
    # Set finetuning args
    for arg, argv in finetune_config.trainer.items():
        if arg != 'name':
            setattr(args, arg, argv)
    for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
        setattr(args, _config, OmegaConf.to_container(getattr(finetune_config, _config)))
    return finetune_config, args


def get_finetuner(model: Module, finetune_config: dict, device: torch.device, 
                  args: any, wandb: any, initial_eval: bool = False):
    """
    Initialize finetuning trainer
    """
    model.to(device)  # if using a fused optimizer
    model.train()

    # Initialize optimizer and scheduler
    optimizer = get_optimizer(model=model, **finetune_config.optimizer)
    scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler)

    dataloaders  = load_data(finetune_config.dataset, finetune_config.dataloader)
    train_loader = dataloaders[finetune_config.trainer.train_split]
    eval_loader  = dataloaders[finetune_config.trainer.val_split]

    OurTrainer = get_trainer(finetune_config.trainer.name)
    trainer = OurTrainer(model=model,
                         args=args,
                         train_loader=train_loader,
                         eval_loader=eval_loader,
                         optimizer_and_scheduler=(optimizer, scheduler),
                         device=device,
                         wandb=wandb,
                         checkpoint_suffix='_ft',
                         **finetune_config.trainer)
    return trainer