Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| from mmcv.utils import print_log | |
| from mmdet.datasets.builder import DATASETS | |
| from mmdet.datasets.pipelines import Compose | |
| from torch.utils.data import Dataset | |
| from mmocr.datasets.builder import build_loader | |
| class BaseDataset(Dataset): | |
| """Custom dataset for text detection, text recognition, and their | |
| downstream tasks. | |
| 1. The text detection annotation format is as follows: | |
| The `annotations` field is optional for testing | |
| (this is one line of anno_file, with line-json-str | |
| converted to dict for visualizing only). | |
| { | |
| "file_name": "sample.jpg", | |
| "height": 1080, | |
| "width": 960, | |
| "annotations": | |
| [ | |
| { | |
| "iscrowd": 0, | |
| "category_id": 1, | |
| "bbox": [357.0, 667.0, 804.0, 100.0], | |
| "segmentation": [[361, 667, 710, 670, | |
| 72, 767, 357, 763]] | |
| } | |
| ] | |
| } | |
| 2. The two text recognition annotation formats are as follows: | |
| The `x1,y1,x2,y2,x3,y3,x4,y4` field is used for online crop | |
| augmentation during training. | |
| format1: sample.jpg hello | |
| format2: sample.jpg 20 20 100 20 100 40 20 40 hello | |
| Args: | |
| ann_file (str): Annotation file path. | |
| pipeline (list[dict]): Processing pipeline. | |
| loader (dict): Dictionary to construct loader | |
| to load annotation infos. | |
| img_prefix (str, optional): Image prefix to generate full | |
| image path. | |
| test_mode (bool, optional): If set True, try...except will | |
| be turned off in __getitem__. | |
| """ | |
| def __init__(self, | |
| ann_file, | |
| loader, | |
| pipeline, | |
| img_prefix='', | |
| test_mode=False): | |
| super().__init__() | |
| self.test_mode = test_mode | |
| self.img_prefix = img_prefix | |
| self.ann_file = ann_file | |
| # load annotations | |
| loader.update(ann_file=ann_file) | |
| self.data_infos = build_loader(loader) | |
| # processing pipeline | |
| self.pipeline = Compose(pipeline) | |
| # set group flag and class, no meaning | |
| # for text detect and recognize | |
| self._set_group_flag() | |
| self.CLASSES = 0 | |
| def __len__(self): | |
| return len(self.data_infos) | |
| def _set_group_flag(self): | |
| """Set flag.""" | |
| self.flag = np.zeros(len(self), dtype=np.uint8) | |
| def pre_pipeline(self, results): | |
| """Prepare results dict for pipeline.""" | |
| results['img_prefix'] = self.img_prefix | |
| def prepare_train_img(self, index): | |
| """Get training data and annotations from pipeline. | |
| Args: | |
| index (int): Index of data. | |
| Returns: | |
| dict: Training data and annotation after pipeline with new keys | |
| introduced by pipeline. | |
| """ | |
| img_info = self.data_infos[index] | |
| results = dict(img_info=img_info) | |
| self.pre_pipeline(results) | |
| return self.pipeline(results) | |
| def prepare_test_img(self, img_info): | |
| """Get testing data from pipeline. | |
| Args: | |
| idx (int): Index of data. | |
| Returns: | |
| dict: Testing data after pipeline with new keys introduced by | |
| pipeline. | |
| """ | |
| return self.prepare_train_img(img_info) | |
| def _log_error_index(self, index): | |
| """Logging data info of bad index.""" | |
| try: | |
| data_info = self.data_infos[index] | |
| img_prefix = self.img_prefix | |
| print_log(f'Warning: skip broken file {data_info} ' | |
| f'with img_prefix {img_prefix}') | |
| except Exception as e: | |
| print_log(f'load index {index} with error {e}') | |
| def _get_next_index(self, index): | |
| """Get next index from dataset.""" | |
| self._log_error_index(index) | |
| index = (index + 1) % len(self) | |
| return index | |
| def __getitem__(self, index): | |
| """Get training/test data from pipeline. | |
| Args: | |
| index (int): Index of data. | |
| Returns: | |
| dict: Training/test data. | |
| """ | |
| if self.test_mode: | |
| return self.prepare_test_img(index) | |
| while True: | |
| try: | |
| data = self.prepare_train_img(index) | |
| if data is None: | |
| raise Exception('prepared train data empty') | |
| break | |
| except Exception as e: | |
| print_log(f'prepare index {index} with error {e}') | |
| index = self._get_next_index(index) | |
| return data | |
| def format_results(self, results, **kwargs): | |
| """Placeholder to format result to dataset-specific output.""" | |
| pass | |
| def evaluate(self, results, metric=None, logger=None, **kwargs): | |
| """Evaluate the dataset. | |
| Args: | |
| results (list): Testing results of the dataset. | |
| metric (str | list[str]): Metrics to be evaluated. | |
| logger (logging.Logger | str | None): Logger used for printing | |
| related information during evaluation. Default: None. | |
| Returns: | |
| dict[str: float] | |
| """ | |
| raise NotImplementedError | |