File size: 1,684 Bytes
98bb602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sys
import torch
import inspect
import warnings
import functools

from pathlib import Path
from diffq import restore_quantized_state

now_dir = os.getcwd()
sys.path.append(now_dir)

from main.configs.config import Config

translations = Config().translations


def load_model(path_or_package, strict=False):
    if isinstance(path_or_package, dict): package = path_or_package
    elif isinstance(path_or_package, (str, Path)):
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            path = path_or_package
            package = torch.load(path, map_location="cpu")
    else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")


    klass = package["klass"]
    args = package["args"]
    kwargs = package["kwargs"]


    if strict: model = klass(*args, **kwargs)
    else:
        sig = inspect.signature(klass)

        for key in list(kwargs):
            if key not in sig.parameters:
                warnings.warn(translations["del_parameter"] + key)
                
                del kwargs[key]

        model = klass(*args, **kwargs)

    state = package["state"]

    set_state(model, state)

    return model


def set_state(model, state, quantizer=None):
    if state.get("__quantized"):
        if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
        else: restore_quantized_state(model, state)
    else: model.load_state_dict(state)

    return state


def capture_init(init):
    @functools.wraps(init)

    def __init__(self, *args, **kwargs):
        self._init_args_kwargs = (args, kwargs)
        init(self, *args, **kwargs)

    return __init__