YinuoGuo27 commited on
Commit
263b2b0
·
verified ·
1 Parent(s): 305c702

Update difpoint/src/models/predictor.py

Browse files
Files changed (1) hide show
  1. difpoint/src/models/predictor.py +15 -6
difpoint/src/models/predictor.py CHANGED
@@ -190,13 +190,21 @@ class OnnxRuntimePredictor:
190
  # opts.inter_op_num_threads = kwargs.get("num_threads", 4)
191
  # opts.intra_op_num_threads = kwargs.get("num_threads", 4)
192
  # opts.log_severity_level = 3
193
- def recreate_session():
194
- return onnxruntime.InferenceSession(model_path, providers=providers, sess_options=opts)
195
-
196
- self.onnx_model = recreate_session()
197
- self.inputs = self.onnx_model.get_inputs()
198
- self.outputs = self.onnx_model.get_outputs()
199
 
 
 
 
 
 
 
 
 
 
 
200
  def input_spec(self):
201
  """
202
  Get the specs for the input tensor of the network. Useful to prepare memory allocations.
@@ -222,6 +230,7 @@ class OnnxRuntimePredictor:
222
  return specs
223
 
224
  def predict(self, *data):
 
225
  input_feeds = {}
226
  for i in range(len(data)):
227
  if self.inputs[i].type == 'tensor(float16)':
 
190
  # opts.inter_op_num_threads = kwargs.get("num_threads", 4)
191
  # opts.intra_op_num_threads = kwargs.get("num_threads", 4)
192
  # opts.log_severity_level = 3
193
+ #self.onnx_model = onnxruntime.InferenceSession(model_path, providers=providers, sess_options=opts)
194
+ #self.inputs = self.onnx_model.get_inputs()
195
+ #self.outputs = self.onnx_model.get_outputs()
196
+ self.onnx_model = None
 
 
197
 
198
+ def _load_model(self):
199
+ """Lazy initialization of the ONNX model (only when needed)."""
200
+ if self.onnx_model is None:
201
+ providers = ['CUDAExecutionProvider', 'CoreMLExecutionProvider', 'CPUExecutionProvider']
202
+ print(f"OnnxRuntime use {providers}")
203
+ opts = onnxruntime.SessionOptions()
204
+ self.onnx_model = onnxruntime.InferenceSession(self.model_path, providers=providers, sess_options=opts)
205
+ self.inputs = self.onnx_model.get_inputs()
206
+ self.outputs = self.onnx_model.get_outputs()
207
+
208
  def input_spec(self):
209
  """
210
  Get the specs for the input tensor of the network. Useful to prepare memory allocations.
 
230
  return specs
231
 
232
  def predict(self, *data):
233
+ self._load_model()
234
  input_feeds = {}
235
  for i in range(len(data)):
236
  if self.inputs[i].type == 'tensor(float16)':