LukeJacob2023 commited on
Commit
eac1a45
·
verified ·
1 Parent(s): de9456f

Upload 14 files

Browse files
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from cttPunctuator import CttPunctuator
3
+
4
+ punc = CttPunctuator()
5
+
6
+
7
+ def punctuate(text):
8
+ # 使用模型生成标点润饰的文本
9
+ return punc.punctuate(text)[0]
10
+
11
+
12
+ def clear_text():
13
+ return "", ""
14
+
15
+
16
+ with gr.Blocks() as iface:
17
+ gr.Markdown("""
18
+ # 中英文标点润饰工具
19
+
20
+ 这个工具可以帮助你自动为文本添加适当的标点符号。
21
+ 基于项目:https://github.com/lovemefan/CT-Transformer-punctuation
22
+
23
+ 使用说明:
24
+ 1. 在左侧的输入框中粘贴或输入你的文本。
25
+ 2. 点击"润饰"按钮。
26
+ 3. 查看右侧输出框中的结果。可以使用输出框右上角复制按钮快速复制结果。
27
+ 4. 如需清空所有内容,点击"清空"按钮。
28
+ """)
29
+
30
+ with gr.Row():
31
+ with gr.Column(scale=1):
32
+ input_text = gr.Textbox(lines=10, label="输入文本")
33
+
34
+ with gr.Column(scale=1):
35
+ output_text = gr.Textbox(lines=10, label="结果", show_copy_button=True)
36
+
37
+ with gr.Row():
38
+ punctuate_button = gr.Button("润饰")
39
+ clear_button = gr.Button("清空")
40
+
41
+ punctuate_button.click(fn=punctuate, inputs=input_text, outputs=output_text)
42
+ clear_button.click(fn=clear_text, inputs=None, outputs=[input_text, output_text])
43
+
44
+ # 启动Gradio应用
45
+ iface.launch()
cttPunctuator.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :ctt-punctuator.py
3
+ # @Time :2023/4/13 15:03
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+
7
+
8
+ __author__ = "lovemefan"
9
+ __copyright__ = "Copyright (C) 2023 lovemefan"
10
+ __license__ = "MIT"
11
+ __version__ = "v0.0.1"
12
+
13
+ import logging
14
+ import threading
15
+
16
+ from cttpunctuator.src.punctuator import CT_Transformer, CT_Transformer_VadRealtime
17
+
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
21
+ )
22
+
23
+ lock = threading.RLock()
24
+
25
+
26
+ class CttPunctuator:
27
+ _offline_model = None
28
+ _online_model = None
29
+
30
+ def __init__(self, online: bool = False):
31
+ """
32
+ punctuator with singleton pattern
33
+ :param online:
34
+ """
35
+ self.online = online
36
+
37
+ if online:
38
+ if CttPunctuator._online_model is None:
39
+ with lock:
40
+ if CttPunctuator._online_model is None:
41
+ logging.info("Initializing punctuator model with online mode.")
42
+ CttPunctuator._online_model = CT_Transformer_VadRealtime()
43
+ self.param_dict = {"cache": []}
44
+ logging.info("Online model initialized.")
45
+ self.model = CttPunctuator._online_model
46
+
47
+ else:
48
+ if CttPunctuator._offline_model is None:
49
+ with lock:
50
+ if CttPunctuator._offline_model is None:
51
+ logging.info("Initializing punctuator model with offline mode.")
52
+ CttPunctuator._offline_model = CT_Transformer()
53
+ logging.info("Offline model initialized.")
54
+ self.model = CttPunctuator._offline_model
55
+
56
+ logging.info("Model initialized.")
57
+
58
+ def punctuate(self, text: str, param_dict=None):
59
+ if self.online:
60
+ param_dict = param_dict or self.param_dict
61
+ return self.model(text, self.param_dict)
62
+ else:
63
+ return self.model(text)
cttpunctuator/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :__init__.py.py
3
+ # @Time :2023/4/13 14:58
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
cttpunctuator/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
cttpunctuator/src/__pycache__/punctuator.cpython-310.pyc ADDED
Binary file (8.14 kB). View file
 
