Spaces:
Running
on
Zero
Running
on
Zero
Update text/g2pw/onnx_api.py
Browse files- text/g2pw/onnx_api.py +15 -11
text/g2pw/onnx_api.py
CHANGED
@@ -81,25 +81,29 @@ class G2PWOnnxConverter:
|
|
81 |
model_source: str=None,
|
82 |
enable_non_tradional_chinese: bool=False):
|
83 |
uncompress_path = download_and_decompress(model_dir)
|
84 |
-
|
85 |
sess_options = onnxruntime.SessionOptions()
|
|
|
86 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
87 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
|
|
88 |
sess_options.intra_op_num_threads = 2
|
|
|
89 |
try:
|
90 |
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
91 |
except:
|
92 |
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
93 |
-
|
94 |
self.config = load_config(
|
95 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
96 |
use_default=True)
|
97 |
|
98 |
self.model_source = model_source if model_source else self.config.model_source
|
99 |
self.enable_opencc = enable_non_tradional_chinese
|
100 |
-
|
101 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
102 |
-
|
103 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
104 |
'POLYPHONIC_CHARS.txt')
|
105 |
monophonic_chars_path = os.path.join(uncompress_path,
|
@@ -123,14 +127,14 @@ class G2PWOnnxConverter:
|
|
123 |
polyphonic_chars=self.polyphonic_chars
|
124 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
125 |
polyphonic_chars=self.polyphonic_chars)
|
126 |
-
|
127 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
128 |
|
129 |
self.polyphonic_chars_new = set(self.chars)
|
130 |
for char in self.non_polyphonic:
|
131 |
if char in self.polyphonic_chars_new:
|
132 |
self.polyphonic_chars_new.remove(char)
|
133 |
-
|
134 |
self.monophonic_chars_dict = {
|
135 |
char: phoneme
|
136 |
for char, phoneme in self.monophonic_chars
|
@@ -138,11 +142,11 @@ class G2PWOnnxConverter:
|
|
138 |
for char in self.non_monophonic:
|
139 |
if char in self.monophonic_chars_dict:
|
140 |
self.monophonic_chars_dict.pop(char)
|
141 |
-
|
142 |
self.pos_tags = [
|
143 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
144 |
]
|
145 |
-
|
146 |
with open(
|
147 |
os.path.join(uncompress_path,
|
148 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
@@ -153,16 +157,16 @@ class G2PWOnnxConverter:
|
|
153 |
'bopomofo': lambda x: x,
|
154 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
155 |
}[style]
|
156 |
-
|
157 |
with open(
|
158 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
159 |
'r',
|
160 |
encoding='utf-8') as fr:
|
161 |
self.char_bopomofo_dict = json.load(fr)
|
162 |
-
|
163 |
if self.enable_opencc:
|
164 |
self.cc = OpenCC('s2tw')
|
165 |
-
|
166 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
167 |
tone = bopomofo[-1]
|
168 |
assert tone in '12345'
|
|
|
81 |
model_source: str=None,
|
82 |
enable_non_tradional_chinese: bool=False):
|
83 |
uncompress_path = download_and_decompress(model_dir)
|
84 |
+
print(":::1")
|
85 |
sess_options = onnxruntime.SessionOptions()
|
86 |
+
print(":::2")
|
87 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
88 |
+
print(":::3")
|
89 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
90 |
+
print(":::4")
|
91 |
sess_options.intra_op_num_threads = 2
|
92 |
+
print(":::5")
|
93 |
try:
|
94 |
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
95 |
except:
|
96 |
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
|
97 |
+
print(":::6")
|
98 |
self.config = load_config(
|
99 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
100 |
use_default=True)
|
101 |
|
102 |
self.model_source = model_source if model_source else self.config.model_source
|
103 |
self.enable_opencc = enable_non_tradional_chinese
|
104 |
+
print(":::7")
|
105 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
106 |
+
print(":::8")
|
107 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
108 |
'POLYPHONIC_CHARS.txt')
|
109 |
monophonic_chars_path = os.path.join(uncompress_path,
|
|
|
127 |
polyphonic_chars=self.polyphonic_chars
|
128 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
129 |
polyphonic_chars=self.polyphonic_chars)
|
130 |
+
print(":::9")
|
131 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
132 |
|
133 |
self.polyphonic_chars_new = set(self.chars)
|
134 |
for char in self.non_polyphonic:
|
135 |
if char in self.polyphonic_chars_new:
|
136 |
self.polyphonic_chars_new.remove(char)
|
137 |
+
print(":::10")
|
138 |
self.monophonic_chars_dict = {
|
139 |
char: phoneme
|
140 |
for char, phoneme in self.monophonic_chars
|
|
|
142 |
for char in self.non_monophonic:
|
143 |
if char in self.monophonic_chars_dict:
|
144 |
self.monophonic_chars_dict.pop(char)
|
145 |
+
print(":::11")
|
146 |
self.pos_tags = [
|
147 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
148 |
]
|
149 |
+
print(":::12")
|
150 |
with open(
|
151 |
os.path.join(uncompress_path,
|
152 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
|
157 |
'bopomofo': lambda x: x,
|
158 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
159 |
}[style]
|
160 |
+
print(":::13")
|
161 |
with open(
|
162 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
163 |
'r',
|
164 |
encoding='utf-8') as fr:
|
165 |
self.char_bopomofo_dict = json.load(fr)
|
166 |
+
print(":::14")
|
167 |
if self.enable_opencc:
|
168 |
self.cc = OpenCC('s2tw')
|
169 |
+
print(":::15")
|
170 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
171 |
tone = bopomofo[-1]
|
172 |
assert tone in '12345'
|