Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import os | |
| import os.path as osp | |
| import warnings | |
| from typing import Iterable | |
| import cv2 | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmcv.parallel import collate | |
| from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine | |
| from mmdet.datasets import replace_ImageToTensor | |
| from mmdet.datasets.pipelines import Compose | |
| from mmocr.core.deployment import (ONNXRuntimeDetector, ONNXRuntimeRecognizer, | |
| TensorRTDetector, TensorRTRecognizer) | |
| from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 | |
| from mmocr.utils import is_2dlist | |
| def get_GiB(x: int): | |
| """return x GiB.""" | |
| return x * (1 << 30) | |
| def _prepare_input_img(imgs, test_pipeline: Iterable[dict]): | |
| """Inference image(s) with the detector. | |
| Args: | |
| imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): | |
| Either image files or loaded images. | |
| test_pipeline (Iterable[dict]): Test pipline of configuration. | |
| Returns: | |
| result (dict): Predicted results. | |
| """ | |
| if isinstance(imgs, (list, tuple)): | |
| if not isinstance(imgs[0], (np.ndarray, str)): | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| elif isinstance(imgs, (np.ndarray, str)): | |
| imgs = [imgs] | |
| else: | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| test_pipeline = replace_ImageToTensor(test_pipeline) | |
| test_pipeline = Compose(test_pipeline) | |
| data = [] | |
| for img in imgs: | |
| # prepare data | |
| # add information into dict | |
| datum = dict(img_info=dict(filename=img), img_prefix=None) | |
| # build the data pipeline | |
| datum = test_pipeline(datum) | |
| # get tensor from list to stack for batch mode (text detection) | |
| data.append(datum) | |
| if isinstance(data[0]['img'], list) and len(data) > 1: | |
| raise Exception('aug test does not support ' | |
| f'inference with batch size ' | |
| f'{len(data)}') | |
| data = collate(data, samples_per_gpu=len(imgs)) | |
| # process img_metas | |
| if isinstance(data['img_metas'], list): | |
| data['img_metas'] = [ | |
| img_metas.data[0] for img_metas in data['img_metas'] | |
| ] | |
| else: | |
| data['img_metas'] = data['img_metas'].data | |
| if isinstance(data['img'], list): | |
| data['img'] = [img.data for img in data['img']] | |
| if isinstance(data['img'][0], list): | |
| data['img'] = [img[0] for img in data['img']] | |
| else: | |
| data['img'] = data['img'].data | |
| return data | |
| def onnx2tensorrt(onnx_file: str, | |
| model_type: str, | |
| trt_file: str, | |
| config: dict, | |
| input_config: dict, | |
| fp16: bool = False, | |
| verify: bool = False, | |
| show: bool = False, | |
| workspace_size: int = 1, | |
| verbose: bool = False): | |
| import tensorrt as trt | |
| min_shape = input_config['min_shape'] | |
| max_shape = input_config['max_shape'] | |
| # create trt engine and wrapper | |
| opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} | |
| max_workspace_size = get_GiB(workspace_size) | |
| trt_engine = onnx2trt( | |
| onnx_file, | |
| opt_shape_dict, | |
| log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, | |
| fp16_mode=fp16, | |
| max_workspace_size=max_workspace_size) | |
| save_dir, _ = osp.split(trt_file) | |
| if save_dir: | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_trt_engine(trt_engine, trt_file) | |
| print(f'Successfully created TensorRT engine: {trt_file}') | |
| if verify: | |
| mm_inputs = _prepare_input_img(input_config['input_path'], | |
| config.data.test.pipeline) | |
| imgs = mm_inputs.pop('img') | |
| img_metas = mm_inputs.pop('img_metas') | |
| if isinstance(imgs, list): | |
| imgs = imgs[0] | |
| img_list = [img[None, :] for img in imgs] | |
| # Get results from ONNXRuntime | |
| if model_type == 'det': | |
| onnx_model = ONNXRuntimeDetector(onnx_file, config, 0) | |
| else: | |
| onnx_model = ONNXRuntimeRecognizer(onnx_file, config, 0) | |
| onnx_out = onnx_model.simple_test( | |
| img_list[0], img_metas[0], rescale=True) | |
| # Get results from TensorRT | |
| if model_type == 'det': | |
| trt_model = TensorRTDetector(trt_file, config, 0) | |
| else: | |
| trt_model = TensorRTRecognizer(trt_file, config, 0) | |
| img_list[0] = img_list[0].to(torch.device('cuda:0')) | |
| trt_out = trt_model.simple_test( | |
| img_list[0], img_metas[0], rescale=True) | |
| # compare results | |
| same_diff = 'same' | |
| if model_type == 'recog': | |
| for onnx_result, trt_result in zip(onnx_out, trt_out): | |
| if onnx_result['text'] != trt_result['text'] or \ | |
| not np.allclose( | |
| np.array(onnx_result['score']), | |
| np.array(trt_result['score']), | |
| rtol=1e-4, | |
| atol=1e-4): | |
| same_diff = 'different' | |
| break | |
| else: | |
| for onnx_result, trt_result in zip(onnx_out[0]['boundary_result'], | |
| trt_out[0]['boundary_result']): | |
| if not np.allclose( | |
| np.array(onnx_result), | |
| np.array(trt_result), | |
| rtol=1e-4, | |
| atol=1e-4): | |
| same_diff = 'different' | |
| break | |
| print('The outputs are {} between TensorRT and ONNX'.format(same_diff)) | |
| if show: | |
| onnx_img = onnx_model.show_result( | |
| input_config['input_path'], | |
| onnx_out[0], | |
| out_file='onnx.jpg', | |
| show=False) | |
| trt_img = trt_model.show_result( | |
| input_config['input_path'], | |
| trt_out[0], | |
| out_file='tensorrt.jpg', | |
| show=False) | |
| if onnx_img is None: | |
| onnx_img = cv2.imread(input_config['input_path']) | |
| if trt_img is None: | |
| trt_img = cv2.imread(input_config['input_path']) | |
| cv2.imshow('TensorRT', trt_img) | |
| cv2.imshow('ONNXRuntime', onnx_img) | |
| cv2.waitKey() | |
| return | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='Convert MMOCR models from ONNX to TensorRT') | |
| parser.add_argument('model_config', help='Config file of the model') | |
| parser.add_argument( | |
| 'model_type', | |
| type=str, | |
| help='Detection or recognition model to deploy.', | |
| choices=['recog', 'det']) | |
| parser.add_argument('image_path', type=str, help='Image for test') | |
| parser.add_argument('onnx_file', help='Path to the input ONNX model') | |
| parser.add_argument( | |
| '--trt-file', | |
| type=str, | |
| help='Path to the output TensorRT engine', | |
| default='tmp.trt') | |
| parser.add_argument( | |
| '--max-shape', | |
| type=int, | |
| nargs=4, | |
| default=[1, 3, 400, 600], | |
| help='Maximum shape of model input.') | |
| parser.add_argument( | |
| '--min-shape', | |
| type=int, | |
| nargs=4, | |
| default=[1, 3, 400, 600], | |
| help='Minimum shape of model input.') | |
| parser.add_argument( | |
| '--workspace-size', | |
| type=int, | |
| default=1, | |
| help='Max workspace size in GiB.') | |
| parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') | |
| parser.add_argument( | |
| '--verify', | |
| action='store_true', | |
| help='Whether Verify the outputs of ONNXRuntime and TensorRT.', | |
| default=True) | |
| parser.add_argument( | |
| '--show', | |
| action='store_true', | |
| help='Whether visiualize outputs of ONNXRuntime and TensorRT.', | |
| default=True) | |
| parser.add_argument( | |
| '--verbose', | |
| action='store_true', | |
| help='Whether to verbose logging messages while creating \ | |
| TensorRT engine.') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' | |
| args = parse_args() | |
| # Following strings of text style are from colorama package | |
| bright_style, reset_style = '\x1b[1m', '\x1b[0m' | |
| red_text, blue_text = '\x1b[31m', '\x1b[34m' | |
| white_background = '\x1b[107m' | |
| msg = white_background + bright_style + red_text | |
| msg += 'DeprecationWarning: This tool will be deprecated in future. ' | |
| msg += blue_text + 'Welcome to use the unified model deployment toolbox ' | |
| msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' | |
| msg += reset_style | |
| warnings.warn(msg) | |
| # check arguments | |
| assert osp.exists(args.model_config), 'Config {} not found.'.format( | |
| args.model_config) | |
| assert osp.exists(args.onnx_file), \ | |
| 'ONNX model {} not found.'.format(args.onnx_file) | |
| assert args.workspace_size >= 0, 'Workspace size less than 0.' | |
| for max_value, min_value in zip(args.max_shape, args.min_shape): | |
| assert max_value >= min_value, \ | |
| 'max_shape should be larger than min shape' | |
| input_config = { | |
| 'min_shape': args.min_shape, | |
| 'max_shape': args.max_shape, | |
| 'input_path': args.image_path | |
| } | |
| cfg = mmcv.Config.fromfile(args.model_config) | |
| if cfg.data.test.get('pipeline', None) is None: | |
| if is_2dlist(cfg.data.test.datasets): | |
| cfg.data.test.pipeline = \ | |
| cfg.data.test.datasets[0][0].pipeline | |
| else: | |
| cfg.data.test.pipeline = \ | |
| cfg.data.test['datasets'][0].pipeline | |
| if is_2dlist(cfg.data.test.pipeline): | |
| cfg.data.test.pipeline = cfg.data.test.pipeline[0] | |
| onnx2tensorrt( | |
| args.onnx_file, | |
| args.model_type, | |
| args.trt_file, | |
| cfg, | |
| input_config, | |
| fp16=args.fp16, | |
| verify=args.verify, | |
| show=args.show, | |
| workspace_size=args.workspace_size, | |
| verbose=args.verbose) | |