File size: 2,765 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import copy

__all__ = ['build_post_process']

from .abinet_postprocess import ABINetLabelDecode
from .ar_postprocess import ARLabelDecode
from .ce_postprocess import CELabelDecode
from .char_postprocess import CharLabelDecode
from .cppd_postprocess import CPPDLabelDecode
from .ctc_postprocess import CTCLabelDecode
from .igtr_postprocess import IGTRLabelDecode
from .lister_postprocess import LISTERLabelDecode
from .mgp_postprocess import MPGLabelDecode
from .nrtr_postprocess import NRTRLabelDecode
from .smtr_postprocess import SMTRLabelDecode
from .srn_postprocess import SRNLabelDecode
from .visionlan_postprocess import VisionLANLabelDecode

support_dict = [
    'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode',
    'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode',
    'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode',
    'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode'
]


def build_post_process(config, global_config=None):
    config = copy.deepcopy(config)
    module_name = config.pop('name')
    if global_config is not None:
        config.update(global_config)
    assert module_name in support_dict, Exception(
        'post process only support {}'.format(support_dict))
    module_class = eval(module_name)(**config)
    return module_class


class GTCLabelDecode(object):
    """Convert between text-label and text-index."""

    def __init__(self,
                 gtc_label_decode=None,
                 character_dict_path=None,
                 use_space_char=True,
                 only_gtc=False,
                 with_ratio=False,
                 **kwargs):
        gtc_label_decode['character_dict_path'] = character_dict_path
        gtc_label_decode['use_space_char'] = use_space_char
        self.gtc_label_decode = build_post_process(gtc_label_decode)
        self.ctc_label_decode = CTCLabelDecode(
            character_dict_path=character_dict_path,
            use_space_char=use_space_char)
        self.gtc_character = self.gtc_label_decode.character
        self.ctc_character = self.ctc_label_decode.character
        self.only_gtc = only_gtc
        self.with_ratio = with_ratio

    def get_character_num(self):
        return [len(self.gtc_character), len(self.ctc_character)]

    def __call__(self, preds, batch=None, *args, **kwargs):
        if self.with_ratio:
            batch = batch[:-1]
        gtc = self.gtc_label_decode(preds['gtc_pred'],
                                    batch[:-2] if batch is not None else None)
        if self.only_gtc:
            return gtc
        ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
                                    batch[-2:] if batch is not None else None)

        return [gtc, ctc]