| import sys | |
| import warnings | |
| from bisect import bisect_right | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import lr_scheduler | |
| import step1x3d_geometry | |
| def get_scheduler(name): | |
| if hasattr(lr_scheduler, name): | |
| return getattr(lr_scheduler, name) | |
| else: | |
| raise NotImplementedError | |
| def getattr_recursive(m, attr): | |
| for name in attr.split("."): | |
| m = getattr(m, name) | |
| return m | |
| def get_parameters(model, name): | |
| module = getattr_recursive(model, name) | |
| if isinstance(module, nn.Module): | |
| return module.parameters() | |
| elif isinstance(module, nn.Parameter): | |
| return module | |
| return [] | |
| def parse_optimizer(config, model): | |
| if hasattr(config, "params"): | |
| params = [ | |
| {"params": get_parameters(model, name), "name": name, **args} | |
| for name, args in config.params.items() | |
| ] | |
| step1x3d_geometry.debug(f"Specify optimizer params: {config.params}") | |
| else: | |
| if hasattr(config, "only_requires_grad") and config.only_requires_grad: | |
| params = list(filter(lambda p: p.requires_grad, model.parameters())) | |
| else: | |
| params = model.parameters() | |
| if config.name in ["FusedAdam"]: | |
| import apex | |
| optim = getattr(apex.optimizers, config.name)(params, **config.args) | |
| elif config.name in ["Prodigy"]: | |
| import prodigyopt | |
| optim = getattr(prodigyopt, config.name)(params, **config.args) | |
| else: | |
| optim = getattr(torch.optim, config.name)(params, **config.args) | |
| return optim | |
| def parse_scheduler_to_instance(config, optimizer): | |
| if config.name == "ChainedScheduler": | |
| schedulers = [ | |
| parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers | |
| ] | |
| scheduler = lr_scheduler.ChainedScheduler(schedulers) | |
| elif config.name == "Sequential": | |
| schedulers = [ | |
| parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers | |
| ] | |
| scheduler = lr_scheduler.SequentialLR( | |
| optimizer, schedulers, milestones=config.milestones | |
| ) | |
| else: | |
| scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) | |
| return scheduler | |
| def parse_scheduler(config, optimizer): | |
| interval = config.get("interval", "epoch") | |
| assert interval in ["epoch", "step"] | |
| if config.name == "SequentialLR": | |
| scheduler = { | |
| "scheduler": lr_scheduler.SequentialLR( | |
| optimizer, | |
| [ | |
| parse_scheduler(conf, optimizer)["scheduler"] | |
| for conf in config.schedulers | |
| ], | |
| milestones=config.milestones, | |
| ), | |
| "interval": interval, | |
| } | |
| elif config.name == "ChainedScheduler": | |
| scheduler = { | |
| "scheduler": lr_scheduler.ChainedScheduler( | |
| [ | |
| parse_scheduler(conf, optimizer)["scheduler"] | |
| for conf in config.schedulers | |
| ] | |
| ), | |
| "interval": interval, | |
| } | |
| else: | |
| scheduler = { | |
| "scheduler": get_scheduler(config.name)(optimizer, **config.args), | |
| "interval": interval, | |
| } | |
| return scheduler | |