Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import mmocr.utils as utils | |
| from mmocr.models.builder import CONVERTORS | |
| from .base import BaseConvertor | |
| class CTCConvertor(BaseConvertor): | |
| """Convert between text, index and tensor for CTC loss-based pipeline. | |
| Args: | |
| dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. | |
| dict_file (None|str): Character dict file path. If not none, the file | |
| is of higher priority than dict_type. | |
| dict_list (None|list[str]): Character list. If not none, the list | |
| is of higher priority than dict_type, but lower than dict_file. | |
| with_unknown (bool): If True, add `UKN` token to class. | |
| lower (bool): If True, convert original string to lower case. | |
| """ | |
| def __init__(self, | |
| dict_type='DICT90', | |
| dict_file=None, | |
| dict_list=None, | |
| with_unknown=True, | |
| lower=False, | |
| **kwargs): | |
| super().__init__(dict_type, dict_file, dict_list) | |
| assert isinstance(with_unknown, bool) | |
| assert isinstance(lower, bool) | |
| self.with_unknown = with_unknown | |
| self.lower = lower | |
| self.update_dict() | |
| def update_dict(self): | |
| # CTC-blank | |
| blank_token = '<BLK>' | |
| self.blank_idx = 0 | |
| self.idx2char.insert(0, blank_token) | |
| # unknown | |
| self.unknown_idx = None | |
| if self.with_unknown: | |
| self.idx2char.append('<UKN>') | |
| self.unknown_idx = len(self.idx2char) - 1 | |
| # update char2idx | |
| self.char2idx = {} | |
| for idx, char in enumerate(self.idx2char): | |
| self.char2idx[char] = idx | |
| def str2tensor(self, strings): | |
| """Convert text-string to ctc-loss input tensor. | |
| Args: | |
| strings (list[str]): ['hello', 'world']. | |
| Returns: | |
| dict (str: tensor | list[tensor]): | |
| tensors (list[tensor]): [torch.Tensor([1,2,3,3,4]), | |
| torch.Tensor([5,4,6,3,7])]. | |
| flatten_targets (tensor): torch.Tensor([1,2,3,3,4,5,4,6,3,7]). | |
| target_lengths (tensor): torch.IntTensot([5,5]). | |
| """ | |
| assert utils.is_type_list(strings, str) | |
| tensors = [] | |
| indexes = self.str2idx(strings) | |
| for index in indexes: | |
| tensor = torch.IntTensor(index) | |
| tensors.append(tensor) | |
| target_lengths = torch.IntTensor([len(t) for t in tensors]) | |
| flatten_target = torch.cat(tensors) | |
| return { | |
| 'targets': tensors, | |
| 'flatten_targets': flatten_target, | |
| 'target_lengths': target_lengths | |
| } | |
| def tensor2idx(self, output, img_metas, topk=1, return_topk=False): | |
| """Convert model output tensor to index-list. | |
| Args: | |
| output (tensor): The model outputs with size: N * T * C. | |
| img_metas (list[dict]): Each dict contains one image info. | |
| topk (int): The highest k classes to be returned. | |
| return_topk (bool): Whether to return topk or just top1. | |
| Returns: | |
| indexes (list[list[int]]): [[1,2,3,3,4], [5,4,6,3,7]]. | |
| scores (list[list[float]]): [[0.9,0.8,0.95,0.97,0.94], | |
| [0.9,0.9,0.98,0.97,0.96]] | |
| ( | |
| indexes_topk (list[list[list[int]->len=topk]]): | |
| scores_topk (list[list[list[float]->len=topk]]) | |
| ). | |
| """ | |
| assert utils.is_type_list(img_metas, dict) | |
| assert len(img_metas) == output.size(0) | |
| assert isinstance(topk, int) | |
| assert topk >= 1 | |
| valid_ratios = [ | |
| img_meta.get('valid_ratio', 1.0) for img_meta in img_metas | |
| ] | |
| batch_size = output.size(0) | |
| output = F.softmax(output, dim=2) | |
| output = output.cpu().detach() | |
| batch_topk_value, batch_topk_idx = output.topk(topk, dim=2) | |
| batch_max_idx = batch_topk_idx[:, :, 0] | |
| scores_topk, indexes_topk = [], [] | |
| scores, indexes = [], [] | |
| feat_len = output.size(1) | |
| for b in range(batch_size): | |
| valid_ratio = valid_ratios[b] | |
| decode_len = min(feat_len, math.ceil(feat_len * valid_ratio)) | |
| pred = batch_max_idx[b, :] | |
| select_idx = [] | |
| prev_idx = self.blank_idx | |
| for t in range(decode_len): | |
| tmp_value = pred[t].item() | |
| if tmp_value not in (prev_idx, self.blank_idx): | |
| select_idx.append(t) | |
| prev_idx = tmp_value | |
| select_idx = torch.LongTensor(select_idx) | |
| topk_value = torch.index_select(batch_topk_value[b, :, :], 0, | |
| select_idx) # valid_seqlen * topk | |
| topk_idx = torch.index_select(batch_topk_idx[b, :, :], 0, | |
| select_idx) | |
| topk_idx_list, topk_value_list = topk_idx.numpy().tolist( | |
| ), topk_value.numpy().tolist() | |
| indexes_topk.append(topk_idx_list) | |
| scores_topk.append(topk_value_list) | |
| indexes.append([x[0] for x in topk_idx_list]) | |
| scores.append([x[0] for x in topk_value_list]) | |
| if return_topk: | |
| return indexes_topk, scores_topk | |
| return indexes, scores | |