Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from typing import Callable, List, Union | |
| from mmcv.transforms import BaseTransform, Compose | |
| from mmpretrain.registry import TRANSFORMS | |
| # Define type of transform or transform config | |
| Transform = Union[dict, Callable[[dict], dict]] | |
| class MultiView(BaseTransform): | |
| """A transform wrapper for multiple views of an image. | |
| Args: | |
| transforms (list[dict | callable], optional): Sequence of transform | |
| object or config dict to be wrapped. | |
| mapping (dict): A dict that defines the input key mapping. | |
| The keys corresponds to the inner key (i.e., kwargs of the | |
| ``transform`` method), and should be string type. The values | |
| corresponds to the outer keys (i.e., the keys of the | |
| data/results), and should have a type of string, list or dict. | |
| None means not applying input mapping. Default: None. | |
| allow_nonexist_keys (bool): If False, the outer keys in the mapping | |
| must exist in the input data, or an exception will be raised. | |
| Default: False. | |
| Examples: | |
| >>> # Example 1: MultiViews 1 pipeline with 2 views | |
| >>> pipeline = [ | |
| >>> dict(type='MultiView', | |
| >>> num_views=2, | |
| >>> transforms=[ | |
| >>> [ | |
| >>> dict(type='Resize', scale=224))], | |
| >>> ]) | |
| >>> ] | |
| >>> # Example 2: MultiViews 2 pipelines, the first with 2 views, | |
| >>> # the second with 6 views | |
| >>> pipeline = [ | |
| >>> dict(type='MultiView', | |
| >>> num_views=[2, 6], | |
| >>> transforms=[ | |
| >>> [ | |
| >>> dict(type='Resize', scale=224)], | |
| >>> [ | |
| >>> dict(type='Resize', scale=224), | |
| >>> dict(type='RandomSolarize')], | |
| >>> ]) | |
| >>> ] | |
| """ | |
| def __init__(self, transforms: List[List[Transform]], | |
| num_views: Union[int, List[int]]) -> None: | |
| if isinstance(num_views, int): | |
| num_views = [num_views] | |
| assert isinstance(num_views, List) | |
| assert len(num_views) == len(transforms) | |
| self.num_views = num_views | |
| self.pipelines = [] | |
| for trans in transforms: | |
| pipeline = Compose(trans) | |
| self.pipelines.append(pipeline) | |
| self.transforms = [] | |
| for i in range(len(num_views)): | |
| self.transforms.extend([self.pipelines[i]] * num_views[i]) | |
| def transform(self, results: dict) -> dict: | |
| """Apply transformation to inputs. | |
| Args: | |
| results (dict): Result dict from previous pipelines. | |
| Returns: | |
| dict: Transformed results. | |
| """ | |
| multi_views_outputs = dict(img=[]) | |
| for trans in self.transforms: | |
| inputs = copy.deepcopy(results) | |
| outputs = trans(inputs) | |
| multi_views_outputs['img'].append(outputs['img']) | |
| results.update(multi_views_outputs) | |
| return results | |
| def __repr__(self) -> str: | |
| repr_str = self.__class__.__name__ + '(' | |
| for i, p in enumerate(self.pipelines): | |
| repr_str += f'\nPipeline {i + 1} with {self.num_views[i]} views:\n' | |
| repr_str += str(p) | |
| repr_str += ')' | |
| return repr_str | |
| class ApplyToList(BaseTransform): | |
| """A transform wrapper to apply the wrapped transforms to a list of items. | |
| For example, to load and resize a list of images. | |
| Args: | |
| transforms (list[dict | callable]): Sequence of transform config dict | |
| to be wrapped. | |
| scatter_key (str): The key to scatter data dict. If the field is a | |
| list, scatter the list to multiple data dicts to do transformation. | |
| collate_keys (List[str]): The keys to collate from multiple data dicts. | |
| The fields in ``collate_keys`` will be composed into a list after | |
| transformation, and the other fields will be adopted from the | |
| first data dict. | |
| """ | |
| def __init__(self, transforms, scatter_key, collate_keys): | |
| super().__init__() | |
| self.transforms = Compose([TRANSFORMS.build(t) for t in transforms]) | |
| self.scatter_key = scatter_key | |
| self.collate_keys = set(collate_keys) | |
| self.collate_keys.add(self.scatter_key) | |
| def transform(self, results: dict): | |
| scatter_field = results.get(self.scatter_key) | |
| if isinstance(scatter_field, list): | |
| scattered_results = [] | |
| for item in scatter_field: | |
| single_results = copy.deepcopy(results) | |
| single_results[self.scatter_key] = item | |
| scattered_results.append(self.transforms(single_results)) | |
| final_output = scattered_results[0] | |
| # merge output list to single output | |
| for key in scattered_results[0].keys(): | |
| if key in self.collate_keys: | |
| final_output[key] = [ | |
| single[key] for single in scattered_results | |
| ] | |
| return final_output | |
| else: | |
| return self.transforms(results) | |