|
import importlib |
|
import os |
|
import sys |
|
from typing import Callable, Dict, Union |
|
|
|
import numpy as np |
|
import yaml |
|
import torch |
|
|
|
|
|
def merge_a_into_b(a, b): |
|
|
|
for k, v in a.items(): |
|
if isinstance(v, dict) and k in b: |
|
assert isinstance( |
|
b[k], dict |
|
), "Cannot inherit key '{}' from base!".format(k) |
|
merge_a_into_b(v, b[k]) |
|
else: |
|
b[k] = v |
|
|
|
|
|
def load_config(config_file): |
|
with open(config_file, "r") as reader: |
|
config = yaml.load(reader, Loader=yaml.FullLoader) |
|
if "inherit_from" in config: |
|
base_config_file = config["inherit_from"] |
|
base_config_file = os.path.join( |
|
os.path.dirname(config_file), base_config_file |
|
) |
|
assert not os.path.samefile(config_file, base_config_file), \ |
|
"inherit from itself" |
|
base_config = load_config(base_config_file) |
|
del config["inherit_from"] |
|
merge_a_into_b(config, base_config) |
|
return base_config |
|
return config |
|
|
|
def get_cls_from_str(string, reload=False): |
|
module_name, cls_name = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module_name) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module_name, package=None), cls_name) |
|
|
|
def init_obj_from_dict(config, **kwargs): |
|
obj_args = config["args"].copy() |
|
obj_args.update(kwargs) |
|
for k in config: |
|
if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs: |
|
obj_args[k] = init_obj_from_dict(config[k]) |
|
try: |
|
obj = get_cls_from_str(config["type"])(**obj_args) |
|
return obj |
|
except Exception as e: |
|
print(f"Initializing {config} failed, detailed error stack: ") |
|
raise e |
|
|
|
def init_model_from_config(config, print_fn=sys.stdout.write): |
|
kwargs = {} |
|
for k in config: |
|
if k not in ["type", "args", "pretrained"]: |
|
sub_model = init_model_from_config(config[k], print_fn) |
|
if "pretrained" in config[k]: |
|
load_pretrained_model(sub_model, |
|
config[k]["pretrained"], |
|
print_fn) |
|
kwargs[k] = sub_model |
|
model = init_obj_from_dict(config, **kwargs) |
|
return model |
|
|
|
def merge_load_state_dict(state_dict, |
|
model: torch.nn.Module, |
|
output_fn: Callable = sys.stdout.write): |
|
model_dict = model.state_dict() |
|
pretrained_dict = {} |
|
mismatch_keys = [] |
|
for key, value in state_dict.items(): |
|
if key in model_dict and model_dict[key].shape == value.shape: |
|
pretrained_dict[key] = value |
|
else: |
|
mismatch_keys.append(key) |
|
output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}") |
|
model_dict.update(pretrained_dict) |
|
model.load_state_dict(model_dict, strict=True) |
|
return pretrained_dict.keys() |
|
|
|
|
|
def load_pretrained_model(model: torch.nn.Module, |
|
pretrained: Union[str, Dict], |
|
output_fn: Callable = sys.stdout.write): |
|
if not isinstance(pretrained, dict) and not os.path.exists(pretrained): |
|
output_fn(f"pretrained {pretrained} not exist!") |
|
return |
|
|
|
if hasattr(model, "load_pretrained"): |
|
model.load_pretrained(pretrained, output_fn) |
|
return |
|
|
|
if isinstance(pretrained, dict): |
|
state_dict = pretrained |
|
else: |
|
state_dict = torch.load(pretrained, map_location="cpu") |
|
|
|
if "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
|
|
merge_load_state_dict(state_dict, model, output_fn) |
|
|
|
def pad_sequence(data, pad_value=0): |
|
if isinstance(data[0], (np.ndarray, torch.Tensor)): |
|
data = [torch.as_tensor(arr) for arr in data] |
|
padded_seq = torch.nn.utils.rnn.pad_sequence(data, |
|
batch_first=True, |
|
padding_value=pad_value) |
|
length = np.array([x.shape[0] for x in data]) |
|
return padded_seq, length |