YinuoGuo27 commited on
Commit
d0af744
·
verified ·
1 Parent(s): 2c1a720

Update difpoint/src/models/predictor.py

Browse files
Files changed (1) hide show
  1. difpoint/src/models/predictor.py +4 -1
difpoint/src/models/predictor.py CHANGED
@@ -227,7 +227,10 @@ class OnnxRuntimePredictor:
227
  if self.inputs[i].type == 'tensor(float16)':
228
  input_feeds[self.inputs[i].name] = data[i].astype(np.float16)
229
  else:
230
- input_feeds[self.inputs[i].name] = data[i].astype(np.float32)
 
 
 
231
  results = self.onnx_model.run(None, input_feeds)
232
  return results
233
 
 
227
  if self.inputs[i].type == 'tensor(float16)':
228
  input_feeds[self.inputs[i].name] = data[i].astype(np.float16)
229
  else:
230
+ try:
231
+ input_feeds[self.inputs[i].name] = data[i].astype(np.float32)
232
+ except:
233
+ input_feeds[self.inputs[i].name] = data[i].cpu().numpy().astype(np.float32)
234
  results = self.onnx_model.run(None, input_feeds)
235
  return results
236