File size: 1,891 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from torch.optim import SGD, lr_scheduler

from configs.qsenn_training_params import QSENNScheduler
from configs.sldd_training_params import OptimizationScheduler
from training.img_net import get_default_img_schedule, get_default_img_optimizer


def get_optimizer(model,  schedulingClass):
    lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params()
    print("Optimizer LR set to ", lr)
    if lr is None: # Dense Training on ImageNet
        print("Learning rate is None, using Default Recipe for Resnet50")
        default_img_optimizer = get_default_img_optimizer(model)
        default_img_schedule = get_default_img_schedule(default_img_optimizer)
        return default_img_optimizer, default_img_schedule, 600
    if finetune:
        param_list = [x for x in model.parameters() if x.requires_grad]
    else:
        param_list = model.parameters()


    if finetune:
        optimizer = SGD(param_list,lr, momentum=0.95,
                        weight_decay=weight_decay)
    else:
        classifier_params_name = ["linear.bias","linear.weight"]
        classifier_params = [x[1] for x in
                             list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))]
        base_params = [x[1] for x in list(
            filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))]

        optimizer = SGD([
            {'params': base_params},
            {"params": classifier_params, 'lr': 0.01}
        ], momentum=0.9, lr=lr, weight_decay=weight_decay)
    # Make schedule
    schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma)
    return optimizer, schedule, n_epochs


def get_scheduler_for_model(model, dataset):
    if model == "qsenn":
        return QSENNScheduler(dataset)
    elif model == "sldd":
        return OptimizationScheduler(dataset)