Spaces:
Sleeping
Sleeping
from torch.optim import SGD, lr_scheduler | |
from configs.qsenn_training_params import QSENNScheduler | |
from configs.sldd_training_params import OptimizationScheduler | |
from training.img_net import get_default_img_schedule, get_default_img_optimizer | |
def get_optimizer(model, schedulingClass): | |
lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params() | |
print("Optimizer LR set to ", lr) | |
if lr is None: # Dense Training on ImageNet | |
print("Learning rate is None, using Default Recipe for Resnet50") | |
default_img_optimizer = get_default_img_optimizer(model) | |
default_img_schedule = get_default_img_schedule(default_img_optimizer) | |
return default_img_optimizer, default_img_schedule, 600 | |
if finetune: | |
param_list = [x for x in model.parameters() if x.requires_grad] | |
else: | |
param_list = model.parameters() | |
if finetune: | |
optimizer = SGD(param_list,lr, momentum=0.95, | |
weight_decay=weight_decay) | |
else: | |
classifier_params_name = ["linear.bias","linear.weight"] | |
classifier_params = [x[1] for x in | |
list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))] | |
base_params = [x[1] for x in list( | |
filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))] | |
optimizer = SGD([ | |
{'params': base_params}, | |
{"params": classifier_params, 'lr': 0.01} | |
], momentum=0.9, lr=lr, weight_decay=weight_decay) | |
# Make schedule | |
schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma) | |
return optimizer, schedule, n_epochs | |
def get_scheduler_for_model(model, dataset): | |
if model == "qsenn": | |
return QSENNScheduler(dataset) | |
elif model == "sldd": | |
return OptimizationScheduler(dataset) |