Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import json | |
| import os.path as osp | |
| import tempfile | |
| import torch | |
| from mmocr.datasets.ner_dataset import NerDataset | |
| from mmocr.models.ner.convertors.ner_convertor import NerConvertor | |
| from mmocr.utils import list_to_file | |
| def _create_dummy_ann_file(ann_file): | |
| data = { | |
| 'text': '彭小军认为,国内银行现在走的是台湾的发卡模式', | |
| 'label': { | |
| 'address': { | |
| '台湾': [[15, 16]] | |
| }, | |
| 'name': { | |
| '彭小军': [[0, 2]] | |
| } | |
| } | |
| } | |
| list_to_file(ann_file, [json.dumps(data, ensure_ascii=False)]) | |
| def _create_dummy_vocab_file(vocab_file): | |
| for char in list(map(chr, range(ord('a'), ord('z') + 1))): | |
| list_to_file(vocab_file, [json.dumps(char + '\n', ensure_ascii=False)]) | |
| def _create_dummy_loader(): | |
| loader = dict( | |
| type='HardDiskLoader', | |
| repeat=1, | |
| parser=dict(type='LineJsonParser', keys=['text', 'label'])) | |
| return loader | |
| def test_ner_dataset(): | |
| # test initialization | |
| loader = _create_dummy_loader() | |
| categories = [ | |
| 'address', 'book', 'company', 'game', 'government', 'movie', 'name', | |
| 'organization', 'position', 'scene' | |
| ] | |
| # create dummy data | |
| tmp_dir = tempfile.TemporaryDirectory() | |
| ann_file = osp.join(tmp_dir.name, 'fake_data.txt') | |
| vocab_file = osp.join(tmp_dir.name, 'fake_vocab.txt') | |
| _create_dummy_ann_file(ann_file) | |
| _create_dummy_vocab_file(vocab_file) | |
| max_len = 128 | |
| ner_convertor = dict( | |
| type='NerConvertor', | |
| annotation_type='bio', | |
| vocab_file=vocab_file, | |
| categories=categories, | |
| max_len=max_len) | |
| test_pipeline = [ | |
| dict( | |
| type='NerTransform', | |
| label_convertor=ner_convertor, | |
| max_len=max_len), | |
| dict(type='ToTensorNER') | |
| ] | |
| dataset = NerDataset(ann_file, loader, pipeline=test_pipeline) | |
| # test pre_pipeline | |
| img_info = dataset.data_infos[0] | |
| results = dict(img_info=img_info) | |
| dataset.pre_pipeline(results) | |
| # test prepare_train_img | |
| dataset.prepare_train_img(0) | |
| # test evaluation | |
| result = [[['address', 15, 16], ['name', 0, 2]]] | |
| dataset.evaluate(result) | |
| # test pred convert2entity function | |
| pred = [ | |
| 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, 11, | |
| 21, 21, 7, 17, 17, 21, 21, 21, 21, 21, 21, 13, 21, 21, 21, 21, 21, 1, | |
| 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, 21, | |
| 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 1, 21, 21, 21, 21, | |
| 21, 21 | |
| ] | |
| preds = [pred[:128]] | |
| mask = [0] * 128 | |
| for i in range(10): | |
| mask[i] = 1 | |
| assert len(preds[0]) == len(mask) | |
| masks = torch.tensor([mask]) | |
| convertor = NerConvertor( | |
| annotation_type='bio', | |
| vocab_file=vocab_file, | |
| categories=categories, | |
| max_len=128) | |
| all_entities = convertor.convert_pred2entities(preds=preds, masks=masks) | |
| assert len(all_entities[0][0]) == 3 | |
| tmp_dir.cleanup() | |