Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from copy import deepcopy | |
| from typing import Any, Callable, List, Tuple, Union | |
| from mmengine.dataset import BaseDataset | |
| from mmengine.registry import build_from_cfg | |
| from mmpose.registry import DATASETS | |
| from .datasets.utils import parse_pose_metainfo | |
| class CombinedDataset(BaseDataset): | |
| """A wrapper of combined dataset. | |
| Args: | |
| metainfo (dict): The meta information of combined dataset. | |
| datasets (list): The configs of datasets to be combined. | |
| pipeline (list, optional): Processing pipeline. Defaults to []. | |
| """ | |
| def __init__(self, | |
| metainfo: dict, | |
| datasets: list, | |
| pipeline: List[Union[dict, Callable]] = [], | |
| **kwargs): | |
| self.datasets = [] | |
| for cfg in datasets: | |
| dataset = build_from_cfg(cfg, DATASETS) | |
| self.datasets.append(dataset) | |
| self._lens = [len(dataset) for dataset in self.datasets] | |
| self._len = sum(self._lens) | |
| super(CombinedDataset, self).__init__(pipeline=pipeline, **kwargs) | |
| self._metainfo = parse_pose_metainfo(metainfo) | |
| def metainfo(self): | |
| return deepcopy(self._metainfo) | |
| def __len__(self): | |
| return self._len | |
| def _get_subset_index(self, index: int) -> Tuple[int, int]: | |
| """Given a data sample's global index, return the index of the sub- | |
| dataset the data sample belongs to, and the local index within that | |
| sub-dataset. | |
| Args: | |
| index (int): The global data sample index | |
| Returns: | |
| tuple[int, int]: | |
| - subset_index (int): The index of the sub-dataset | |
| - local_index (int): The index of the data sample within | |
| the sub-dataset | |
| """ | |
| if index >= len(self) or index < -len(self): | |
| raise ValueError( | |
| f'index({index}) is out of bounds for dataset with ' | |
| f'length({len(self)}).') | |
| if index < 0: | |
| index = index + len(self) | |
| subset_index = 0 | |
| while index >= self._lens[subset_index]: | |
| index -= self._lens[subset_index] | |
| subset_index += 1 | |
| return subset_index, index | |
| def prepare_data(self, idx: int) -> Any: | |
| """Get data processed by ``self.pipeline``.The source dataset is | |
| depending on the index. | |
| Args: | |
| idx (int): The index of ``data_info``. | |
| Returns: | |
| Any: Depends on ``self.pipeline``. | |
| """ | |
| data_info = self.get_data_info(idx) | |
| return self.pipeline(data_info) | |
| def get_data_info(self, idx: int) -> dict: | |
| """Get annotation by index. | |
| Args: | |
| idx (int): Global index of ``CombinedDataset``. | |
| Returns: | |
| dict: The idx-th annotation of the datasets. | |
| """ | |
| subset_idx, sample_idx = self._get_subset_index(idx) | |
| # Get data sample processed by ``subset.pipeline`` | |
| data_info = self.datasets[subset_idx][sample_idx] | |
| # Add metainfo items that are required in the pipeline and the model | |
| metainfo_keys = [ | |
| 'upper_body_ids', 'lower_body_ids', 'flip_pairs', | |
| 'dataset_keypoint_weights', 'flip_indices' | |
| ] | |
| for key in metainfo_keys: | |
| data_info[key] = deepcopy(self._metainfo[key]) | |
| return data_info | |
| def full_init(self): | |
| """Fully initialize all sub datasets.""" | |
| if self._fully_initialized: | |
| return | |
| for dataset in self.datasets: | |
| dataset.full_init() | |
| self._fully_initialized = True | |