File size: 1,837 Bytes
e0202f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import sys
import torch
import inspect
import warnings
import functools

from pathlib import Path

sys.path.append(os.getcwd())

from main.configs.config import Config
translations = Config().translations

def load_model(path_or_package, strict=False):
    if isinstance(path_or_package, dict): package = path_or_package
    elif isinstance(path_or_package, (str, Path)):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            package = torch.load(path_or_package, map_location="cpu")
    else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
    klass = package["klass"]
    args = package["args"]
    kwargs = package["kwargs"]
    if strict: model = klass(*args, **kwargs)
    else:
        sig = inspect.signature(klass)
        for key in list(kwargs):
            if key not in sig.parameters:
                warnings.warn(translations["del_parameter"] + key)
                del kwargs[key]
        model = klass(*args, **kwargs)
    state = package["state"]
    set_state(model, state)
    return model

def restore_quantized_state(model, state):
    assert "meta" in state
    quantizer = state["meta"]["klass"](model, **state["meta"]["init_kwargs"])
    quantizer.restore_quantized_state(state)
    quantizer.detach()

def set_state(model, state, quantizer=None):
    if state.get("__quantized"):
        if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
        else: restore_quantized_state(model, state)
    else: model.load_state_dict(state)
    return state

def capture_init(init):
    @functools.wraps(init)
    def __init__(self, *args, **kwargs):
        self._init_args_kwargs = (args, kwargs)
        init(self, *args, **kwargs)
    return __init__