Spaces:
Running
Running
import copy | |
__all__ = ['build_post_process'] | |
from .abinet_postprocess import ABINetLabelDecode | |
from .ar_postprocess import ARLabelDecode | |
from .ce_postprocess import CELabelDecode | |
from .char_postprocess import CharLabelDecode | |
from .cppd_postprocess import CPPDLabelDecode | |
from .ctc_postprocess import CTCLabelDecode | |
from .igtr_postprocess import IGTRLabelDecode | |
from .lister_postprocess import LISTERLabelDecode | |
from .mgp_postprocess import MPGLabelDecode | |
from .nrtr_postprocess import NRTRLabelDecode | |
from .smtr_postprocess import SMTRLabelDecode | |
from .srn_postprocess import SRNLabelDecode | |
from .visionlan_postprocess import VisionLANLabelDecode | |
support_dict = [ | |
'CTCLabelDecode', 'CharLabelDecode', 'CELabelDecode', 'CPPDLabelDecode', | |
'NRTRLabelDecode', 'ABINetLabelDecode', 'ARLabelDecode', 'IGTRLabelDecode', | |
'VisionLANLabelDecode', 'SMTRLabelDecode', 'SRNLabelDecode', | |
'LISTERLabelDecode', 'GTCLabelDecode', 'MPGLabelDecode' | |
] | |
def build_post_process(config, global_config=None): | |
config = copy.deepcopy(config) | |
module_name = config.pop('name') | |
if global_config is not None: | |
config.update(global_config) | |
assert module_name in support_dict, Exception( | |
'post process only support {}'.format(support_dict)) | |
module_class = eval(module_name)(**config) | |
return module_class | |
class GTCLabelDecode(object): | |
"""Convert between text-label and text-index.""" | |
def __init__(self, | |
gtc_label_decode=None, | |
character_dict_path=None, | |
use_space_char=True, | |
only_gtc=False, | |
with_ratio=False, | |
**kwargs): | |
gtc_label_decode['character_dict_path'] = character_dict_path | |
gtc_label_decode['use_space_char'] = use_space_char | |
self.gtc_label_decode = build_post_process(gtc_label_decode) | |
self.ctc_label_decode = CTCLabelDecode( | |
character_dict_path=character_dict_path, | |
use_space_char=use_space_char) | |
self.gtc_character = self.gtc_label_decode.character | |
self.ctc_character = self.ctc_label_decode.character | |
self.only_gtc = only_gtc | |
self.with_ratio = with_ratio | |
def get_character_num(self): | |
return [len(self.gtc_character), len(self.ctc_character)] | |
def __call__(self, preds, batch=None, *args, **kwargs): | |
if self.with_ratio: | |
batch = batch[:-1] | |
gtc = self.gtc_label_decode(preds['gtc_pred'], | |
batch[:-2] if batch is not None else None) | |
if self.only_gtc: | |
return gtc | |
ctc = self.ctc_label_decode(preds['ctc_pred'], [None] + | |
batch[-2:] if batch is not None else None) | |
return [gtc, ctc] | |