AnhP's picture
Upload 76 files
e0202f8 verified
raw
history blame
1.84 kB
import os
import sys
import torch
import inspect
import warnings
import functools
from pathlib import Path
sys.path.append(os.getcwd())
from main.configs.config import Config
translations = Config().translations
def load_model(path_or_package, strict=False):
if isinstance(path_or_package, dict): package = path_or_package
elif isinstance(path_or_package, (str, Path)):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
package = torch.load(path_or_package, map_location="cpu")
else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict: model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn(translations["del_parameter"] + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
set_state(model, state)
return model
def restore_quantized_state(model, state):
assert "meta" in state
quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
quantizer.restore_quantized_state(state)
quantizer.detach()
def set_state(model, state, quantizer=None):
if state.get("__quantized"):
if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
else: restore_quantized_state(model, state)
else: model.load_state_dict(state)
return state
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__