import os import sys import torch import inspect import warnings import functools from pathlib import Path from diffq import restore_quantized_state now_dir = os.getcwd() sys.path.append(now_dir) from main.configs.config import Config translations = Config().translations def load_model(path_or_package, strict=False): if isinstance(path_or_package, dict): package = path_or_package elif isinstance(path_or_package, (str, Path)): with warnings.catch_warnings(): warnings.simplefilter("ignore") path = path_or_package package = torch.load(path, map_location="cpu") else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.") klass = package["klass"] args = package["args"] kwargs = package["kwargs"] if strict: model = klass(*args, **kwargs) else: sig = inspect.signature(klass) for key in list(kwargs): if key not in sig.parameters: warnings.warn(translations["del_parameter"] + key) del kwargs[key] model = klass(*args, **kwargs) state = package["state"] set_state(model, state) return model def set_state(model, state, quantizer=None): if state.get("__quantized"): if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"]) else: restore_quantized_state(model, state) else: model.load_state_dict(state) return state def capture_init(init): @functools.wraps(init) def __init__(self, *args, **kwargs): self._init_args_kwargs = (args, kwargs) init(self, *args, **kwargs) return __init__