Spaces:
Running
Running
# -*- coding: utf8 -*- | |
import torch.cuda as cuda | |
import torch.nn as nn | |
import torch | |
import collections | |
from torch.nn.parallel._functions import Gather | |
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] | |
def async_copy_to(obj, dev, main_stream=None): | |
if torch.is_tensor(obj): | |
v = obj.cuda(dev, non_blocking=True) | |
if main_stream is not None: | |
v.data.record_stream(main_stream) | |
return v | |
elif isinstance(obj, collections.Mapping): | |
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} | |
elif isinstance(obj, collections.Sequence): | |
return [async_copy_to(o, dev, main_stream) for o in obj] | |
else: | |
return obj | |
def dict_gather(outputs, target_device, dim=0): | |
""" | |
Gathers variables from different GPUs on a specified device | |
(-1 means the CPU), with dictionary support. | |
""" | |
def gather_map(outputs): | |
out = outputs[0] | |
if torch.is_tensor(out): | |
# MJY(20180330) HACK:: force nr_dims > 0 | |
if out.dim() == 0: | |
outputs = [o.unsqueeze(0) for o in outputs] | |
return Gather.apply(target_device, dim, *outputs) | |
elif out is None: | |
return None | |
elif isinstance(out, collections.Mapping): | |
return {k: gather_map([o[k] for o in outputs]) for k in out} | |
elif isinstance(out, collections.Sequence): | |
return type(out)(map(gather_map, zip(*outputs))) | |
return gather_map(outputs) | |
class DictGatherDataParallel(nn.DataParallel): | |
def gather(self, outputs, output_device): | |
return dict_gather(outputs, output_device, dim=self.dim) | |
class UserScatteredDataParallel(DictGatherDataParallel): | |
def scatter(self, inputs, kwargs, device_ids): | |
assert len(inputs) == 1 | |
inputs = inputs[0] | |
inputs = _async_copy_stream(inputs, device_ids) | |
inputs = [[i] for i in inputs] | |
assert len(kwargs) == 0 | |
kwargs = [{} for _ in range(len(inputs))] | |
return inputs, kwargs | |
def user_scattered_collate(batch): | |
return batch | |
def _async_copy(inputs, device_ids): | |
nr_devs = len(device_ids) | |
assert type(inputs) in (tuple, list) | |
assert len(inputs) == nr_devs | |
outputs = [] | |
for i, dev in zip(inputs, device_ids): | |
with cuda.device(dev): | |
outputs.append(async_copy_to(i, dev)) | |
return tuple(outputs) | |
def _async_copy_stream(inputs, device_ids): | |
nr_devs = len(device_ids) | |
assert type(inputs) in (tuple, list) | |
assert len(inputs) == nr_devs | |
outputs = [] | |
streams = [_get_stream(d) for d in device_ids] | |
for i, dev, stream in zip(inputs, device_ids, streams): | |
with cuda.device(dev): | |
main_stream = cuda.current_stream() | |
with cuda.stream(stream): | |
outputs.append(async_copy_to(i, dev, main_stream=main_stream)) | |
main_stream.wait_stream(stream) | |
return outputs | |
"""Adapted from: torch/nn/parallel/_functions.py""" | |
# background streams used for copying | |
_streams = None | |
def _get_stream(device): | |
"""Gets a background stream for copying between CPU and GPU""" | |
global _streams | |
if device == -1: | |
return None | |
if _streams is None: | |
_streams = [None] * cuda.device_count() | |
if _streams[device] is None: _streams[device] = cuda.Stream(device) | |
return _streams[device] | |