cttpunctuator/src/onnx/configuration.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "onnx",
3
+ "task" : "punctuation",
4
+ "model" : {
5
+ "type" : "generic-punc",
6
+ "punc_model_name" : "punc.pb",
7
+ "punc_model_config" : {
8
+ "type": "pytorch",
9
+ "code_base": "funasr",
10
+ "mode": "punc",
11
+ "lang": "zh-cn",
12
+ "batch_size": 1,
13
+ "punc_config": "punc.yaml",
14
+ "model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
15
+ }
16
+ },
17
+ "pipeline": {
18
+ "type":"punc-inference"
19
+ }
20
+ }
cttpunctuator/src/onnx/punc.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85a6f2ec7cfa74c1ec932223425a35ce801bb0171571330a94d5d78f9ba2e245
3
+ size 2848807
cttpunctuator/src/onnx/punc.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed5318d91ff9520a03a5b5a8dba264b76858931db7d914b0de6ec9e4ad35970e
3
+ size 292001778
cttpunctuator/src/punctuator.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import Tuple, Union
6
+
7
+ import numpy as np
8
+
9
+ from cttpunctuator.src.utils.OrtInferSession import ONNXRuntimeError, OrtInferSession
10
+ from cttpunctuator.src.utils.text_post_process import (
11
+ TokenIDConverter,
12
+ code_mix_split_words,
13
+ split_to_mini_sentence,
14
+ )
15
+
16
+
17
+ class CT_Transformer:
18
+ """
19
+ Author: Speech Lab, Alibaba Group, China
20
+ CT-Transformer: Controllable time-delay transformer
21
+ for real-time punctuation prediction and disfluency detection
22
+ https://arxiv.org/pdf/2003.01309.pdf
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model_dir: Union[str, Path] = None,
28
+ batch_size: int = 1,
29
+ device_id: Union[str, int] = "-1",
30
+ quantize: bool = False,
31
+ intra_op_num_threads: int = 4,
32
+ ):
33
+ model_dir = model_dir or os.path.join(os.path.dirname(__file__), "onnx")
34
+ if model_dir is None or not Path(model_dir).exists():
35
+ raise FileNotFoundError(f"{model_dir} does not exist.")
36
+
37
+ model_file = os.path.join(model_dir, "punc.onnx")
38
+ if quantize:
39
+ model_file = os.path.join(model_dir, "model_quant.onnx")
40
+ config_file = os.path.join(model_dir, "punc.bin")
41
+ with open(config_file, "rb") as file:
42
+ config = pickle.load(file)
43
+
44
+ self.converter = TokenIDConverter(config["token_list"])
45
+ self.ort_infer = OrtInferSession(
46
+ model_file, device_id, intra_op_num_threads=intra_op_num_threads
47
+ )
48
+ self.batch_size = 1
49
+ self.punc_list = config["punc_list"]
50
+ self.period = 0
51
+ for i in range(len(self.punc_list)):
52
+ if self.punc_list[i] == ",":
53
+ self.punc_list[i] = ","
54
+ elif self.punc_list[i] == "?":
55
+ self.punc_list[i] = "?"
56
+ elif self.punc_list[i] == "。":
57
+ self.period = i
58
+
59
+ def __call__(self, text: Union[list, str], split_size=20):
60
+ split_text = code_mix_split_words(text)
61
+ split_text_id = self.converter.tokens2ids(split_text)
62
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
63
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
64
+ assert len(mini_sentences) == len(mini_sentences_id)
65
+ cache_sent = []
66
+ cache_sent_id = []
67
+ new_mini_sentence = ""
68
+ new_mini_sentence_punc = []
69
+ cache_pop_trigger_limit = 200
70
+ for mini_sentence_i in range(len(mini_sentences)):
71
+ mini_sentence = mini_sentences[mini_sentence_i]
72
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
73
+ mini_sentence = cache_sent + mini_sentence
74
+
75
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int64")
76
+ text_lengths = np.array([len(mini_sentence)], dtype="int32")
77
+
78
+ data = {
79
+ "text": mini_sentence_id[None, :],
80
+ "text_lengths": text_lengths,
81
+ }
82
+ try:
83
+ outputs = self.infer(data["text"], data["text_lengths"])
84
+ y = outputs[0]
85
+ punctuations = np.argmax(y, axis=-1)[0]
86
+ assert punctuations.size == len(mini_sentence)
87
+ except ONNXRuntimeError as e:
88
+ logging.exception(e)
89
+
90
+ # Search for the last Period/QuestionMark as cache
91
+ if mini_sentence_i < len(mini_sentences) - 1:
92
+ sentenceEnd = -1
93
+ last_comma_index = -1
94
+ for i in range(len(punctuations) - 2, 1, -1):
95
+ if (
96
+ self.punc_list[punctuations[i]] == "。"
97
+ or self.punc_list[punctuations[i]] == "?"
98
+ ):
99
+ sentenceEnd = i
100
+ break
101
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
102
+ last_comma_index = i
103
+
104
+ if (
105
+ sentenceEnd < 0
106
+ and len(mini_sentence) > cache_pop_trigger_limit
107
+ and last_comma_index >= 0
108
+ ):
109
+ # The sentence it too long, cut off at a comma.
110
+ sentenceEnd = last_comma_index
111
+ punctuations[sentenceEnd] = self.period
112
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
113
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
114
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
115
+ punctuations = punctuations[0 : sentenceEnd + 1]
116
+
117
+ new_mini_sentence_punc += [int(x) for x in punctuations]
118
+ words_with_punc = []
119
+ for i in range(len(mini_sentence)):
120
+ if i > 0:
121
+ if (
122
+ len(mini_sentence[i][0].encode()) == 1
123
+ and len(mini_sentence[i - 1][0].encode()) == 1
124
+ ):
125
+ mini_sentence[i] = " " + mini_sentence[i]
126
+ words_with_punc.append(mini_sentence[i])
127
+ if self.punc_list[punctuations[i]] != "_":
128
+ words_with_punc.append(self.punc_list[punctuations[i]])
129
+ new_mini_sentence += "".join(words_with_punc)
130
+ # Add Period for the end of the sentence
131
+ new_mini_sentence_out = new_mini_sentence
132
+ new_mini_sentence_punc_out = new_mini_sentence_punc
133
+ if mini_sentence_i == len(mini_sentences) - 1:
134
+ if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
135
+ new_mini_sentence_out = new_mini_sentence[:-1] + "。"
136
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
137
+ self.period
138
+ ]
139
+ elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
140
+ new_mini_sentence_out = new_mini_sentence + "。"
141
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
142
+ self.period
143
+ ]
144
+ return new_mini_sentence_out, new_mini_sentence_punc_out
145
+
146
+ def infer(
147
+ self, feats: np.ndarray, feats_len: np.ndarray
148
+ ) -> Tuple[np.ndarray, np.ndarray]:
149
+ outputs = self.ort_infer([feats, feats_len])
150
+ return outputs
151
+
152
+
153
+ class CT_Transformer_VadRealtime(CT_Transformer):
154
+ """
155
+ Author: Speech Lab, Alibaba Group, China
156
+ CT-Transformer: Controllable time-delay transformer for
157
+ real-time punctuation prediction and disfluency detection
158
+ https://arxiv.org/pdf/2003.01309.pdf
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ model_dir: Union[str, Path] = None,
164
+ batch_size: int = 1,
165
+ device_id: Union[str, int] = "-1",
166
+ quantize: bool = False,
167
+ intra_op_num_threads: int = 4,
168
+ ):
169
+ super(CT_Transformer_VadRealtime, self).__init__(
170
+ model_dir, batch_size, device_id, quantize, intra_op_num_threads
171
+ )
172
+
173
+ def __call__(self, text: str, param_dict: map, split_size=20):
174
+ cache_key = "cache"
175
+ assert cache_key in param_dict
176
+ cache = param_dict[cache_key]
177
+ if cache is not None and len(cache) > 0:
178
+ precache = "".join(cache)
179
+ else:
180
+ precache = ""
181
+ cache = []
182
+ full_text = precache + text
183
+ split_text = code_mix_split_words(full_text)
184
+ split_text_id = self.converter.tokens2ids(split_text)
185
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
186
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
187
+ new_mini_sentence_punc = []
188
+ assert len(mini_sentences) == len(mini_sentences_id)
189
+
190
+ cache_sent = []
191
+ cache_sent_id = np.array([], dtype="int32")
192
+ sentence_punc_list = []
193
+ sentence_words_list = []
194
+ cache_pop_trigger_limit = 200
195
+ skip_num = 0
196
+ for mini_sentence_i in range(len(mini_sentences)):
197
+ mini_sentence = mini_sentences[mini_sentence_i]
198
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
199
+ mini_sentence = cache_sent + mini_sentence
200
+ mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
201
+ text_length = len(mini_sentence_id)
202
+ data = {
203
+ "input": np.array(mini_sentence_id[None, :], dtype="int64"),
204
+ "text_lengths": np.array([text_length], dtype="int32"),
205
+ "vad_mask": self.vad_mask(text_length, len(cache))[
206
+ None, None, :, :
207
+ ].astype(np.float32),
208
+ "sub_masks": np.tril(
209
+ np.ones((text_length, text_length), dtype=np.float32)
210
+ )[None, None, :, :].astype(np.float32),
211
+ }
212
+ try:
213
+ outputs = self.infer(
214
+ data["input"],
215
+ data["text_lengths"],
216
+ data["vad_mask"],
217
+ data["sub_masks"],
218
+ )
219
+ y = outputs[0]
220
+ punctuations = np.argmax(y, axis=-1)[0]
221
+ assert punctuations.size == len(mini_sentence)
222
+ except ONNXRuntimeError as e:
223
+ logging.exception(e)
224
+
225
+ # Search for the last Period/QuestionMark as cache
226
+ if mini_sentence_i < len(mini_sentences) - 1:
227
+ sentenceEnd = -1
228
+ last_comma_index = -1
229
+ for i in range(len(punctuations) - 2, 1, -1):
230
+ if (
231
+ self.punc_list[punctuations[i]] == "。"
232
+ or self.punc_list[punctuations[i]] == "?"
233
+ ):
234
+ sentenceEnd = i
235
+ break
236
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
237
+ last_comma_index = i
238
+
239
+ if (
240
+ sentenceEnd < 0
241
+ and len(mini_sentence) > cache_pop_trigger_limit
242
+ and last_comma_index >= 0
243
+ ):
244
+ # The sentence it too long, cut off at a comma.
245
+ sentenceEnd = last_comma_index
246
+ punctuations[sentenceEnd] = self.period
247
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
248
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
249
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
250
+ punctuations = punctuations[0 : sentenceEnd + 1]
251
+
252
+ punctuations_np = [int(x) for x in punctuations]
253
+ new_mini_sentence_punc += punctuations_np
254
+ sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
255
+ sentence_words_list += mini_sentence
256
+
257
+ assert len(sentence_punc_list) == len(sentence_words_list)
258
+ words_with_punc = []
259
+ sentence_punc_list_out = []
260
+ for i in range(0, len(sentence_words_list)):
261
+ if i > 0:
262
+ if (
263
+ len(sentence_words_list[i][0].encode()) == 1
264
+ and len(sentence_words_list[i - 1][-1].encode()) == 1
265
+ ):
266
+ sentence_words_list[i] = " " + sentence_words_list[i]
267
+ if skip_num < len(cache):
268
+ skip_num += 1
269
+ else:
270
+ words_with_punc.append(sentence_words_list[i])
271
+ if skip_num >= len(cache):
272
+ sentence_punc_list_out.append(sentence_punc_list[i])
273
+ if sentence_punc_list[i] != "_":
274
+ words_with_punc.append(sentence_punc_list[i])
275
+ sentence_out = "".join(words_with_punc)
276
+
277
+ sentenceEnd = -1
278
+ for i in range(len(sentence_punc_list) - 2, 1, -1):
279
+ if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
280
+ sentenceEnd = i
281
+ break
282
+ cache_out = sentence_words_list[sentenceEnd + 1 :]
283
+ if sentence_out[-1] in self.punc_list:
284
+ sentence_out = sentence_out[:-1]
285
+ sentence_punc_list_out[-1] = "_"
286
+ param_dict[cache_key] = cache_out
287
+ return sentence_out, sentence_punc_list_out, cache_out
288
+
289
+ def vad_mask(self, size, vad_pos, dtype=np.bool_):
290
+ """Create mask for decoder self-attention.
291
+
292
+ :param int size: size of mask
293
+ :param int vad_pos: index of vad index
294
+ :param torch.dtype dtype: result dtype
295
+ :rtype: torch.Tensor (B, Lmax, Lmax)
296
+ """
297
+ ret = np.ones((size, size), dtype=dtype)
298
+ if vad_pos <= 0 or vad_pos >= size:
299
+ return ret
300
+ sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
301
+ ret[0 : vad_pos - 1, vad_pos:] = sub_corner
302
+ return ret
303
+
304
+ def infer(
305
+ self,
306
+ feats: np.ndarray,
307
+ feats_len: np.ndarray,
308
+ vad_mask: np.ndarray,
309
+ sub_masks: np.ndarray,
310
+ ) -> Tuple[np.ndarray, np.ndarray]:
311
+ outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
312
+ return outputs
cttpunctuator/src/utils/OrtInferSession.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :OrtInferSession.py
3
+ # @Time :2023/4/13 15:13
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import List, Union
9
+
10
+ import numpy as np
11
+ from onnxruntime import (
12
+ GraphOptimizationLevel,
13
+ InferenceSession,
14
+ SessionOptions,
15
+ get_available_providers,
16
+ get_device,
17
+ )
18
+
19
+
20
+ class ONNXRuntimeError(Exception):
21
+ pass
22
+
23
+
24
+ class OrtInferSession:
25
+ def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
26
+ device_id = str(device_id)
27
+ sess_opt = SessionOptions()
28
+ sess_opt.intra_op_num_threads = intra_op_num_threads
29
+ sess_opt.log_severity_level = 4
30
+ sess_opt.enable_cpu_mem_arena = False
31
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
32
+
33
+ cuda_ep = "CUDAExecutionProvider"
34
+ cuda_provider_options = {
35
+ "device_id": device_id,
36
+ "arena_extend_strategy": "kNextPowerOfTwo",
37
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
38
+ "do_copy_in_default_stream": "true",
39
+ }
40
+ cpu_ep = "CPUExecutionProvider"
41
+ cpu_provider_options = {
42
+ "arena_extend_strategy": "kSameAsRequested",
43
+ }
44
+
45
+ EP_list = []
46
+ if (
47
+ device_id != "-1"
48
+ and get_device() == "GPU"
49
+ and cuda_ep in get_available_providers()
50
+ ):
51
+ EP_list = [(cuda_ep, cuda_provider_options)]
52
+ EP_list.append((cpu_ep, cpu_provider_options))
53
+
54
+ self._verify_model(model_file)
55
+ self.session = InferenceSession(
56
+ model_file, sess_options=sess_opt, providers=EP_list
57
+ )
58
+
59
+ if device_id != "-1" and cuda_ep not in self.session.get_providers():
60
+ logging.warnings.warn(
61
+ f"{cuda_ep} is not avaiable for current env, "
62
+ f"the inference part is automatically shifted to be executed under {cpu_ep}.\n"
63
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
64
+ "you can check their relations from the offical web site: "
65
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
66
+ RuntimeWarning,
67
+ )
68
+
69
+ def __call__(
70
+ self, input_content: List[Union[np.ndarray, np.ndarray]]
71
+ ) -> np.ndarray:
72
+ input_dict = dict(zip(self.get_input_names(), input_content))
73
+ try:
74
+ return self.session.run(self.get_output_names(), input_dict)
75
+ except Exception as e:
76
+ raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
77
+
78
+ def get_input_names(
79
+ self,
80
+ ):
81
+ return [v.name for v in self.session.get_inputs()]
82
+
83
+ def get_output_names(
84
+ self,
85
+ ):
86
+ return [v.name for v in self.session.get_outputs()]
87
+
88
+ def get_character_list(self, key: str = "character"):
89
+ return self.meta_dict[key].splitlines()
90
+
91
+ def have_key(self, key: str = "character") -> bool:
92
+ self.meta_dict = self.session.get_modelmeta().custom_metadata_map
93
+ if key in self.meta_dict.keys():
94
+ return True
95
+ return False
96
+
97
+ @staticmethod
98
+ def _verify_model(model_path):
99
+ model_path = Path(model_path)
100
+ if not model_path.exists():
101
+ raise FileNotFoundError(f"{model_path} does not exists.")
102
+ if not model_path.is_file():
103
+ raise FileExistsError(f"{model_path} is not a file.")
cttpunctuator/src/utils/__pycache__/OrtInferSession.cpython-310.pyc ADDED
Binary file (3.82 kB). View file
 
