Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from torch import nn | |
| class ModuleProxyWrapper(nn.Module): | |
| """ | |
| Wrap a DistributedDataParallel module and forward requests for missing | |
| attributes to the module wrapped by DDP (the twice-wrapped module). | |
| Also forward calls to :func:`state_dict` and :func:`load_state_dict`. | |
| Usage:: | |
| module.xyz = "hello world" | |
| wrapped_module = DistributedDataParallel(module, **ddp_args) | |
| wrapped_module = ModuleProxyWrapper(wrapped_module) | |
| assert wrapped_module.xyz == "hello world" | |
| assert wrapped_module.state_dict().keys() == module.state_dict().keys() | |
| Args: | |
| module (nn.Module): module to wrap | |
| """ | |
| def __init__(self, module: nn.Module): | |
| super().__init__() | |
| assert hasattr(module, "module"), \ | |
| "ModuleProxyWrapper expects input to wrap another module" | |
| self.module = module | |
| def __getattr__(self, name): | |
| """Forward missing attributes to twice-wrapped module.""" | |
| try: | |
| # defer to nn.Module's logic | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| try: | |
| # forward to the once-wrapped module | |
| return getattr(self.module, name) | |
| except AttributeError: | |
| # forward to the twice-wrapped module | |
| return getattr(self.module.module, name) | |
| def state_dict(self, *args, **kwargs): | |
| """Forward to the twice-wrapped module.""" | |
| return self.module.module.state_dict(*args, **kwargs) | |
| def load_state_dict(self, *args, **kwargs): | |
| """Forward to the twice-wrapped module.""" | |
| return self.module.module.load_state_dict(*args, **kwargs) | |
| def forward(self, *args, **kwargs): | |
| return self.module(*args, **kwargs) | |