Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # File : comm.py | |
| # Author : Jiayuan Mao | |
| # Email : [email protected] | |
| # Date : 27/01/2018 | |
| # | |
| # This file is part of Synchronized-BatchNorm-PyTorch. | |
| # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
| # Distributed under MIT License. | |
| import queue | |
| import collections | |
| import threading | |
| __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] | |
| class FutureResult(object): | |
| """A thread-safe future implementation. Used only as one-to-one pipe.""" | |
| def __init__(self): | |
| self._result = None | |
| self._lock = threading.Lock() | |
| self._cond = threading.Condition(self._lock) | |
| def put(self, result): | |
| with self._lock: | |
| assert self._result is None, 'Previous result has\'t been fetched.' | |
| self._result = result | |
| self._cond.notify() | |
| def get(self): | |
| with self._lock: | |
| if self._result is None: | |
| self._cond.wait() | |
| res = self._result | |
| self._result = None | |
| return res | |
| _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) | |
| _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) | |
| class SlavePipe(_SlavePipeBase): | |
| """Pipe for master-slave communication.""" | |
| def run_slave(self, msg): | |
| self.queue.put((self.identifier, msg)) | |
| ret = self.result.get() | |
| self.queue.put(True) | |
| return ret | |
| class SyncMaster(object): | |
| """An abstract `SyncMaster` object. | |
| - During the replication, as the data parallel will trigger an callback of each module, all slave devices should | |
| call `register(id)` and obtain an `SlavePipe` to communicate with the master. | |
| - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, | |
| and passed to a registered callback. | |
| - After receiving the messages, the master device should gather the information and determine to message passed | |
| back to each slave devices. | |
| """ | |
| def __init__(self, master_callback): | |
| """ | |
| Args: | |
| master_callback: a callback to be invoked after having collected messages from slave devices. | |
| """ | |
| self._master_callback = master_callback | |
| self._queue = queue.Queue() | |
| self._registry = collections.OrderedDict() | |
| self._activated = False | |
| def __getstate__(self): | |
| return {'master_callback': self._master_callback} | |
| def __setstate__(self, state): | |
| self.__init__(state['master_callback']) | |
| def register_slave(self, identifier): | |
| """ | |
| Register an slave device. | |
| Args: | |
| identifier: an identifier, usually is the device id. | |
| Returns: a `SlavePipe` object which can be used to communicate with the master device. | |
| """ | |
| if self._activated: | |
| assert self._queue.empty(), 'Queue is not clean before next initialization.' | |
| self._activated = False | |
| self._registry.clear() | |
| future = FutureResult() | |
| self._registry[identifier] = _MasterRegistry(future) | |
| return SlavePipe(identifier, self._queue, future) | |
| def run_master(self, master_msg): | |
| """ | |
| Main entry for the master device in each forward pass. | |
| The messages were first collected from each devices (including the master device), and then | |
| an callback will be invoked to compute the message to be sent back to each devices | |
| (including the master device). | |
| Args: | |
| master_msg: the message that the master want to send to itself. This will be placed as the first | |
| message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. | |
| Returns: the message to be sent back to the master device. | |
| """ | |
| self._activated = True | |
| intermediates = [(0, master_msg)] | |
| for i in range(self.nr_slaves): | |
| intermediates.append(self._queue.get()) | |
| results = self._master_callback(intermediates) | |
| assert results[0][0] == 0, 'The first result should belongs to the master.' | |
| for i, res in results: | |
| if i == 0: | |
| continue | |
| self._registry[i].result.put(res) | |
| for i in range(self.nr_slaves): | |
| assert self._queue.get() is True | |
| return results[0][1] | |
| def nr_slaves(self): | |
| return len(self._registry) | |