Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import os.path as osp | |
| import shutil | |
| import warnings | |
| import mmcv | |
| from mmocr import digit_version | |
| from mmocr.utils import list_from_file | |
| class LmdbAnnFileBackend: | |
| """Lmdb storage backend for annotation file. | |
| Args: | |
| lmdb_path (str): Lmdb file path. | |
| """ | |
| def __init__(self, lmdb_path, encoding='utf8'): | |
| self.lmdb_path = lmdb_path | |
| self.encoding = encoding | |
| env = self._get_env() | |
| with env.begin(write=False) as txn: | |
| self.total_number = int( | |
| txn.get('total_number'.encode('utf-8')).decode(self.encoding)) | |
| def __getitem__(self, index): | |
| """Retrieve one line from lmdb file by index.""" | |
| # only attach env to self when __getitem__ is called | |
| # because env object cannot be pickle | |
| if not hasattr(self, 'env'): | |
| self.env = self._get_env() | |
| with self.env.begin(write=False) as txn: | |
| line = txn.get(str(index).encode('utf-8')).decode(self.encoding) | |
| return line | |
| def __len__(self): | |
| return self.total_number | |
| def _get_env(self): | |
| try: | |
| import lmdb | |
| except ImportError: | |
| raise ImportError( | |
| 'Please install lmdb to enable LmdbAnnFileBackend.') | |
| return lmdb.open( | |
| self.lmdb_path, | |
| max_readers=1, | |
| readonly=True, | |
| lock=False, | |
| readahead=False, | |
| meminit=False, | |
| ) | |
| def close(self): | |
| self.env.close() | |
| class HardDiskAnnFileBackend: | |
| """Load annotation file with raw hard disks storage backend.""" | |
| def __init__(self, file_format='txt'): | |
| assert file_format in ['txt', 'lmdb'] | |
| self.file_format = file_format | |
| def __call__(self, ann_file): | |
| if self.file_format == 'lmdb': | |
| return LmdbAnnFileBackend(ann_file) | |
| return list_from_file(ann_file) | |
| class PetrelAnnFileBackend: | |
| """Load annotation file with petrel storage backend.""" | |
| def __init__(self, file_format='txt', save_dir='tmp_dir'): | |
| assert file_format in ['txt', 'lmdb'] | |
| self.file_format = file_format | |
| self.save_dir = save_dir | |
| def __call__(self, ann_file): | |
| file_client = mmcv.FileClient(backend='petrel') | |
| if self.file_format == 'lmdb': | |
| mmcv_version = digit_version(mmcv.__version__) | |
| if mmcv_version < digit_version('1.3.16'): | |
| raise Exception('Please update mmcv to 1.3.16 or higher ' | |
| 'to enable "get_local_path" of "FileClient".') | |
| assert file_client.isdir(ann_file) | |
| files = file_client.list_dir_or_file(ann_file) | |
| ann_file_rel_path = ann_file.split('s3://')[-1] | |
| ann_file_dir = osp.dirname(ann_file_rel_path) | |
| ann_file_name = osp.basename(ann_file_rel_path) | |
| local_dir = osp.join(self.save_dir, ann_file_dir, ann_file_name) | |
| if osp.exists(local_dir): | |
| warnings.warn( | |
| f'local_ann_file: {local_dir} is already existed and ' | |
| 'will be used. If it is not the correct ann_file ' | |
| 'corresponding to {ann_file}, please remove it or ' | |
| 'change "save_dir" first then try again.') | |
| else: | |
| os.makedirs(local_dir, exist_ok=True) | |
| print(f'Fetching {ann_file} to {local_dir}...') | |
| for each_file in files: | |
| tmp_file_path = file_client.join_path(ann_file, each_file) | |
| with file_client.get_local_path( | |
| tmp_file_path) as local_path: | |
| shutil.copy(local_path, osp.join(local_dir, each_file)) | |
| return LmdbAnnFileBackend(local_dir) | |
| lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') | |
| return [x for x in lines if x.strip() != ''] | |
| class HTTPAnnFileBackend: | |
| """Load annotation file with http storage backend.""" | |
| def __init__(self, file_format='txt'): | |
| assert file_format in ['txt', 'lmdb'] | |
| self.file_format = file_format | |
| def __call__(self, ann_file): | |
| file_client = mmcv.FileClient(backend='http') | |
| if self.file_format == 'lmdb': | |
| raise NotImplementedError( | |
| 'Loading lmdb file on http is not supported yet.') | |
| lines = str(file_client.get(ann_file), encoding='utf-8').split('\n') | |
| return [x for x in lines if x.strip() != ''] | |