Spaces:
Paused
Paused
Update text/g2pw/onnx_api.py
Browse files- text/g2pw/onnx_api.py +12 -16
text/g2pw/onnx_api.py
CHANGED
|
@@ -81,26 +81,22 @@ class G2PWOnnxConverter:
|
|
| 81 |
model_source: str=None,
|
| 82 |
enable_non_tradional_chinese: bool=False):
|
| 83 |
uncompress_path = download_and_decompress(model_dir)
|
| 84 |
-
|
| 85 |
sess_options = onnxruntime.SessionOptions()
|
| 86 |
-
print(":::2")
|
| 87 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 88 |
-
print(":::3")
|
| 89 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
| 90 |
-
print(":::4")
|
| 91 |
sess_options.intra_op_num_threads = 2
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
print(":::6")
|
| 95 |
self.config = load_config(
|
| 96 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
| 97 |
use_default=True)
|
| 98 |
|
| 99 |
self.model_source = model_source if model_source else self.config.model_source
|
| 100 |
self.enable_opencc = enable_non_tradional_chinese
|
| 101 |
-
|
| 102 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
| 103 |
-
|
| 104 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
| 105 |
'POLYPHONIC_CHARS.txt')
|
| 106 |
monophonic_chars_path = os.path.join(uncompress_path,
|
|
@@ -124,14 +120,14 @@ class G2PWOnnxConverter:
|
|
| 124 |
polyphonic_chars=self.polyphonic_chars
|
| 125 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
| 126 |
polyphonic_chars=self.polyphonic_chars)
|
| 127 |
-
|
| 128 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
| 129 |
|
| 130 |
self.polyphonic_chars_new = set(self.chars)
|
| 131 |
for char in self.non_polyphonic:
|
| 132 |
if char in self.polyphonic_chars_new:
|
| 133 |
self.polyphonic_chars_new.remove(char)
|
| 134 |
-
|
| 135 |
self.monophonic_chars_dict = {
|
| 136 |
char: phoneme
|
| 137 |
for char, phoneme in self.monophonic_chars
|
|
@@ -139,11 +135,11 @@ class G2PWOnnxConverter:
|
|
| 139 |
for char in self.non_monophonic:
|
| 140 |
if char in self.monophonic_chars_dict:
|
| 141 |
self.monophonic_chars_dict.pop(char)
|
| 142 |
-
|
| 143 |
self.pos_tags = [
|
| 144 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
| 145 |
]
|
| 146 |
-
|
| 147 |
with open(
|
| 148 |
os.path.join(uncompress_path,
|
| 149 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
@@ -154,16 +150,16 @@ class G2PWOnnxConverter:
|
|
| 154 |
'bopomofo': lambda x: x,
|
| 155 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
| 156 |
}[style]
|
| 157 |
-
|
| 158 |
with open(
|
| 159 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
| 160 |
'r',
|
| 161 |
encoding='utf-8') as fr:
|
| 162 |
self.char_bopomofo_dict = json.load(fr)
|
| 163 |
-
|
| 164 |
if self.enable_opencc:
|
| 165 |
self.cc = OpenCC('s2tw')
|
| 166 |
-
|
| 167 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
| 168 |
tone = bopomofo[-1]
|
| 169 |
assert tone in '12345'
|
|
|
|
| 81 |
model_source: str=None,
|
| 82 |
enable_non_tradional_chinese: bool=False):
|
| 83 |
uncompress_path = download_and_decompress(model_dir)
|
| 84 |
+
|
| 85 |
sess_options = onnxruntime.SessionOptions()
|
|
|
|
| 86 |
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
| 87 |
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
|
|
|
| 88 |
sess_options.intra_op_num_threads = 2
|
| 89 |
+
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'), sess_options=sess_options, providers=['CPUExecutionProvider'])
|
| 90 |
+
|
|
|
|
| 91 |
self.config = load_config(
|
| 92 |
config_path=os.path.join(uncompress_path, 'config.py'),
|
| 93 |
use_default=True)
|
| 94 |
|
| 95 |
self.model_source = model_source if model_source else self.config.model_source
|
| 96 |
self.enable_opencc = enable_non_tradional_chinese
|
| 97 |
+
|
| 98 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
| 99 |
+
|
| 100 |
polyphonic_chars_path = os.path.join(uncompress_path,
|
| 101 |
'POLYPHONIC_CHARS.txt')
|
| 102 |
monophonic_chars_path = os.path.join(uncompress_path,
|
|
|
|
| 120 |
polyphonic_chars=self.polyphonic_chars
|
| 121 |
) if self.config.use_char_phoneme else get_phoneme_labels(
|
| 122 |
polyphonic_chars=self.polyphonic_chars)
|
| 123 |
+
|
| 124 |
self.chars = sorted(list(self.char2phonemes.keys()))
|
| 125 |
|
| 126 |
self.polyphonic_chars_new = set(self.chars)
|
| 127 |
for char in self.non_polyphonic:
|
| 128 |
if char in self.polyphonic_chars_new:
|
| 129 |
self.polyphonic_chars_new.remove(char)
|
| 130 |
+
|
| 131 |
self.monophonic_chars_dict = {
|
| 132 |
char: phoneme
|
| 133 |
for char, phoneme in self.monophonic_chars
|
|
|
|
| 135 |
for char in self.non_monophonic:
|
| 136 |
if char in self.monophonic_chars_dict:
|
| 137 |
self.monophonic_chars_dict.pop(char)
|
| 138 |
+
|
| 139 |
self.pos_tags = [
|
| 140 |
'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
|
| 141 |
]
|
| 142 |
+
|
| 143 |
with open(
|
| 144 |
os.path.join(uncompress_path,
|
| 145 |
'bopomofo_to_pinyin_wo_tune_dict.json'),
|
|
|
|
| 150 |
'bopomofo': lambda x: x,
|
| 151 |
'pinyin': self._convert_bopomofo_to_pinyin,
|
| 152 |
}[style]
|
| 153 |
+
|
| 154 |
with open(
|
| 155 |
os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
|
| 156 |
'r',
|
| 157 |
encoding='utf-8') as fr:
|
| 158 |
self.char_bopomofo_dict = json.load(fr)
|
| 159 |
+
|
| 160 |
if self.enable_opencc:
|
| 161 |
self.cc = OpenCC('s2tw')
|
| 162 |
+
|
| 163 |
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
| 164 |
tone = bopomofo[-1]
|
| 165 |
assert tone in '12345'
|