Spaces:
Runtime error
Runtime error
""" | |
Meta-learning modules based on: https://github.com/tristandeleu/pytorch-meta | |
MIT License | |
Copyright (c) 2019-2020 Tristan Deleu | |
Permission is hereby granted, free of charge, to any person obtaining a copy | |
of this software and associated documentation files (the "Software"), to deal | |
in the Software without restriction, including without limitation the rights | |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
copies of the Software, and to permit persons to whom the Software is | |
furnished to do so, subject to the following conditions: | |
The above copyright notice and this permission notice shall be included in all | |
copies or substantial portions of the Software. | |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
SOFTWARE. | |
""" | |
import itertools | |
import re | |
from collections import OrderedDict | |
import torch.nn as nn | |
from shap_e.util.collections import AttrDict | |
__all__ = [ | |
"MetaModule", | |
"subdict", | |
"superdict", | |
"leveldict", | |
"leveliter", | |
"batch_meta_parameters", | |
"batch_meta_state_dict", | |
] | |
def subdict(dictionary, key=None): | |
if dictionary is None: | |
return None | |
if (key is None) or (key == ""): | |
return dictionary | |
key_re = re.compile(r"^{0}\.(.+)".format(re.escape(key))) | |
return AttrDict( | |
OrderedDict( | |
(key_re.sub(r"\1", k), value) | |
for (k, value) in dictionary.items() | |
if key_re.match(k) is not None | |
) | |
) | |
def superdict(dictionary, key=None): | |
if dictionary is None: | |
return None | |
if (key is None) or (key == ""): | |
return dictionary | |
return AttrDict(OrderedDict((key + "." + k, value) for (k, value) in dictionary.items())) | |
def leveldict(dictionary, depth=0): | |
return AttrDict(leveliter(dictionary, depth=depth)) | |
def leveliter(dictionary, depth=0): | |
""" | |
depth == 0 is root | |
""" | |
for key, value in dictionary.items(): | |
if key.count(".") == depth: | |
yield key, value | |
class MetaModule(nn.Module): | |
""" | |
Base class for PyTorch meta-learning modules. These modules accept an | |
additional argument `params` in their `forward` method. | |
Notes | |
----- | |
Objects inherited from `MetaModule` are fully compatible with PyTorch | |
modules from `torch.nn.Module`. The argument `params` is a dictionary of | |
tensors, with full support of the computation graph (for differentiation). | |
Based on SIREN's torchmeta with some additional features/changes. | |
All meta weights must not have the batch dimension, as they are later tiled | |
to the given batch size after unsqueezing the first dimension (e.g. a | |
weight of dimension [d_out x d_in] is tiled to have the dimension [batch x | |
d_out x d_in]). Requiring all meta weights to have a batch dimension of 1 | |
(e.g. [1 x d_out x d_in] from the earlier example) could be a more natural | |
choice, but this results in silent failures. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._meta_state_dict = set() | |
self._meta_params = set() | |
def register_meta_buffer(self, name: str, param: nn.Parameter): | |
""" | |
Registers a trainable or nontrainable parameter as a meta buffer. This | |
can be later retrieved by meta_state_dict | |
""" | |
self.register_buffer(name, param) | |
self._meta_state_dict.add(name) | |
def register_meta_parameter(self, name: str, parameter: nn.Parameter): | |
""" | |
Registers a meta parameter so it is included in named_meta_parameters | |
and meta_state_dict. | |
""" | |
self.register_parameter(name, parameter) | |
self._meta_params.add(name) | |
self._meta_state_dict.add(name) | |
def register_meta(self, name: str, parameter: nn.Parameter, trainable: bool = True): | |
if trainable: | |
self.register_meta_parameter(name, parameter) | |
else: | |
self.register_meta_buffer(name, parameter) | |
def register(self, name: str, parameter: nn.Parameter, meta: bool, trainable: bool = True): | |
if meta: | |
if trainable: | |
self.register_meta_parameter(name, parameter) | |
else: | |
self.register_meta_buffer(name, parameter) | |
else: | |
if trainable: | |
self.register_parameter(name, parameter) | |
else: | |
self.register_buffer(name, parameter) | |
def named_meta_parameters(self, prefix="", recurse=True): | |
""" | |
Returns an iterator over all the names and meta parameters | |
""" | |
def meta_iterator(module): | |
meta = module._meta_params if isinstance(module, MetaModule) else set() | |
for name, param in module._parameters.items(): | |
if name in meta: | |
yield name, param | |
gen = self._named_members( | |
meta_iterator, | |
prefix=prefix, | |
recurse=recurse, | |
) | |
for name, param in gen: | |
yield name, param | |
def named_nonmeta_parameters(self, prefix="", recurse=True): | |
def _iterator(module): | |
meta = module._meta_params if isinstance(module, MetaModule) else set() | |
for name, param in module._parameters.items(): | |
if name not in meta: | |
yield name, param | |
gen = self._named_members( | |
_iterator, | |
prefix=prefix, | |
recurse=recurse, | |
) | |
for name, param in gen: | |
yield name, param | |
def nonmeta_parameters(self, prefix="", recurse=True): | |
for _, param in self.named_nonmeta_parameters(prefix=prefix, recurse=recurse): | |
yield param | |
def meta_state_dict(self, prefix="", recurse=True): | |
""" | |
Returns an iterator over all the names and meta parameters/buffers. | |
One difference between module.state_dict() is that this preserves | |
requires_grad, because we may want to compute the gradient w.r.t. meta | |
buffers, but don't necessarily update them automatically. | |
""" | |
def meta_iterator(module): | |
meta = module._meta_state_dict if isinstance(module, MetaModule) else set() | |
for name, param in itertools.chain(module._buffers.items(), module._parameters.items()): | |
if name in meta: | |
yield name, param | |
gen = self._named_members( | |
meta_iterator, | |
prefix=prefix, | |
recurse=recurse, | |
) | |
return dict(gen) | |
def update(self, params=None): | |
""" | |
Updates the parameter list before the forward prop so that if `params` | |
is None or doesn't have a certain key, the module uses the default | |
parameter/buffer registered in the module. | |
""" | |
# import pdb; pdb.set_trace() | |
if params is None: | |
params = AttrDict() | |
params = AttrDict(params) | |
named_params = set([name for name, _ in self.named_parameters()]) | |
for name, param in self.named_parameters(): | |
params.setdefault(name, param) | |
for name, param in self.state_dict().items(): | |
if name not in named_params: | |
params.setdefault(name, param) | |
return params | |
def batch_meta_parameters(net, batch_size): | |
params = AttrDict() | |
for name, param in net.named_meta_parameters(): | |
params[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) | |
return params | |
def batch_meta_state_dict(net, batch_size): | |
state_dict = AttrDict() | |
meta_parameters = set([name for name, _ in net.named_meta_parameters()]) | |
for name, param in net.meta_state_dict().items(): | |
state_dict[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) | |
return state_dict | |