Haaribo's picture
Add application file
8d4ee22
from torch.optim import SGD, lr_scheduler
from configs.qsenn_training_params import QSENNScheduler
from configs.sldd_training_params import OptimizationScheduler
from training.img_net import get_default_img_schedule, get_default_img_optimizer
def get_optimizer(model, schedulingClass):
lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params()
print("Optimizer LR set to ", lr)
if lr is None: # Dense Training on ImageNet
print("Learning rate is None, using Default Recipe for Resnet50")
default_img_optimizer = get_default_img_optimizer(model)
default_img_schedule = get_default_img_schedule(default_img_optimizer)
return default_img_optimizer, default_img_schedule, 600
if finetune:
param_list = [x for x in model.parameters() if x.requires_grad]
else:
param_list = model.parameters()
if finetune:
optimizer = SGD(param_list,lr, momentum=0.95,
weight_decay=weight_decay)
else:
classifier_params_name = ["linear.bias","linear.weight"]
classifier_params = [x[1] for x in
list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))]
base_params = [x[1] for x in list(
filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))]
optimizer = SGD([
{'params': base_params},
{"params": classifier_params, 'lr': 0.01}
], momentum=0.9, lr=lr, weight_decay=weight_decay)
# Make schedule
schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma)
return optimizer, schedule, n_epochs
def get_scheduler_for_model(model, dataset):
if model == "qsenn":
return QSENNScheduler(dataset)
elif model == "sldd":
return OptimizationScheduler(dataset)