AnhP's picture
Upload 65 files
98bb602 verified
raw
history blame
1.68 kB
import os
import sys
import torch
import inspect
import warnings
import functools
from pathlib import Path
from diffq import restore_quantized_state
now_dir = os.getcwd()
sys.path.append(now_dir)
from main.configs.config import Config
translations = Config().translations
def load_model(path_or_package, strict=False):
if isinstance(path_or_package, dict): package = path_or_package
elif isinstance(path_or_package, (str, Path)):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
path = path_or_package
package = torch.load(path, map_location="cpu")
else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict: model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn(translations["del_parameter"] + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
set_state(model, state)
return model
def set_state(model, state, quantizer=None):
if state.get("__quantized"):
if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
else: restore_quantized_state(model, state)
else: model.load_state_dict(state)
return state
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__