Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# File : comm.py | |
# Author : Jiayuan Mao | |
# Email : [email protected] | |
# Date : 27/01/2018 | |
# | |
# This file is part of Synchronized-BatchNorm-PyTorch. | |
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
# Distributed under MIT License. | |
import queue | |
import collections | |
import threading | |
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] | |
class FutureResult(object): | |
"""A thread-safe future implementation. Used only as one-to-one pipe.""" | |
def __init__(self): | |
self._result = None | |
self._lock = threading.Lock() | |
self._cond = threading.Condition(self._lock) | |
def put(self, result): | |
with self._lock: | |
assert self._result is None, 'Previous result has\'t been fetched.' | |
self._result = result | |
self._cond.notify() | |
def get(self): | |
with self._lock: | |
if self._result is None: | |
self._cond.wait() | |
res = self._result | |
self._result = None | |
return res | |
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) | |
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) | |
class SlavePipe(_SlavePipeBase): | |
"""Pipe for master-slave communication.""" | |
def run_slave(self, msg): | |
self.queue.put((self.identifier, msg)) | |
ret = self.result.get() | |
self.queue.put(True) | |
return ret | |
class SyncMaster(object): | |
"""An abstract `SyncMaster` object. | |
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should | |
call `register(id)` and obtain an `SlavePipe` to communicate with the master. | |
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, | |
and passed to a registered callback. | |
- After receiving the messages, the master device should gather the information and determine to message passed | |
back to each slave devices. | |
""" | |
def __init__(self, master_callback): | |
""" | |
Args: | |
master_callback: a callback to be invoked after having collected messages from slave devices. | |
""" | |
self._master_callback = master_callback | |
self._queue = queue.Queue() | |
self._registry = collections.OrderedDict() | |
self._activated = False | |
def __getstate__(self): | |
return {'master_callback': self._master_callback} | |
def __setstate__(self, state): | |
self.__init__(state['master_callback']) | |
def register_slave(self, identifier): | |
""" | |
Register an slave device. | |
Args: | |
identifier: an identifier, usually is the device id. | |
Returns: a `SlavePipe` object which can be used to communicate with the master device. | |
""" | |
if self._activated: | |
assert self._queue.empty(), 'Queue is not clean before next initialization.' | |
self._activated = False | |
self._registry.clear() | |
future = FutureResult() | |
self._registry[identifier] = _MasterRegistry(future) | |
return SlavePipe(identifier, self._queue, future) | |
def run_master(self, master_msg): | |
""" | |
Main entry for the master device in each forward pass. | |
The messages were first collected from each devices (including the master device), and then | |
an callback will be invoked to compute the message to be sent back to each devices | |
(including the master device). | |
Args: | |
master_msg: the message that the master want to send to itself. This will be placed as the first | |
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. | |
Returns: the message to be sent back to the master device. | |
""" | |
self._activated = True | |
intermediates = [(0, master_msg)] | |
for i in range(self.nr_slaves): | |
intermediates.append(self._queue.get()) | |
results = self._master_callback(intermediates) | |
assert results[0][0] == 0, 'The first result should belongs to the master.' | |
for i, res in results: | |
if i == 0: | |
continue | |
self._registry[i].result.put(res) | |
for i in range(self.nr_slaves): | |
assert self._queue.get() is True | |
return results[0][1] | |
def nr_slaves(self): | |
return len(self._registry) | |