Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import bisect | |
import copy | |
import logging | |
import math | |
from collections import defaultdict | |
from typing import List, Sequence, Tuple, Union | |
import numpy as np | |
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
from mmengine.logging import print_log | |
from mmengine.registry import DATASETS | |
from .base_dataset import BaseDataset, force_full_init | |
class ConcatDataset(_ConcatDataset): | |
"""A wrapper of concatenated dataset. | |
Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. | |
Note: | |
``ConcatDataset`` should not inherit from ``BaseDataset`` since | |
``get_subset`` and ``get_subset_`` could produce ambiguous meaning | |
sub-dataset which conflicts with original dataset. If you want to use | |
a sub-dataset of ``ConcatDataset``, you should set ``indices`` | |
arguments for wrapped dataset which inherit from ``BaseDataset``. | |
Args: | |
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets | |
which will be concatenated. | |
lazy_init (bool, optional): Whether to load annotation during | |
instantiation. Defaults to False. | |
ignore_keys (List[str] or str): Ignore the keys that can be | |
unequal in `dataset.metainfo`. Defaults to None. | |
`New in version 0.3.0.` | |
""" | |
def __init__(self, | |
datasets: Sequence[Union[BaseDataset, dict]], | |
lazy_init: bool = False, | |
ignore_keys: Union[str, List[str], None] = None): | |
self.datasets: List[BaseDataset] = [] | |
for i, dataset in enumerate(datasets): | |
if isinstance(dataset, dict): | |
self.datasets.append(DATASETS.build(dataset)) | |
elif isinstance(dataset, BaseDataset): | |
self.datasets.append(dataset) | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`BaseDataset` instance, but got {type(dataset)}') | |
if ignore_keys is None: | |
self.ignore_keys = [] | |
elif isinstance(ignore_keys, str): | |
self.ignore_keys = [ignore_keys] | |
elif isinstance(ignore_keys, list): | |
self.ignore_keys = ignore_keys | |
else: | |
raise TypeError('ignore_keys should be a list or str, ' | |
f'but got {type(ignore_keys)}') | |
meta_keys: set = set() | |
for dataset in self.datasets: | |
meta_keys |= dataset.metainfo.keys() | |
# Only use metainfo of first dataset. | |
self._metainfo = self.datasets[0].metainfo | |
for i, dataset in enumerate(self.datasets, 1): | |
for key in meta_keys: | |
if key in self.ignore_keys: | |
continue | |
if key not in dataset.metainfo: | |
raise ValueError( | |
f'{key} does not in the meta information of ' | |
f'the {i}-th dataset') | |
first_type = type(self._metainfo[key]) | |
cur_type = type(dataset.metainfo[key]) | |
if first_type is not cur_type: # type: ignore | |
raise TypeError( | |
f'The type {cur_type} of {key} in the {i}-th dataset ' | |
'should be the same with the first dataset ' | |
f'{first_type}') | |
if (isinstance(self._metainfo[key], np.ndarray) | |
and not np.array_equal(self._metainfo[key], | |
dataset.metainfo[key]) | |
or self._metainfo[key] != dataset.metainfo[key]): | |
raise ValueError( | |
f'The meta information of the {i}-th dataset does not ' | |
'match meta information of the first dataset') | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of the first dataset in ``self.datasets``. | |
Returns: | |
dict: Meta information of first dataset. | |
""" | |
# Prevent `self._metainfo` from being modified by outside. | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
for d in self.datasets: | |
d.full_init() | |
# Get the cumulative sizes of `self.datasets`. For example, the length | |
# of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9] | |
super().__init__(self.datasets) | |
self._fully_initialized = True | |
def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]: | |
"""Convert global idx to local index. | |
Args: | |
idx (int): Global index of ``RepeatDataset``. | |
Returns: | |
Tuple[int, int]: The index of ``self.datasets`` and the local | |
index of data. | |
""" | |
if idx < 0: | |
if -idx > len(self): | |
raise ValueError( | |
f'absolute value of index({idx}) should not exceed dataset' | |
f'length({len(self)}).') | |
idx = len(self) + idx | |
# Get `dataset_idx` to tell idx belongs to which dataset. | |
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
# Get the inner index of single dataset. | |
if dataset_idx == 0: | |
sample_idx = idx | |
else: | |
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
return dataset_idx, sample_idx | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) | |
return self.datasets[dataset_idx].get_data_info(sample_idx) | |
def __len__(self): | |
return super().__len__() | |
def __getitem__(self, idx): | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init` method manually to ' | |
'accelerate the speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) | |
return self.datasets[dataset_idx][sample_idx] | |
def get_subset_(self, indices: Union[List[int], int]) -> None: | |
"""Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- | |
dataset.""" | |
raise NotImplementedError( | |
'`ConcatDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `ConcatDataset`.') | |
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': | |
"""Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- | |
dataset.""" | |
raise NotImplementedError( | |
'`ConcatDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `ConcatDataset`.') | |
class RepeatDataset: | |
"""A wrapper of repeated dataset. | |
The length of repeated dataset will be `times` larger than the original | |
dataset. This is useful when the data loading time is long but the dataset | |
is small. Using RepeatDataset can reduce the data loading time between | |
epochs. | |
Note: | |
``RepeatDataset`` should not inherit from ``BaseDataset`` since | |
``get_subset`` and ``get_subset_`` could produce ambiguous meaning | |
sub-dataset which conflicts with original dataset. If you want to use | |
a sub-dataset of ``RepeatDataset``, you should set ``indices`` | |
arguments for wrapped dataset which inherit from ``BaseDataset``. | |
Args: | |
dataset (BaseDataset or dict): The dataset to be repeated. | |
times (int): Repeat times. | |
lazy_init (bool): Whether to load annotation during | |
instantiation. Defaults to False. | |
""" | |
def __init__(self, | |
dataset: Union[BaseDataset, dict], | |
times: int, | |
lazy_init: bool = False): | |
self.dataset: BaseDataset | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`BaseDataset` instance, but got {type(dataset)}') | |
self.times = times | |
self._metainfo = self.dataset.metainfo | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of the repeated dataset. | |
Returns: | |
dict: The meta information of repeated dataset. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
self._ori_len = len(self.dataset) | |
self._fully_initialized = True | |
def _get_ori_dataset_idx(self, idx: int) -> int: | |
"""Convert global index to local index. | |
Args: | |
idx: Global index of ``RepeatDataset``. | |
Returns: | |
idx (int): Local index of data. | |
""" | |
return idx % self._ori_len | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
sample_idx = self._get_ori_dataset_idx(idx) | |
return self.dataset.get_data_info(sample_idx) | |
def __getitem__(self, idx): | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init` method manually to accelerate the ' | |
'speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
sample_idx = self._get_ori_dataset_idx(idx) | |
return self.dataset[sample_idx] | |
def __len__(self): | |
return self.times * self._ori_len | |
def get_subset_(self, indices: Union[List[int], int]) -> None: | |
"""Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- | |
dataset.""" | |
raise NotImplementedError( | |
'`RepeatDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `RepeatDataset`.') | |
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': | |
"""Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- | |
dataset.""" | |
raise NotImplementedError( | |
'`RepeatDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `RepeatDataset`.') | |
class ClassBalancedDataset: | |
"""A wrapper of class balanced dataset. | |
Suitable for training on class imbalanced datasets like LVIS. Following | |
the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_, | |
in each epoch, an image may appear multiple times based on its | |
"repeat factor". | |
The repeat factor for an image is a function of the frequency the rarest | |
category labeled in that image. The "frequency of category c" in [0, 1] | |
is defined by the fraction of images in the training set (without repeats) | |
in which category c appears. | |
The dataset needs to instantiate :meth:`get_cat_ids` to support | |
ClassBalancedDataset. | |
The repeat factor is computed as followed. | |
1. For each category c, compute the fraction # of images | |
that contain it: :math:`f(c)` | |
2. For each category c, compute the category-level repeat factor: | |
:math:`r(c) = max(1, sqrt(t/f(c)))` | |
3. For each image I, compute the image-level repeat factor: | |
:math:`r(I) = max_{c in I} r(c)` | |
Note: | |
``ClassBalancedDataset`` should not inherit from ``BaseDataset`` | |
since ``get_subset`` and ``get_subset_`` could produce ambiguous | |
meaning sub-dataset which conflicts with original dataset. If you | |
want to use a sub-dataset of ``ClassBalancedDataset``, you should set | |
``indices`` arguments for wrapped dataset which inherit from | |
``BaseDataset``. | |
Args: | |
dataset (BaseDataset or dict): The dataset to be repeated. | |
oversample_thr (float): frequency threshold below which data is | |
repeated. For categories with ``f_c >= oversample_thr``, there is | |
no oversampling. For categories with ``f_c < oversample_thr``, the | |
degree of oversampling following the square-root inverse frequency | |
heuristic above. | |
lazy_init (bool, optional): whether to load annotation during | |
instantiation. Defaults to False | |
""" | |
def __init__(self, | |
dataset: Union[BaseDataset, dict], | |
oversample_thr: float, | |
lazy_init: bool = False): | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError( | |
'elements in datasets sequence should be config or ' | |
f'`BaseDataset` instance, but got {type(dataset)}') | |
self.oversample_thr = oversample_thr | |
self._metainfo = self.dataset.metainfo | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of the repeated dataset. | |
Returns: | |
dict: The meta information of repeated dataset. | |
""" | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""Loop to ``full_init`` each dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
# Get repeat factors for each image. | |
repeat_factors = self._get_repeat_factors(self.dataset, | |
self.oversample_thr) | |
# Repeat dataset's indices according to repeat_factors. For example, | |
# if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`, | |
# the repeated indices will be [1, 2, 2, 3, 3, 3]. | |
repeat_indices = [] | |
for dataset_index, repeat_factor in enumerate(repeat_factors): | |
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) | |
self.repeat_indices = repeat_indices | |
self._fully_initialized = True | |
def _get_repeat_factors(self, dataset: BaseDataset, | |
repeat_thr: float) -> List[float]: | |
"""Get repeat factor for each images in the dataset. | |
Args: | |
dataset (BaseDataset): The dataset. | |
repeat_thr (float): The threshold of frequency. If an image | |
contains the categories whose frequency below the threshold, | |
it would be repeated. | |
Returns: | |
List[float]: The repeat factors for each images in the dataset. | |
""" | |
# 1. For each category c, compute the fraction # of images | |
# that contain it: f(c) | |
category_freq: defaultdict = defaultdict(float) | |
num_images = len(dataset) | |
for idx in range(num_images): | |
cat_ids = set(self.dataset.get_cat_ids(idx)) | |
for cat_id in cat_ids: | |
category_freq[cat_id] += 1 | |
for k, v in category_freq.items(): | |
assert v > 0, f'caterogy {k} does not contain any images' | |
category_freq[k] = v / num_images | |
# 2. For each category c, compute the category-level repeat factor: | |
# r(c) = max(1, sqrt(t/f(c))) | |
category_repeat = { | |
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) | |
for cat_id, cat_freq in category_freq.items() | |
} | |
# 3. For each image I and its labels L(I), compute the image-level | |
# repeat factor: | |
# r(I) = max_{c in L(I)} r(c) | |
repeat_factors = [] | |
for idx in range(num_images): | |
# the length of `repeat_factors` need equal to the length of | |
# dataset. Hence, if the `cat_ids` is empty, | |
# the repeat_factor should be 1. | |
repeat_factor: float = 1. | |
cat_ids = set(self.dataset.get_cat_ids(idx)) | |
if len(cat_ids) != 0: | |
repeat_factor = max( | |
{category_repeat[cat_id] | |
for cat_id in cat_ids}) | |
repeat_factors.append(repeat_factor) | |
return repeat_factors | |
def _get_ori_dataset_idx(self, idx: int) -> int: | |
"""Convert global index to local index. | |
Args: | |
idx (int): Global index of ``RepeatDataset``. | |
Returns: | |
int: Local index of data. | |
""" | |
return self.repeat_indices[idx] | |
def get_cat_ids(self, idx: int) -> List[int]: | |
"""Get category ids of class balanced dataset by index. | |
Args: | |
idx (int): Index of data. | |
Returns: | |
List[int]: All categories in the image of specified index. | |
""" | |
sample_idx = self._get_ori_dataset_idx(idx) | |
return self.dataset.get_cat_ids(sample_idx) | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``ConcatDataset``. | |
Returns: | |
dict: The idx-th annotation of the dataset. | |
""" | |
sample_idx = self._get_ori_dataset_idx(idx) | |
return self.dataset.get_data_info(sample_idx) | |
def __getitem__(self, idx): | |
if not self._fully_initialized: | |
print_log( | |
'Please call `full_init` method manually to accelerate ' | |
'the speed.', | |
logger='current', | |
level=logging.WARNING) | |
self.full_init() | |
ori_index = self._get_ori_dataset_idx(idx) | |
return self.dataset[ori_index] | |
def __len__(self): | |
return len(self.repeat_indices) | |
def get_subset_(self, indices: Union[List[int], int]) -> None: | |
"""Not supported in ``ClassBalancedDataset`` for the ambiguous meaning | |
of sub-dataset.""" | |
raise NotImplementedError( | |
'`ClassBalancedDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `ClassBalancedDataset`.') | |
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': | |
"""Not supported in ``ClassBalancedDataset`` for the ambiguous meaning | |
of sub-dataset.""" | |
raise NotImplementedError( | |
'`ClassBalancedDataset` dose not support `get_subset` and ' | |
'`get_subset_` interfaces because this will lead to ambiguous ' | |
'implementation of some methods. If you want to use `get_subset` ' | |
'or `get_subset_` interfaces, please use them in the wrapped ' | |
'dataset first and then use `ClassBalancedDataset`.') | |