Spaces:
Sleeping
Sleeping
Update difpoint/src/models/predictor.py
Browse files
difpoint/src/models/predictor.py
CHANGED
@@ -10,7 +10,8 @@ import torch
|
|
10 |
from torch.cuda import nvtx
|
11 |
from collections import OrderedDict
|
12 |
import platform
|
13 |
-
|
|
|
14 |
|
15 |
try:
|
16 |
import tensorrt as trt
|
@@ -40,15 +41,15 @@ class TensorRTPredictor:
|
|
40 |
"""
|
41 |
Implements inference for the EfficientDet TensorRT engine.
|
42 |
"""
|
43 |
-
|
44 |
def __init__(self, **kwargs):
|
45 |
"""
|
46 |
:param engine_path: The path to the serialized engine to load from disk.
|
47 |
"""
|
48 |
if platform.system().lower() == 'linux':
|
49 |
-
ctypes.CDLL("./
|
50 |
else:
|
51 |
-
ctypes.CDLL("./
|
52 |
# Load TRT engine
|
53 |
self.logger = trt.Logger(trt.Logger.VERBOSE)
|
54 |
trt.init_libnvinfer_plugins(self.logger, "")
|
@@ -172,7 +173,6 @@ class TensorRTPredictor:
|
|
172 |
del self.outputs
|
173 |
del self.tensors
|
174 |
|
175 |
-
|
176 |
class OnnxRuntimePredictor:
|
177 |
"""
|
178 |
OnnxRuntime Prediction
|
@@ -190,7 +190,10 @@ class OnnxRuntimePredictor:
|
|
190 |
# opts.inter_op_num_threads = kwargs.get("num_threads", 4)
|
191 |
# opts.intra_op_num_threads = kwargs.get("num_threads", 4)
|
192 |
# opts.log_severity_level = 3
|
193 |
-
|
|
|
|
|
|
|
194 |
self.inputs = self.onnx_model.get_inputs()
|
195 |
self.outputs = self.onnx_model.get_outputs()
|
196 |
|
|
|
10 |
from torch.cuda import nvtx
|
11 |
from collections import OrderedDict
|
12 |
import platform
|
13 |
+
|
14 |
+
import spaces
|
15 |
|
16 |
try:
|
17 |
import tensorrt as trt
|
|
|
41 |
"""
|
42 |
Implements inference for the EfficientDet TensorRT engine.
|
43 |
"""
|
44 |
+
@spaces.GPU
|
45 |
def __init__(self, **kwargs):
|
46 |
"""
|
47 |
:param engine_path: The path to the serialized engine to load from disk.
|
48 |
"""
|
49 |
if platform.system().lower() == 'linux':
|
50 |
+
ctypes.CDLL("./difpoint/checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
|
51 |
else:
|
52 |
+
ctypes.CDLL("./difpoint/checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL)
|
53 |
# Load TRT engine
|
54 |
self.logger = trt.Logger(trt.Logger.VERBOSE)
|
55 |
trt.init_libnvinfer_plugins(self.logger, "")
|
|
|
173 |
del self.outputs
|
174 |
del self.tensors
|
175 |
|
|
|
176 |
class OnnxRuntimePredictor:
|
177 |
"""
|
178 |
OnnxRuntime Prediction
|
|
|
190 |
# opts.inter_op_num_threads = kwargs.get("num_threads", 4)
|
191 |
# opts.intra_op_num_threads = kwargs.get("num_threads", 4)
|
192 |
# opts.log_severity_level = 3
|
193 |
+
def recreate_session():
|
194 |
+
return onnxruntime.InferenceSession(model_path, providers=providers, sess_options=opts)
|
195 |
+
|
196 |
+
self.onnx_model = recreate_session()
|
197 |
self.inputs = self.onnx_model.get_inputs()
|
198 |
self.outputs = self.onnx_model.get_outputs()
|
199 |
|