lj1995 commited on
Commit
79fea0b
·
verified ·
1 Parent(s): 53b26a5

Update text/g2pw/onnx_api.py

Browse files
Files changed (1) hide show
  1. text/g2pw/onnx_api.py +15 -11
text/g2pw/onnx_api.py CHANGED
@@ -81,25 +81,29 @@ class G2PWOnnxConverter:
81
  model_source: str=None,
82
  enable_non_tradional_chinese: bool=False):
83
  uncompress_path = download_and_decompress(model_dir)
84
-
85
  sess_options = onnxruntime.SessionOptions()
 
86
  sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
 
87
  sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
 
88
  sess_options.intra_op_num_threads = 2
 
89
  try:
90
  self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
91
  except:
92
  self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
93
-
94
  self.config = load_config(
95
  config_path=os.path.join(uncompress_path, 'config.py'),
96
  use_default=True)
97
 
98
  self.model_source = model_source if model_source else self.config.model_source
99
  self.enable_opencc = enable_non_tradional_chinese
100
-
101
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
102
-
103
  polyphonic_chars_path = os.path.join(uncompress_path,
104
  'POLYPHONIC_CHARS.txt')
105
  monophonic_chars_path = os.path.join(uncompress_path,
@@ -123,14 +127,14 @@ class G2PWOnnxConverter:
123
  polyphonic_chars=self.polyphonic_chars
124
  ) if self.config.use_char_phoneme else get_phoneme_labels(
125
  polyphonic_chars=self.polyphonic_chars)
126
-
127
  self.chars = sorted(list(self.char2phonemes.keys()))
128
 
129
  self.polyphonic_chars_new = set(self.chars)
130
  for char in self.non_polyphonic:
131
  if char in self.polyphonic_chars_new:
132
  self.polyphonic_chars_new.remove(char)
133
-
134
  self.monophonic_chars_dict = {
135
  char: phoneme
136
  for char, phoneme in self.monophonic_chars
@@ -138,11 +142,11 @@ class G2PWOnnxConverter:
138
  for char in self.non_monophonic:
139
  if char in self.monophonic_chars_dict:
140
  self.monophonic_chars_dict.pop(char)
141
-
142
  self.pos_tags = [
143
  'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
144
  ]
145
-
146
  with open(
147
  os.path.join(uncompress_path,
148
  'bopomofo_to_pinyin_wo_tune_dict.json'),
@@ -153,16 +157,16 @@ class G2PWOnnxConverter:
153
  'bopomofo': lambda x: x,
154
  'pinyin': self._convert_bopomofo_to_pinyin,
155
  }[style]
156
-
157
  with open(
158
  os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
159
  'r',
160
  encoding='utf-8') as fr:
161
  self.char_bopomofo_dict = json.load(fr)
162
-
163
  if self.enable_opencc:
164
  self.cc = OpenCC('s2tw')
165
-
166
  def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
167
  tone = bopomofo[-1]
168
  assert tone in '12345'
 
81
  model_source: str=None,
82
  enable_non_tradional_chinese: bool=False):
83
  uncompress_path = download_and_decompress(model_dir)
84
+ print(":::1")
85
  sess_options = onnxruntime.SessionOptions()
86
+ print(":::2")
87
  sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
88
+ print(":::3")
89
  sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
90
+ print(":::4")
91
  sess_options.intra_op_num_threads = 2
92
+ print(":::5")
93
  try:
94
  self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
95
  except:
96
  self.session_g2pW = onnxruntime.InferenceSession(os.path.join(uncompress_path, 'g2pW.onnx'),sess_options=sess_options, providers=['CPUExecutionProvider'])
97
+ print(":::6")
98
  self.config = load_config(
99
  config_path=os.path.join(uncompress_path, 'config.py'),
100
  use_default=True)
101
 
102
  self.model_source = model_source if model_source else self.config.model_source
103
  self.enable_opencc = enable_non_tradional_chinese
104
+ print(":::7")
105
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
106
+ print(":::8")
107
  polyphonic_chars_path = os.path.join(uncompress_path,
108
  'POLYPHONIC_CHARS.txt')
109
  monophonic_chars_path = os.path.join(uncompress_path,
 
127
  polyphonic_chars=self.polyphonic_chars
128
  ) if self.config.use_char_phoneme else get_phoneme_labels(
129
  polyphonic_chars=self.polyphonic_chars)
130
+ print(":::9")
131
  self.chars = sorted(list(self.char2phonemes.keys()))
132
 
133
  self.polyphonic_chars_new = set(self.chars)
134
  for char in self.non_polyphonic:
135
  if char in self.polyphonic_chars_new:
136
  self.polyphonic_chars_new.remove(char)
137
+ print(":::10")
138
  self.monophonic_chars_dict = {
139
  char: phoneme
140
  for char, phoneme in self.monophonic_chars
 
142
  for char in self.non_monophonic:
143
  if char in self.monophonic_chars_dict:
144
  self.monophonic_chars_dict.pop(char)
145
+ print(":::11")
146
  self.pos_tags = [
147
  'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'
148
  ]
149
+ print(":::12")
150
  with open(
151
  os.path.join(uncompress_path,
152
  'bopomofo_to_pinyin_wo_tune_dict.json'),
 
157
  'bopomofo': lambda x: x,
158
  'pinyin': self._convert_bopomofo_to_pinyin,
159
  }[style]
160
+ print(":::13")
161
  with open(
162
  os.path.join(uncompress_path, 'char_bopomofo_dict.json'),
163
  'r',
164
  encoding='utf-8') as fr:
165
  self.char_bopomofo_dict = json.load(fr)
166
+ print(":::14")
167
  if self.enable_opencc:
168
  self.cc = OpenCC('s2tw')
169
+ print(":::15")
170
  def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
171
  tone = bopomofo[-1]
172
  assert tone in '12345'