|
''' |
|
Utilities for dealing with simple state dicts as npz files instead of pth files. |
|
''' |
|
|
|
import torch |
|
from collections.abc import MutableMapping, Mapping |
|
|
|
def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None): |
|
''' |
|
Loads a model from numpy_dict using load_state_dict. |
|
Converts numpy types to torch types using the current state_dict |
|
of the model to determine types and devices for the tensors. |
|
Supports loading a subdict by prepending the given prefix to all keys. |
|
''' |
|
if prefix: |
|
if not prefix.endswith('.'): |
|
prefix = prefix + '.' |
|
numpy_dict = PrefixSubDict(numpy_dict, prefix) |
|
if examples is None: |
|
exampels = model.state_dict() |
|
torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples) |
|
model.load_state_dict(torch_state_dict) |
|
|
|
def save_to_numpy_dict(model, numpy_dict, prefix=''): |
|
''' |
|
Saves a model by copying tensors to numpy_dict. |
|
Converts torch types to numpy types using `t.detach().cpu().numpy()`. |
|
Supports saving a subdict by prepending the given prefix to all keys. |
|
''' |
|
if prefix: |
|
if not prefix.endswith('.'): |
|
prefix = prefix + '.' |
|
for k, v in model.numpy_dict().items(): |
|
if isinstance(v, torch.Tensor): |
|
v = v.detach().cpu().numpy() |
|
numpy_dict[prefix + k] = v |
|
|
|
class TorchTypeMatchingDict(Mapping): |
|
''' |
|
Provides a view of a dict of numpy values as torch tensors, where the |
|
types are converted to match the types and devices in the given |
|
dict of examples. |
|
''' |
|
def __init__(self, data, examples): |
|
self.data = data |
|
self.examples = examples |
|
self.cached_data = {} |
|
def __getitem__(self, key): |
|
if key in self.cached_data: |
|
return self.cached_data[key] |
|
val = self.data[key] |
|
if key not in self.examples: |
|
return val |
|
example = self.examples.get(key, None) |
|
example_type = type(example) |
|
if example is not None and type(val) != example_type: |
|
if isinstance(example, torch.Tensor): |
|
val = torch.from_numpy(val) |
|
else: |
|
val = example_type(val) |
|
if isinstance(example, torch.Tensor): |
|
val = val.to(dtype=example.dtype, device=example.device) |
|
self.cached_data[key] = val |
|
return val |
|
def __iter__(self): |
|
return self.data.keys() |
|
def __len__(self): |
|
return len(self.data) |
|
|
|
class PrefixSubDict(MutableMapping): |
|
''' |
|
Provides a view of the subset of a dict where string keys begin with |
|
the given prefix. The prefix is stripped from all keys of the view. |
|
''' |
|
def __init__(self, data, prefix=''): |
|
self.data = data |
|
self.prefix = prefix |
|
self._cached_keys = None |
|
def __getitem__(self, key): |
|
return self.data[self.prefix + key] |
|
def __setitem__(self, key, value): |
|
pkey = self.prefix + key |
|
if self._cached_keys is not None and pkey not in self.data: |
|
self._cached_keys = None |
|
self.data[pkey] = value |
|
def __delitem__(self, key): |
|
pkey = self.prefix + key |
|
if self._cached_keys is not None and pkey in self.data: |
|
self._cached_keys = None |
|
del self.data[pkey] |
|
def __cached_keys(self): |
|
if self._cached_keys is None: |
|
plen = len(self.prefix) |
|
self._cached_keys = list(k[plen:] for k in self.data |
|
if k.startswith(self.prefix)) |
|
return self._cached_keys |
|
def __iter__(self): |
|
return iter(self.__cached_keys()) |
|
def __len__(self): |
|
return len(self.__cached_keys()) |
|
|