|
from torch.optim import SGD, lr_scheduler |
|
|
|
from configs.qsenn_training_params import QSENNScheduler |
|
from configs.sldd_training_params import OptimizationScheduler |
|
from training.img_net import get_default_img_schedule, get_default_img_optimizer |
|
|
|
|
|
def get_optimizer(model, schedulingClass): |
|
lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params() |
|
print("Optimizer LR set to ", lr) |
|
if lr is None: |
|
print("Learning rate is None, using Default Recipe for Resnet50") |
|
default_img_optimizer = get_default_img_optimizer(model) |
|
default_img_schedule = get_default_img_schedule(default_img_optimizer) |
|
return default_img_optimizer, default_img_schedule, 600 |
|
if finetune: |
|
param_list = [x for x in model.parameters() if x.requires_grad] |
|
else: |
|
param_list = model.parameters() |
|
|
|
|
|
if finetune: |
|
optimizer = SGD(param_list,lr, momentum=0.95, |
|
weight_decay=weight_decay) |
|
else: |
|
classifier_params_name = ["linear.bias","linear.weight"] |
|
classifier_params = [x[1] for x in |
|
list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))] |
|
base_params = [x[1] for x in list( |
|
filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))] |
|
|
|
optimizer = SGD([ |
|
{'params': base_params}, |
|
{"params": classifier_params, 'lr': 0.01} |
|
], momentum=0.9, lr=lr, weight_decay=weight_decay) |
|
|
|
schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma) |
|
return optimizer, schedule, n_epochs |
|
|
|
|
|
def get_scheduler_for_model(model, dataset): |
|
if model == "qsenn": |
|
return QSENNScheduler(dataset) |
|
elif model == "sldd": |
|
return OptimizationScheduler(dataset) |