Update text/g2pw/onnx_api.py
Browse files- text/g2pw/onnx_api.py +5 -4
text/g2pw/onnx_api.py
CHANGED
@@ -86,10 +86,11 @@ class G2PWOnnxConverter:
|
|
86 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
87 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
88 |
sess_options.intra_op_num_threads = 2
|
89 |
-
|
90 |
-
os.path.join(uncompress_path, 'g2pW.onnx'),
|
91 |
-
|
92 |
-
|
|
|
93 |
self.config = load_config(
|
94 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
95 |
use_default=True)
|
|
|
86 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
87 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
88 |
sess_options.intra_op_num_threads = 2
|
89 |
+
try:
|
90 |
+
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
91 |
+
except:
|
92 |
+
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
93 |
+
|
94 |
self.config = load_config(
|
95 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
96 |
use_default=True)
|