Spaces:
Running
Running
File size: 2,765 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import copy
__all__ = ['build_post_process']
from .abinet_postprocess import ABINetLabelDecode
from .ar_postprocess import ARLabelDecode
from .ce_postprocess import CELabelDecode
from .char_postprocess import CharLabelDecode
from .cppd_postprocess import CPPDLabelDecode
from .ctc_postprocess import CTCLabelDecode
from .igtr_postprocess import IGTRLabelDecode
from .lister_postprocess import LISTERLabelDecode
from .mgp_postprocess import MPGLabelDecode
from .nrtr_postprocess import NRTRLabelDecode
from .smtr_postprocess import SMTRLabelDecode
from .srn_postprocess import SRNLabelDecode
from .visionlan_postprocess import VisionLANLabelDecode
support_dict = [
'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode',
'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode',
'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode',
'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode'
]
def build_post_process(config, global_config=None):
config = copy.deepcopy(config)
module_name = config.pop('name')
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
'post process only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class
class GTCLabelDecode(object):
"""Convert between text-label and text-index."""
def __init__(self,
gtc_label_decode=None,
character_dict_path=None,
use_space_char=True,
only_gtc=False,
with_ratio=False,
**kwargs):
gtc_label_decode['character_dict_path'] = character_dict_path
gtc_label_decode['use_space_char'] = use_space_char
self.gtc_label_decode = build_post_process(gtc_label_decode)
self.ctc_label_decode = CTCLabelDecode(
character_dict_path=character_dict_path,
use_space_char=use_space_char)
self.gtc_character = self.gtc_label_decode.character
self.ctc_character = self.ctc_label_decode.character
self.only_gtc = only_gtc
self.with_ratio = with_ratio
def get_character_num(self):
return [len(self.gtc_character), len(self.ctc_character)]
def __call__(self, preds, batch=None, *args, **kwargs):
if self.with_ratio:
batch = batch[:-1]
gtc = self.gtc_label_decode(preds['gtc_pred'],
batch[:-2] if batch is not None else None)
if self.only_gtc:
return gtc
ctc = self.ctc_label_decode(preds['ctc_pred'], [None] +
batch[-2:] if batch is not None else None)
return [gtc, ctc]
|