cttpunctuator/src/utils/__pycache__/text_post_process.cpython-310.pyc ADDED
Binary file (3.26 kB). View file
 
cttpunctuator/src/utils/text_post_process.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # @FileName :text_post_process.py
3
+ # @Time :2023/4/13 15:09
4
+ # @Author :lovemefan
5
+ # @Email :[email protected]
6
+ from pathlib import Path
7
+ from typing import Dict, Iterable, List, Union
8
+
9
+ import numpy as np
10
+ import yaml
11
+ from typeguard import check_argument_types
12
+
13
+
14
+ class TokenIDConverterError(Exception):
15
+ pass
16
+
17
+
18
+ class TokenIDConverter:
19
+ def __init__(
20
+ self,
21
+ token_list: Union[List, str],
22
+ ):
23
+ check_argument_types()
24
+
25
+ self.token_list = token_list
26
+ self.unk_symbol = token_list[-1]
27
+ self.token2id = {v: i for i, v in enumerate(self.token_list)}
28
+ self.unk_id = self.token2id[self.unk_symbol]
29
+
30
+ def get_num_vocabulary_size(self) -> int:
31
+ return len(self.token_list)
32
+
33
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
34
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
35
+ raise TokenIDConverterError(
36
+ f"Must be 1 dim ndarray, but got {integers.ndim}"
37
+ )
38
+ return [self.token_list[i] for i in integers]
39
+
40
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
41
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
42
+
43
+
44
+ def split_to_mini_sentence(words: list, word_limit: int = 20):
45
+ assert word_limit > 1
46
+ if len(words) <= word_limit:
47
+ return [words]
48
+ sentences = []
49
+ length = len(words)
50
+ sentence_len = length // word_limit
51
+ for i in range(sentence_len):
52
+ sentences.append(words[i * word_limit : (i + 1) * word_limit])
53
+ if length % word_limit > 0:
54
+ sentences.append(words[sentence_len * word_limit :])
55
+ return sentences
56
+
57
+
58
+ def code_mix_split_words(text: str):
59
+ words = []
60
+ segs = text.split()
61
+ for seg in segs:
62
+ # There is no space in seg.
63
+ current_word = ""
64
+ for c in seg:
65
+ if len(c.encode()) == 1:
66
+ # This is an ASCII char.
67
+ current_word += c
68
+ else:
69
+ # This is a Chinese char.
70
+ if len(current_word) > 0:
71
+ words.append(current_word)
72
+ current_word = ""
73
+ words.append(c)
74
+ if len(current_word) > 0:
75
+ words.append(current_word)
76
+ return words
77
+
78
+
79
+ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
80
+ if not Path(yaml_path).exists():
81
+ raise FileExistsError(f"The {yaml_path} does not exist.")
82
+
83
+ with open(str(yaml_path), "rb") as f:
84
+ data = yaml.load(f, Loader=yaml.Loader)
85
+ return data
requirements.txt ADDED
Binary file (5.01 kB). View file