Spaces:
Build error
Build error
File size: 6,593 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import warnings
from typing import Any, Mapping, Sequence
import numpy as np
import torch
from torch.utils.data._utils.collate import \
default_collate as torch_default_collate
from mmengine.registry import FUNCTIONS
from mmengine.structures import BaseDataElement
# FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep
# the compatibility.
COLLATE_FUNCTIONS = FUNCTIONS
def worker_init_fn(worker_id: int,
num_workers: int,
rank: int,
seed: int,
disable_subprocess_warning: bool = False) -> None:
"""This function will be called on each worker subprocess after seeding and
before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1].
num_workers (int): How many subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in
non-distributed environment, it is a constant number `0`.
seed (int): Random seed.
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter('ignore')
@FUNCTIONS.register_module()
def pseudo_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
will not stack tensors to batch tensors, and convert int, float, ndarray to
tensors.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Args:
data_batch (Sequence): Batch of data from dataloader.
Returns:
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
""" # noqa: E501
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named tuple
return data_item_type(*(pseudo_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [pseudo_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[pseudo_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [pseudo_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: pseudo_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return data_batch
@FUNCTIONS.register_module()
def default_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
Different from :func:`pseudo_collate`, ``default_collate`` will stack
tensor contained in ``data_batch`` into a batched tensor with the
first dimension batch size, and then move input tensor to the target
device.
Different from ``default_collate`` in pytorch, ``default_collate`` will
not process ``BaseDataElement``.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_.
Note:
``default_collate`` only accept input tensor with the same shape.
Args:
data_batch (Sequence): Data sampled from dataset.
Returns:
Any: Data in the same format as the data_itement of ``data_batch``, of which
tensors have been stacked, and ndarray, int, float have been
converted to tensors.
""" # noqa: E501
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (BaseDataElement, str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named_tuple
return data_item_type(*(default_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [default_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return torch_default_collate(data_batch)
|