Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import torch | |
| from .predictor import get_predictor | |
| class BaseModel: | |
| """ | |
| 模型预测的基类 | |
| """ | |
| def __init__(self, **kwargs): | |
| self.kwargs = copy.deepcopy(kwargs) | |
| self.predictor = get_predictor(**self.kwargs) | |
| self.device = torch.cuda.current_device() | |
| self.cudaStream = torch.cuda.current_stream().cuda_stream | |
| self.predict_type = kwargs.get("predict_type", "trt") | |
| if self.predictor is not None: | |
| self.input_shapes = self.predictor.input_spec() | |
| self.output_shapes = self.predictor.output_spec() | |
| def input_process(self, *data): | |
| """ | |
| 输入预处理 | |
| :return: | |
| """ | |
| pass | |
| def output_process(self, *data): | |
| """ | |
| 输出后处理 | |
| :return: | |
| """ | |
| pass | |
| def predict(self, *data): | |
| """ | |
| 预测 | |
| :return: | |
| """ | |
| pass | |
| def __del__(self): | |
| """ | |
| 删除实例 | |
| :return: | |
| """ | |
| if self.predictor is not None: | |
| del self.predictor | |