Spaces:
Runtime error
Runtime error
File size: 1,722 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import importlib
from swapae.optimizers.base_optimizer import BaseOptimizer
import torch
def find_optimizer_using_name(optimizer_name):
"""Import the module "optimizers/[optimizer_name]_optimizer.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseOptimizer,
and it is case-insensitive.
"""
optimizer_filename = "swapae.optimizers." + optimizer_name + "_optimizer"
optimizerlib = importlib.import_module(optimizer_filename)
optimizer = None
target_optimizer_name = optimizer_name.replace('_', '') + 'optimizer'
for name, cls in optimizerlib.__dict__.items():
if name.lower() == target_optimizer_name.lower() \
and issubclass(cls, BaseOptimizer):
optimizer = cls
if optimizer is None:
print("In %s.py, there should be a subclass of BaseOptimizer with class name that matches %s in lowercase." % (optimizer_filename, target_optimizer_name))
exit(0)
return optimizer
def get_option_setter(optimizer_name):
"""Return the static method <modify_commandline_options> of the optimizer class."""
optimizer_class = find_optimizer_using_name(optimizer_name)
return optimizer_class.modify_commandline_options
def create_optimizer(opt, model):
"""Create a optimizer given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from optimizers import create_optimizer
>>> optimizer = create_optimizer(opt)
"""
optimizer = find_optimizer_using_name(opt.optimizer)
instance = optimizer(model)
return instance
|