Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import warnings | |
| from os import path as osp | |
| import numpy as np | |
| import torch | |
| from mmdet.datasets.builder import DATASETS | |
| from mmocr.core import compute_f1_score | |
| from mmocr.datasets.base_dataset import BaseDataset | |
| from mmocr.datasets.pipelines import sort_vertex8 | |
| from mmocr.utils import is_type_list, list_from_file | |
| class KIEDataset(BaseDataset): | |
| """ | |
| Args: | |
| ann_file (str): Annotation file path. | |
| pipeline (list[dict]): Processing pipeline. | |
| loader (dict): Dictionary to construct loader | |
| to load annotation infos. | |
| img_prefix (str, optional): Image prefix to generate full | |
| image path. | |
| test_mode (bool, optional): If True, try...except will | |
| be turned off in __getitem__. | |
| dict_file (str): Character dict file path. | |
| norm (float): Norm to map value from one range to another. | |
| """ | |
| def __init__(self, | |
| ann_file=None, | |
| loader=None, | |
| dict_file=None, | |
| img_prefix='', | |
| pipeline=None, | |
| norm=10., | |
| directed=False, | |
| test_mode=True, | |
| **kwargs): | |
| if ann_file is None and loader is None: | |
| warnings.warn( | |
| 'KIEDataset is only initialized as a downstream demo task ' | |
| 'of text detection and recognition ' | |
| 'without an annotation file.', UserWarning) | |
| else: | |
| super().__init__( | |
| ann_file, | |
| loader, | |
| pipeline, | |
| img_prefix=img_prefix, | |
| test_mode=test_mode) | |
| assert osp.exists(dict_file) | |
| self.norm = norm | |
| self.directed = directed | |
| self.dict = { | |
| '': 0, | |
| **{ | |
| line.rstrip('\r\n'): ind | |
| for ind, line in enumerate(list_from_file(dict_file), 1) | |
| } | |
| } | |
| def pre_pipeline(self, results): | |
| results['img_prefix'] = self.img_prefix | |
| results['bbox_fields'] = [] | |
| results['ori_texts'] = results['ann_info']['ori_texts'] | |
| results['filename'] = osp.join(self.img_prefix, | |
| results['img_info']['filename']) | |
| results['ori_filename'] = results['img_info']['filename'] | |
| # a dummy img data | |
| results['img'] = np.zeros((0, 0, 0), dtype=np.uint8) | |
| def _parse_anno_info(self, annotations): | |
| """Parse annotations of boxes, texts and labels for one image. | |
| Args: | |
| annotations (list[dict]): Annotations of one image, where | |
| each dict is for one character. | |
| Returns: | |
| dict: A dict containing the following keys: | |
| - bboxes (np.ndarray): Bbox in one image with shape: | |
| box_num * 4. They are sorted clockwise when loading. | |
| - relations (np.ndarray): Relations between bbox with shape: | |
| box_num * box_num * D. | |
| - texts (np.ndarray): Text index with shape: | |
| box_num * text_max_len. | |
| - labels (np.ndarray): Box Labels with shape: | |
| box_num * (box_num + 1). | |
| """ | |
| assert is_type_list(annotations, dict) | |
| assert len(annotations) > 0, 'Please remove data with empty annotation' | |
| assert 'box' in annotations[0] | |
| assert 'text' in annotations[0] | |
| boxes, texts, text_inds, labels, edges = [], [], [], [], [] | |
| for ann in annotations: | |
| box = ann['box'] | |
| sorted_box = sort_vertex8(box[:8]) | |
| boxes.append(sorted_box) | |
| text = ann['text'] | |
| texts.append(ann['text']) | |
| text_ind = [self.dict[c] for c in text if c in self.dict] | |
| text_inds.append(text_ind) | |
| labels.append(ann.get('label', 0)) | |
| edges.append(ann.get('edge', 0)) | |
| ann_infos = dict( | |
| boxes=boxes, | |
| texts=texts, | |
| text_inds=text_inds, | |
| edges=edges, | |
| labels=labels) | |
| return self.list_to_numpy(ann_infos) | |
| def prepare_train_img(self, index): | |
| """Get training data and annotations from pipeline. | |
| Args: | |
| index (int): Index of data. | |
| Returns: | |
| dict: Training data and annotation after pipeline with new keys | |
| introduced by pipeline. | |
| """ | |
| img_ann_info = self.data_infos[index] | |
| img_info = { | |
| 'filename': img_ann_info['file_name'], | |
| 'height': img_ann_info['height'], | |
| 'width': img_ann_info['width'] | |
| } | |
| ann_info = self._parse_anno_info(img_ann_info['annotations']) | |
| results = dict(img_info=img_info, ann_info=ann_info) | |
| self.pre_pipeline(results) | |
| return self.pipeline(results) | |
| def evaluate(self, | |
| results, | |
| metric='macro_f1', | |
| metric_options=dict(macro_f1=dict(ignores=[])), | |
| **kwargs): | |
| # allow some kwargs to pass through | |
| assert set(kwargs).issubset(['logger']) | |
| # Protect ``metric_options`` since it uses mutable value as default | |
| metric_options = copy.deepcopy(metric_options) | |
| metrics = metric if isinstance(metric, list) else [metric] | |
| allowed_metrics = ['macro_f1'] | |
| for m in metrics: | |
| if m not in allowed_metrics: | |
| raise KeyError(f'metric {m} is not supported') | |
| return self.compute_macro_f1(results, **metric_options['macro_f1']) | |
| def compute_macro_f1(self, results, ignores=[]): | |
| node_preds = [] | |
| node_gts = [] | |
| for idx, result in enumerate(results): | |
| node_preds.append(result['nodes'].cpu()) | |
| box_ann_infos = self.data_infos[idx]['annotations'] | |
| node_gt = [box_ann_info['label'] for box_ann_info in box_ann_infos] | |
| node_gts.append(torch.Tensor(node_gt)) | |
| node_preds = torch.cat(node_preds) | |
| node_gts = torch.cat(node_gts).int() | |
| node_f1s = compute_f1_score(node_preds, node_gts, ignores) | |
| return { | |
| 'macro_f1': node_f1s.mean(), | |
| } | |
| def list_to_numpy(self, ann_infos): | |
| """Convert bboxes, relations, texts and labels to ndarray.""" | |
| boxes, text_inds = ann_infos['boxes'], ann_infos['text_inds'] | |
| texts = ann_infos['texts'] | |
| boxes = np.array(boxes, np.int32) | |
| relations, bboxes = self.compute_relation(boxes) | |
| labels = ann_infos.get('labels', None) | |
| if labels is not None: | |
| labels = np.array(labels, np.int32) | |
| edges = ann_infos.get('edges', None) | |
| if edges is not None: | |
| labels = labels[:, None] | |
| edges = np.array(edges) | |
| edges = (edges[:, None] == edges[None, :]).astype(np.int32) | |
| if self.directed: | |
| edges = (edges & labels == 1).astype(np.int32) | |
| np.fill_diagonal(edges, -1) | |
| labels = np.concatenate([labels, edges], -1) | |
| padded_text_inds = self.pad_text_indices(text_inds) | |
| return dict( | |
| bboxes=bboxes, | |
| relations=relations, | |
| texts=padded_text_inds, | |
| ori_texts=texts, | |
| labels=labels) | |
| def pad_text_indices(self, text_inds): | |
| """Pad text index to same length.""" | |
| max_len = max([len(text_ind) for text_ind in text_inds]) | |
| padded_text_inds = -np.ones((len(text_inds), max_len), np.int32) | |
| for idx, text_ind in enumerate(text_inds): | |
| padded_text_inds[idx, :len(text_ind)] = np.array(text_ind) | |
| return padded_text_inds | |
| def compute_relation(self, boxes): | |
| """Compute relation between every two boxes.""" | |
| # Get minimal axis-aligned bounding boxes for each of the boxes | |
| # yapf: disable | |
| bboxes = np.concatenate( | |
| [boxes[:, 0::2].min(axis=1, keepdims=True), | |
| boxes[:, 1::2].min(axis=1, keepdims=True), | |
| boxes[:, 0::2].max(axis=1, keepdims=True), | |
| boxes[:, 1::2].max(axis=1, keepdims=True)], | |
| axis=1).astype(np.float32) | |
| # yapf: enable | |
| x1, y1 = bboxes[:, 0:1], bboxes[:, 1:2] | |
| x2, y2 = bboxes[:, 2:3], bboxes[:, 3:4] | |
| w, h = np.maximum(x2 - x1 + 1, 1), np.maximum(y2 - y1 + 1, 1) | |
| dx = (x1.T - x1) / self.norm | |
| dy = (y1.T - y1) / self.norm | |
| xhh, xwh = h.T / h, w.T / h | |
| whs = w / h + np.zeros_like(xhh) | |
| relation = np.stack([dx, dy, whs, xhh, xwh], -1).astype(np.float32) | |
| return relation, bboxes | |