YinuoGuo27 commited on
Commit
2cdc0f6
·
verified ·
1 Parent(s): a4c475a

Update difpoint/src/models/predictor.py

Browse files
Files changed (1) hide show
  1. difpoint/src/models/predictor.py +9 -6
difpoint/src/models/predictor.py CHANGED
@@ -10,7 +10,8 @@ import torch
10
  from torch.cuda import nvtx
11
  from collections import OrderedDict
12
  import platform
13
- import spaces
 
14
 
15
  try:
16
  import tensorrt as trt
@@ -40,15 +41,15 @@ class TensorRTPredictor:
40
  """
41
  Implements inference for the EfficientDet TensorRT engine.
42
  """
43
-
44
  def __init__(self, **kwargs):
45
  """
46
  :param engine_path: The path to the serialized engine to load from disk.
47
  """
48
  if platform.system().lower() == 'linux':
49
- ctypes.CDLL("./downloaded_repo/pretrained_weights/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
50
  else:
51
- ctypes.CDLL("./downloaded_repo/pretrained_weights/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL)
52
  # Load TRT engine
53
  self.logger = trt.Logger(trt.Logger.VERBOSE)
54
  trt.init_libnvinfer_plugins(self.logger, "")
@@ -172,7 +173,6 @@ class TensorRTPredictor:
172
  del self.outputs
173
  del self.tensors
174
 
175
-
176
  class OnnxRuntimePredictor:
177
  """
178
  OnnxRuntime Prediction
@@ -190,7 +190,10 @@ class OnnxRuntimePredictor:
190
  # opts.inter_op_num_threads = kwargs.get("num_threads", 4)
191
  # opts.intra_op_num_threads = kwargs.get("num_threads", 4)
192
  # opts.log_severity_level = 3
193
- self.onnx_model = onnxruntime.InferenceSession(model_path, providers=providers, sess_options=opts)
 
 
 
194
  self.inputs = self.onnx_model.get_inputs()
195
  self.outputs = self.onnx_model.get_outputs()
196
 
 
10
  from torch.cuda import nvtx
11
  from collections import OrderedDict
12
  import platform
13
+
14
+ import spaces
15
 
16
  try:
17
  import tensorrt as trt
 
41
  """
42
  Implements inference for the EfficientDet TensorRT engine.
43
  """
44
+ @spaces.GPU
45
  def __init__(self, **kwargs):
46
  """
47
  :param engine_path: The path to the serialized engine to load from disk.
48
  """
49
  if platform.system().lower() == 'linux':
50
+ ctypes.CDLL("./difpoint/checkpoints/liveportrait_onnx/libgrid_sample_3d_plugin.so", mode=ctypes.RTLD_GLOBAL)
51
  else:
52
+ ctypes.CDLL("./difpoint/checkpoints/liveportrait_onnx/grid_sample_3d_plugin.dll", mode=ctypes.RTLD_GLOBAL)
53
  # Load TRT engine
54
  self.logger = trt.Logger(trt.Logger.VERBOSE)
55
  trt.init_libnvinfer_plugins(self.logger, "")
 
173
  del self.outputs
174
  del self.tensors
175
 
 
176
  class OnnxRuntimePredictor:
177
  """
178
  OnnxRuntime Prediction
 
190
  # opts.inter_op_num_threads = kwargs.get("num_threads", 4)
191
  # opts.intra_op_num_threads = kwargs.get("num_threads", 4)
192
  # opts.log_severity_level = 3
193
+ def recreate_session():
194
+ return onnxruntime.InferenceSession(model_path, providers=providers, sess_options=opts)
195
+
196
+ self.onnx_model = recreate_session()
197
  self.inputs = self.onnx_model.get_inputs()
198
  self.outputs = self.onnx_model.get_outputs()
199