Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import abstractmethod | |
from typing import Dict, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmengine.registry import MODELS | |
from mmengine.structures import BaseDataElement | |
from .base_model import BaseModel | |
# multi-batch inputs processed by different augmentations from the same batch. | |
EnhancedBatchInputs = List[Union[torch.Tensor, List[torch.Tensor]]] | |
# multi-batch data samples processed by different augmentations from the same | |
# batch. The inner list stands for different augmentations and the outer list | |
# stands for batch. | |
EnhancedBatchDataSamples = List[List[BaseDataElement]] | |
DATA_BATCH = Union[Dict[str, Union[EnhancedBatchInputs, | |
EnhancedBatchDataSamples]], tuple, dict] | |
MergedDataSamples = List[BaseDataElement] | |
class BaseTTAModel(BaseModel): | |
"""Base model for inference with test-time augmentation. | |
``BaseTTAModel`` is a wrapper for inference given multi-batch data. | |
It implements the :meth:`test_step` for multi-batch data inference. | |
``multi-batch`` data means data processed by different augmentation | |
from the same batch. | |
During test time augmentation, the data processed by | |
:obj:`mmcv.transforms.TestTimeAug`, and then collated by | |
``pseudo_collate`` will have the following format: | |
.. code-block:: | |
result = dict( | |
inputs=[ | |
[image1_aug1, image2_aug1], | |
[image1_aug2, image2_aug2] | |
], | |
data_samples=[ | |
[data_sample1_aug1, data_sample2_aug1], | |
[data_sample1_aug2, data_sample2_aug2], | |
] | |
) | |
``image{i}_aug{j}`` means the i-th image of the batch, which is | |
augmented by the j-th augmentation. | |
``BaseTTAModel`` will collate the data to: | |
.. code-block:: | |
data1 = dict( | |
inputs=[image1_aug1, image2_aug1], | |
data_samples=[data_sample1_aug1, data_sample2_aug1] | |
) | |
data2 = dict( | |
inputs=[image1_aug2, image2_aug2], | |
data_samples=[data_sample1_aug2, data_sample2_aug2] | |
) | |
``data1`` and ``data2`` will be passed to model, and the results will be | |
merged by :meth:`merge_preds`. | |
Note: | |
:meth:`merge_preds` is an abstract method, all subclasses should | |
implement it. | |
Warning: | |
If ``data_preprocessor`` is not None, it will overwrite the model's | |
``data_preprocessor``. | |
Args: | |
module (dict or nn.Module): Tested model. | |
data_preprocessor (dict or :obj:`BaseDataPreprocessor`, optional): | |
If model does not define ``data_preprocessor``, it will be the | |
default value for model. | |
""" | |
def __init__( | |
self, | |
module: Union[dict, nn.Module], | |
data_preprocessor: Union[dict, nn.Module, None] = None, | |
): | |
super().__init__() | |
if isinstance(module, nn.Module): | |
self.module = module | |
elif isinstance(module, dict): | |
if data_preprocessor is not None: | |
module['data_preprocessor'] = data_preprocessor | |
self.module = MODELS.build(module) | |
else: | |
raise TypeError('The type of module should be a `nn.Module` ' | |
f'instance or a dict, but got {module}') | |
assert hasattr(self.module, 'test_step'), ( | |
'Model wrapped by BaseTTAModel must implement `test_step`!') | |
def merge_preds(self, data_samples_list: EnhancedBatchDataSamples) \ | |
-> MergedDataSamples: | |
"""Merge predictions of enhanced data to one prediction. | |
Args: | |
data_samples_list (EnhancedBatchDataSamples): List of predictions | |
of all enhanced data. | |
Returns: | |
List[BaseDataElement]: Merged prediction. | |
""" | |
def test_step(self, data): | |
"""Get predictions of each enhanced data, a multiple predictions. | |
Args: | |
data (DataBatch): Enhanced data batch sampled from dataloader. | |
Returns: | |
MergedDataSamples: Merged prediction. | |
""" | |
data_list: Union[List[dict], List[list]] | |
if isinstance(data, dict): | |
num_augs = len(data[next(iter(data))]) | |
data_list = [{key: value[idx] | |
for key, value in data.items()} | |
for idx in range(num_augs)] | |
elif isinstance(data, (tuple, list)): | |
num_augs = len(data[0]) | |
data_list = [[_data[idx] for _data in data] | |
for idx in range(num_augs)] | |
else: | |
raise TypeError('data given by dataLoader should be a dict, ' | |
f'tuple or a list, but got {type(data)}') | |
predictions = [] | |
for data in data_list: # type: ignore | |
predictions.append(self.module.test_step(data)) | |
return self.merge_preds(list(zip(*predictions))) # type: ignore | |
def forward(self, | |
inputs: torch.Tensor, | |
data_samples: Optional[list] = None, | |
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: | |
"""``BaseTTAModel.forward`` should not be called.""" | |
raise NotImplementedError( | |
'`BaseTTAModel.forward` will not be called during training or' | |
'testing. Please call `test_step` instead. If you want to use' | |
'`BaseTTAModel.forward`, please implement this method') | |