Spaces:
Running
on
Zero
Running
on
Zero
| import dataclasses | |
| import warnings | |
| import numpy as np | |
| import torch | |
| def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): | |
| """Change the device of object recursively""" | |
| if isinstance(data, dict): | |
| return { | |
| k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() | |
| } | |
| elif dataclasses.is_dataclass(data) and not isinstance(data, type): | |
| return type(data)( | |
| *[ | |
| to_device(v, device, dtype, non_blocking, copy) | |
| for v in dataclasses.astuple(data) | |
| ] | |
| ) | |
| # maybe namedtuple. I don't know the correct way to judge namedtuple. | |
| elif isinstance(data, tuple) and type(data) is not tuple: | |
| return type(data)( | |
| *[to_device(o, device, dtype, non_blocking, copy) for o in data] | |
| ) | |
| elif isinstance(data, (list, tuple)): | |
| return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) | |
| elif isinstance(data, np.ndarray): | |
| return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) | |
| elif isinstance(data, torch.Tensor): | |
| return data.to(device, dtype, non_blocking, copy) | |
| else: | |
| return data | |
| def force_gatherable(data, device): | |
| """Change object to gatherable in torch.nn.DataParallel recursively | |
| The difference from to_device() is changing to torch.Tensor if float or int | |
| value is found. | |
| The restriction to the returned value in DataParallel: | |
| The object must be | |
| - torch.cuda.Tensor | |
| - 1 or more dimension. 0-dimension-tensor sends warning. | |
| or a list, tuple, dict. | |
| """ | |
| if isinstance(data, dict): | |
| return {k: force_gatherable(v, device) for k, v in data.items()} | |
| # DataParallel can't handle NamedTuple well | |
| elif isinstance(data, tuple) and type(data) is not tuple: | |
| return type(data)(*[force_gatherable(o, device) for o in data]) | |
| elif isinstance(data, (list, tuple, set)): | |
| return type(data)(force_gatherable(v, device) for v in data) | |
| elif isinstance(data, np.ndarray): | |
| return force_gatherable(torch.from_numpy(data), device) | |
| elif isinstance(data, torch.Tensor): | |
| if data.dim() == 0: | |
| # To 1-dim array | |
| data = data[None] | |
| return data.to(device) | |
| elif isinstance(data, float): | |
| return torch.tensor([data], dtype=torch.float, device=device) | |
| elif isinstance(data, int): | |
| return torch.tensor([data], dtype=torch.long, device=device) | |
| elif data is None: | |
| return None | |
| else: | |
| warnings.warn(f"{type(data)} may not be gatherable by DataParallel") | |
| return data | |