Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import warnings | |
from typing import Any, Mapping, Sequence | |
import numpy as np | |
import torch | |
from torch.utils.data._utils.collate import \ | |
default_collate as torch_default_collate | |
from mmengine.registry import FUNCTIONS | |
from mmengine.structures import BaseDataElement | |
# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep | |
# the compatibility. | |
COLLATE_FUNCTIONS = FUNCTIONS | |
def worker_init_fn(worker_id: int, | |
num_workers: int, | |
rank: int, | |
seed: int, | |
disable_subprocess_warning: bool = False) -> None: | |
"""This function will be called on each worker subprocess after seeding and | |
before data loading. | |
Args: | |
worker_id (int): Worker id in [0, num_workers - 1]. | |
num_workers (int): How many subprocesses to use for data loading. | |
rank (int): Rank of process in distributed environment. If in | |
non-distributed environment, it is a constant number `0`. | |
seed (int): Random seed. | |
""" | |
# The seed of each worker equals to | |
# num_worker * rank + worker_id + user_seed | |
worker_seed = num_workers * rank + worker_id + seed | |
np.random.seed(worker_seed) | |
random.seed(worker_seed) | |
torch.manual_seed(worker_seed) | |
if disable_subprocess_warning and worker_id != 0: | |
warnings.simplefilter('ignore') | |
def pseudo_collate(data_batch: Sequence) -> Any: | |
"""Convert list of data sampled from dataset into a batch of data, of which | |
type consistent with the type of each data_itement in ``data_batch``. | |
The default behavior of dataloader is to merge a list of samples to form | |
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` | |
will not stack tensors to batch tensors, and convert int, float, ndarray to | |
tensors. | |
This code is referenced from: | |
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. | |
Args: | |
data_batch (Sequence): Batch of data from dataloader. | |
Returns: | |
Any: Transversed Data in the same format as the data_itement of | |
``data_batch``. | |
""" # noqa: E501 | |
data_item = data_batch[0] | |
data_item_type = type(data_item) | |
if isinstance(data_item, (str, bytes)): | |
return data_batch | |
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): | |
# named tuple | |
return data_item_type(*(pseudo_collate(samples) | |
for samples in zip(*data_batch))) | |
elif isinstance(data_item, Sequence): | |
# check to make sure that the data_itements in batch have | |
# consistent size | |
it = iter(data_batch) | |
data_item_size = len(next(it)) | |
if not all(len(data_item) == data_item_size for data_item in it): | |
raise RuntimeError( | |
'each data_itement in list of batch should be of equal size') | |
transposed = list(zip(*data_batch)) | |
if isinstance(data_item, tuple): | |
return [pseudo_collate(samples) | |
for samples in transposed] # Compat with Pytorch. | |
else: | |
try: | |
return data_item_type( | |
[pseudo_collate(samples) for samples in transposed]) | |
except TypeError: | |
# The sequence type may not support `__init__(iterable)` | |
# (e.g., `range`). | |
return [pseudo_collate(samples) for samples in transposed] | |
elif isinstance(data_item, Mapping): | |
return data_item_type({ | |
key: pseudo_collate([d[key] for d in data_batch]) | |
for key in data_item | |
}) | |
else: | |
return data_batch | |
def default_collate(data_batch: Sequence) -> Any: | |
"""Convert list of data sampled from dataset into a batch of data, of which | |
type consistent with the type of each data_itement in ``data_batch``. | |
Different from :func:`pseudo_collate`, ``default_collate`` will stack | |
tensor contained in ``data_batch`` into a batched tensor with the | |
first dimension batch size, and then move input tensor to the target | |
device. | |
Different from ``default_collate`` in pytorch, ``default_collate`` will | |
not process ``BaseDataElement``. | |
This code is referenced from: | |
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. | |
Note: | |
``default_collate`` only accept input tensor with the same shape. | |
Args: | |
data_batch (Sequence): Data sampled from dataset. | |
Returns: | |
Any: Data in the same format as the data_itement of ``data_batch``, of which | |
tensors have been stacked, and ndarray, int, float have been | |
converted to tensors. | |
""" # noqa: E501 | |
data_item = data_batch[0] | |
data_item_type = type(data_item) | |
if isinstance(data_item, (BaseDataElement, str, bytes)): | |
return data_batch | |
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): | |
# named_tuple | |
return data_item_type(*(default_collate(samples) | |
for samples in zip(*data_batch))) | |
elif isinstance(data_item, Sequence): | |
# check to make sure that the data_itements in batch have | |
# consistent size | |
it = iter(data_batch) | |
data_item_size = len(next(it)) | |
if not all(len(data_item) == data_item_size for data_item in it): | |
raise RuntimeError( | |
'each data_itement in list of batch should be of equal size') | |
transposed = list(zip(*data_batch)) | |
if isinstance(data_item, tuple): | |
return [default_collate(samples) | |
for samples in transposed] # Compat with Pytorch. | |
else: | |
try: | |
return data_item_type( | |
[default_collate(samples) for samples in transposed]) | |
except TypeError: | |
# The sequence type may not support `__init__(iterable)` | |
# (e.g., `range`). | |
return [default_collate(samples) for samples in transposed] | |
elif isinstance(data_item, Mapping): | |
return data_item_type({ | |
key: default_collate([d[key] for d in data_batch]) | |
for key in data_item | |
}) | |
else: | |
return torch_default_collate(data_batch) | |