File size: 1,684 Bytes
98bb602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import sys
import torch
import inspect
import warnings
import functools
from pathlib import Path
from diffq import restore_quantized_state
now_dir = os.getcwd()
sys.path.append(now_dir)
from main.configs.config import Config
translations = Config().translations
def load_model(path_or_package, strict=False):
if isinstance(path_or_package, dict): package = path_or_package
elif isinstance(path_or_package, (str, Path)):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
path = path_or_package
package = torch.load(path, map_location="cpu")
else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
klass = package["klass"]
args = package["args"]
kwargs = package["kwargs"]
if strict: model = klass(*args, **kwargs)
else:
sig = inspect.signature(klass)
for key in list(kwargs):
if key not in sig.parameters:
warnings.warn(translations["del_parameter"] + key)
del kwargs[key]
model = klass(*args, **kwargs)
state = package["state"]
set_state(model, state)
return model
def set_state(model, state, quantizer=None):
if state.get("__quantized"):
if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
else: restore_quantized_state(model, state)
else: model.load_state_dict(state)
return state
def capture_init(init):
@functools.wraps(init)
def __init__(self, *args, **kwargs):
self._init_args_kwargs = (args, kwargs)
init(self, *args, **kwargs)
return __init__ |