YinuoGuo27 commited on
Commit
64dbe94
·
verified ·
1 Parent(s): 61496e9

Update difpoint/src/models/predictor.py

Browse files
Files changed (1) hide show
  1. difpoint/src/models/predictor.py +2 -1
difpoint/src/models/predictor.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from torch.cuda import nvtx
11
  from collections import OrderedDict
12
  import platform
 
13
 
14
  try:
15
  import tensorrt as trt
@@ -176,7 +177,7 @@ class OnnxRuntimePredictor:
176
  """
177
  OnnxRuntime Prediction
178
  """
179
-
180
  def __init__(self, **kwargs):
181
  model_path = kwargs.get("model_path", "") # 用模型路径区分是否是一样的实例
182
  assert os.path.exists(model_path), "model path must exist!"
 
10
  from torch.cuda import nvtx
11
  from collections import OrderedDict
12
  import platform
13
+ import spaces
14
 
15
  try:
16
  import tensorrt as trt
 
177
  """
178
  OnnxRuntime Prediction
179
  """
180
+ @spaces.GPU
181
  def __init__(self, **kwargs):
182
  model_path = kwargs.get("model_path", "") # 用模型路径区分是否是一样的实例
183
  assert os.path.exists(model_path), "model path must exist!"