Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # File : replicate.py | |
| # Author : Jiayuan Mao | |
| # Email : [email protected] | |
| # Date : 27/01/2018 | |
| # | |
| # This file is part of Synchronized-BatchNorm-PyTorch. | |
| # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
| # Distributed under MIT License. | |
| import functools | |
| from torch.nn.parallel.data_parallel import DataParallel | |
| __all__ = [ | |
| 'CallbackContext', | |
| 'execute_replication_callbacks', | |
| 'DataParallelWithCallback', | |
| 'patch_replication_callback' | |
| ] | |
| class CallbackContext(object): | |
| pass | |
| def execute_replication_callbacks(modules): | |
| """ | |
| Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. | |
| The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` | |
| Note that, as all modules are isomorphism, we assign each sub-module with a context | |
| (shared among multiple copies of this module on different devices). | |
| Through this context, different copies can share some information. | |
| We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback | |
| of any slave copies. | |
| """ | |
| master_copy = modules[0] | |
| nr_modules = len(list(master_copy.modules())) | |
| ctxs = [CallbackContext() for _ in range(nr_modules)] | |
| for i, module in enumerate(modules): | |
| for j, m in enumerate(module.modules()): | |
| if hasattr(m, '__data_parallel_replicate__'): | |
| m.__data_parallel_replicate__(ctxs[j], i) | |
| class DataParallelWithCallback(DataParallel): | |
| """ | |
| Data Parallel with a replication callback. | |
| An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by | |
| original `replicate` function. | |
| The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` | |
| Examples: | |
| > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
| > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) | |
| # sync_bn.__data_parallel_replicate__ will be invoked. | |
| """ | |
| def replicate(self, module, device_ids): | |
| modules = super(DataParallelWithCallback, self).replicate(module, device_ids) | |
| execute_replication_callbacks(modules) | |
| return modules | |
| def patch_replication_callback(data_parallel): | |
| """ | |
| Monkey-patch an existing `DataParallel` object. Add the replication callback. | |
| Useful when you have customized `DataParallel` implementation. | |
| Examples: | |
| > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
| > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) | |
| > patch_replication_callback(sync_bn) | |
| # this is equivalent to | |
| > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) | |
| > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) | |
| """ | |
| assert isinstance(data_parallel, DataParallel) | |
| old_replicate = data_parallel.replicate | |
| def new_replicate(module, device_ids): | |
| modules = old_replicate(module, device_ids) | |
| execute_replication_callbacks(modules) | |
| return modules | |
| data_parallel.replicate = new_replicate | |