Spaces:
Sleeping
Sleeping
Upload 14 files
Browse files- app.py +45 -0
- cttPunctuator.py +63 -0
- cttpunctuator/__init__.py +5 -0
- cttpunctuator/__pycache__/__init__.cpython-310.pyc +0 -0
- cttpunctuator/src/__pycache__/punctuator.cpython-310.pyc +0 -0
- cttpunctuator/src/onnx/configuration.json +20 -0
- cttpunctuator/src/onnx/punc.bin +3 -0
- cttpunctuator/src/onnx/punc.onnx +3 -0
- cttpunctuator/src/punctuator.py +312 -0
- cttpunctuator/src/utils/OrtInferSession.py +103 -0
- cttpunctuator/src/utils/__pycache__/OrtInferSession.cpython-310.pyc +0 -0
- cttpunctuator/src/utils/__pycache__/text_post_process.cpython-310.pyc +0 -0
- cttpunctuator/src/utils/text_post_process.py +85 -0
- requirements.txt +0 -0
app.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from cttPunctuator import CttPunctuator
|
3 |
+
|
4 |
+
punc = CttPunctuator()
|
5 |
+
|
6 |
+
|
7 |
+
def punctuate(text):
|
8 |
+
# 使用模型生成标点润饰的文本
|
9 |
+
return punc.punctuate(text)[0]
|
10 |
+
|
11 |
+
|
12 |
+
def clear_text():
|
13 |
+
return "", ""
|
14 |
+
|
15 |
+
|
16 |
+
with gr.Blocks() as iface:
|
17 |
+
gr.Markdown("""
|
18 |
+
# 中英文标点润饰工具
|
19 |
+
|
20 |
+
这个工具可以帮助你自动为文本添加适当的标点符号。
|
21 |
+
基于项目:https://github.com/lovemefan/CT-Transformer-punctuation
|
22 |
+
|
23 |
+
使用说明:
|
24 |
+
1. 在左侧的输入框中粘贴或输入你的文本。
|
25 |
+
2. 点击"润饰"按钮。
|
26 |
+
3. 查看右侧输出框中的结果。可以使用输出框右上角复制按钮快速复制结果。
|
27 |
+
4. 如需清空所有内容,点击"清空"按钮。
|
28 |
+
""")
|
29 |
+
|
30 |
+
with gr.Row():
|
31 |
+
with gr.Column(scale=1):
|
32 |
+
input_text = gr.Textbox(lines=10, label="输入文本")
|
33 |
+
|
34 |
+
with gr.Column(scale=1):
|
35 |
+
output_text = gr.Textbox(lines=10, label="结果", show_copy_button=True)
|
36 |
+
|
37 |
+
with gr.Row():
|
38 |
+
punctuate_button = gr.Button("润饰")
|
39 |
+
clear_button = gr.Button("清空")
|
40 |
+
|
41 |
+
punctuate_button.click(fn=punctuate, inputs=input_text, outputs=output_text)
|
42 |
+
clear_button.click(fn=clear_text, inputs=None, outputs=[input_text, output_text])
|
43 |
+
|
44 |
+
# 启动Gradio应用
|
45 |
+
iface.launch()
|
cttPunctuator.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :ctt-punctuator.py
|
3 |
+
# @Time :2023/4/13 15:03
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
|
7 |
+
|
8 |
+
__author__ = "lovemefan"
|
9 |
+
__copyright__ = "Copyright (C) 2023 lovemefan"
|
10 |
+
__license__ = "MIT"
|
11 |
+
__version__ = "v0.0.1"
|
12 |
+
|
13 |
+
import logging
|
14 |
+
import threading
|
15 |
+
|
16 |
+
from cttpunctuator.src.punctuator import CT_Transformer, CT_Transformer_VadRealtime
|
17 |
+
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO,
|
20 |
+
format="[%(asctime)s %(levelname)s] [%(filename)s:%(lineno)d %(module)s.%(funcName)s] %(message)s",
|
21 |
+
)
|
22 |
+
|
23 |
+
lock = threading.RLock()
|
24 |
+
|
25 |
+
|
26 |
+
class CttPunctuator:
|
27 |
+
_offline_model = None
|
28 |
+
_online_model = None
|
29 |
+
|
30 |
+
def __init__(self, online: bool = False):
|
31 |
+
"""
|
32 |
+
punctuator with singleton pattern
|
33 |
+
:param online:
|
34 |
+
"""
|
35 |
+
self.online = online
|
36 |
+
|
37 |
+
if online:
|
38 |
+
if CttPunctuator._online_model is None:
|
39 |
+
with lock:
|
40 |
+
if CttPunctuator._online_model is None:
|
41 |
+
logging.info("Initializing punctuator model with online mode.")
|
42 |
+
CttPunctuator._online_model = CT_Transformer_VadRealtime()
|
43 |
+
self.param_dict = {"cache": []}
|
44 |
+
logging.info("Online model initialized.")
|
45 |
+
self.model = CttPunctuator._online_model
|
46 |
+
|
47 |
+
else:
|
48 |
+
if CttPunctuator._offline_model is None:
|
49 |
+
with lock:
|
50 |
+
if CttPunctuator._offline_model is None:
|
51 |
+
logging.info("Initializing punctuator model with offline mode.")
|
52 |
+
CttPunctuator._offline_model = CT_Transformer()
|
53 |
+
logging.info("Offline model initialized.")
|
54 |
+
self.model = CttPunctuator._offline_model
|
55 |
+
|
56 |
+
logging.info("Model initialized.")
|
57 |
+
|
58 |
+
def punctuate(self, text: str, param_dict=None):
|
59 |
+
if self.online:
|
60 |
+
param_dict = param_dict or self.param_dict
|
61 |
+
return self.model(text, self.param_dict)
|
62 |
+
else:
|
63 |
+
return self.model(text)
|
cttpunctuator/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :__init__.py.py
|
3 |
+
# @Time :2023/4/13 14:58
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
cttpunctuator/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (146 Bytes). View file
|
|
cttpunctuator/src/__pycache__/punctuator.cpython-310.pyc
ADDED
Binary file (8.14 kB). View file
|
|
cttpunctuator/src/onnx/configuration.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"framework": "onnx",
|
3 |
+
"task" : "punctuation",
|
4 |
+
"model" : {
|
5 |
+
"type" : "generic-punc",
|
6 |
+
"punc_model_name" : "punc.pb",
|
7 |
+
"punc_model_config" : {
|
8 |
+
"type": "pytorch",
|
9 |
+
"code_base": "funasr",
|
10 |
+
"mode": "punc",
|
11 |
+
"lang": "zh-cn",
|
12 |
+
"batch_size": 1,
|
13 |
+
"punc_config": "punc.yaml",
|
14 |
+
"model": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
15 |
+
}
|
16 |
+
},
|
17 |
+
"pipeline": {
|
18 |
+
"type":"punc-inference"
|
19 |
+
}
|
20 |
+
}
|
cttpunctuator/src/onnx/punc.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:85a6f2ec7cfa74c1ec932223425a35ce801bb0171571330a94d5d78f9ba2e245
|
3 |
+
size 2848807
|
cttpunctuator/src/onnx/punc.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed5318d91ff9520a03a5b5a8dba264b76858931db7d914b0de6ec9e4ad35970e
|
3 |
+
size 292001778
|
cttpunctuator/src/punctuator.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os.path
|
3 |
+
import pickle
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from cttpunctuator.src.utils.OrtInferSession import ONNXRuntimeError, OrtInferSession
|
10 |
+
from cttpunctuator.src.utils.text_post_process import (
|
11 |
+
TokenIDConverter,
|
12 |
+
code_mix_split_words,
|
13 |
+
split_to_mini_sentence,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class CT_Transformer:
|
18 |
+
"""
|
19 |
+
Author: Speech Lab, Alibaba Group, China
|
20 |
+
CT-Transformer: Controllable time-delay transformer
|
21 |
+
for real-time punctuation prediction and disfluency detection
|
22 |
+
https://arxiv.org/pdf/2003.01309.pdf
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model_dir: Union[str, Path] = None,
|
28 |
+
batch_size: int = 1,
|
29 |
+
device_id: Union[str, int] = "-1",
|
30 |
+
quantize: bool = False,
|
31 |
+
intra_op_num_threads: int = 4,
|
32 |
+
):
|
33 |
+
model_dir = model_dir or os.path.join(os.path.dirname(__file__), "onnx")
|
34 |
+
if model_dir is None or not Path(model_dir).exists():
|
35 |
+
raise FileNotFoundError(f"{model_dir} does not exist.")
|
36 |
+
|
37 |
+
model_file = os.path.join(model_dir, "punc.onnx")
|
38 |
+
if quantize:
|
39 |
+
model_file = os.path.join(model_dir, "model_quant.onnx")
|
40 |
+
config_file = os.path.join(model_dir, "punc.bin")
|
41 |
+
with open(config_file, "rb") as file:
|
42 |
+
config = pickle.load(file)
|
43 |
+
|
44 |
+
self.converter = TokenIDConverter(config["token_list"])
|
45 |
+
self.ort_infer = OrtInferSession(
|
46 |
+
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
47 |
+
)
|
48 |
+
self.batch_size = 1
|
49 |
+
self.punc_list = config["punc_list"]
|
50 |
+
self.period = 0
|
51 |
+
for i in range(len(self.punc_list)):
|
52 |
+
if self.punc_list[i] == ",":
|
53 |
+
self.punc_list[i] = ","
|
54 |
+
elif self.punc_list[i] == "?":
|
55 |
+
self.punc_list[i] = "?"
|
56 |
+
elif self.punc_list[i] == "。":
|
57 |
+
self.period = i
|
58 |
+
|
59 |
+
def __call__(self, text: Union[list, str], split_size=20):
|
60 |
+
split_text = code_mix_split_words(text)
|
61 |
+
split_text_id = self.converter.tokens2ids(split_text)
|
62 |
+
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
63 |
+
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
64 |
+
assert len(mini_sentences) == len(mini_sentences_id)
|
65 |
+
cache_sent = []
|
66 |
+
cache_sent_id = []
|
67 |
+
new_mini_sentence = ""
|
68 |
+
new_mini_sentence_punc = []
|
69 |
+
cache_pop_trigger_limit = 200
|
70 |
+
for mini_sentence_i in range(len(mini_sentences)):
|
71 |
+
mini_sentence = mini_sentences[mini_sentence_i]
|
72 |
+
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
73 |
+
mini_sentence = cache_sent + mini_sentence
|
74 |
+
|
75 |
+
mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int64")
|
76 |
+
text_lengths = np.array([len(mini_sentence)], dtype="int32")
|
77 |
+
|
78 |
+
data = {
|
79 |
+
"text": mini_sentence_id[None, :],
|
80 |
+
"text_lengths": text_lengths,
|
81 |
+
}
|
82 |
+
try:
|
83 |
+
outputs = self.infer(data["text"], data["text_lengths"])
|
84 |
+
y = outputs[0]
|
85 |
+
punctuations = np.argmax(y, axis=-1)[0]
|
86 |
+
assert punctuations.size == len(mini_sentence)
|
87 |
+
except ONNXRuntimeError as e:
|
88 |
+
logging.exception(e)
|
89 |
+
|
90 |
+
# Search for the last Period/QuestionMark as cache
|
91 |
+
if mini_sentence_i < len(mini_sentences) - 1:
|
92 |
+
sentenceEnd = -1
|
93 |
+
last_comma_index = -1
|
94 |
+
for i in range(len(punctuations) - 2, 1, -1):
|
95 |
+
if (
|
96 |
+
self.punc_list[punctuations[i]] == "。"
|
97 |
+
or self.punc_list[punctuations[i]] == "?"
|
98 |
+
):
|
99 |
+
sentenceEnd = i
|
100 |
+
break
|
101 |
+
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
102 |
+
last_comma_index = i
|
103 |
+
|
104 |
+
if (
|
105 |
+
sentenceEnd < 0
|
106 |
+
and len(mini_sentence) > cache_pop_trigger_limit
|
107 |
+
and last_comma_index >= 0
|
108 |
+
):
|
109 |
+
# The sentence it too long, cut off at a comma.
|
110 |
+
sentenceEnd = last_comma_index
|
111 |
+
punctuations[sentenceEnd] = self.period
|
112 |
+
cache_sent = mini_sentence[sentenceEnd + 1 :]
|
113 |
+
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
|
114 |
+
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
|
115 |
+
punctuations = punctuations[0 : sentenceEnd + 1]
|
116 |
+
|
117 |
+
new_mini_sentence_punc += [int(x) for x in punctuations]
|
118 |
+
words_with_punc = []
|
119 |
+
for i in range(len(mini_sentence)):
|
120 |
+
if i > 0:
|
121 |
+
if (
|
122 |
+
len(mini_sentence[i][0].encode()) == 1
|
123 |
+
and len(mini_sentence[i - 1][0].encode()) == 1
|
124 |
+
):
|
125 |
+
mini_sentence[i] = " " + mini_sentence[i]
|
126 |
+
words_with_punc.append(mini_sentence[i])
|
127 |
+
if self.punc_list[punctuations[i]] != "_":
|
128 |
+
words_with_punc.append(self.punc_list[punctuations[i]])
|
129 |
+
new_mini_sentence += "".join(words_with_punc)
|
130 |
+
# Add Period for the end of the sentence
|
131 |
+
new_mini_sentence_out = new_mini_sentence
|
132 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc
|
133 |
+
if mini_sentence_i == len(mini_sentences) - 1:
|
134 |
+
if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
|
135 |
+
new_mini_sentence_out = new_mini_sentence[:-1] + "。"
|
136 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
|
137 |
+
self.period
|
138 |
+
]
|
139 |
+
elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
|
140 |
+
new_mini_sentence_out = new_mini_sentence + "。"
|
141 |
+
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [
|
142 |
+
self.period
|
143 |
+
]
|
144 |
+
return new_mini_sentence_out, new_mini_sentence_punc_out
|
145 |
+
|
146 |
+
def infer(
|
147 |
+
self, feats: np.ndarray, feats_len: np.ndarray
|
148 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
149 |
+
outputs = self.ort_infer([feats, feats_len])
|
150 |
+
return outputs
|
151 |
+
|
152 |
+
|
153 |
+
class CT_Transformer_VadRealtime(CT_Transformer):
|
154 |
+
"""
|
155 |
+
Author: Speech Lab, Alibaba Group, China
|
156 |
+
CT-Transformer: Controllable time-delay transformer for
|
157 |
+
real-time punctuation prediction and disfluency detection
|
158 |
+
https://arxiv.org/pdf/2003.01309.pdf
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
model_dir: Union[str, Path] = None,
|
164 |
+
batch_size: int = 1,
|
165 |
+
device_id: Union[str, int] = "-1",
|
166 |
+
quantize: bool = False,
|
167 |
+
intra_op_num_threads: int = 4,
|
168 |
+
):
|
169 |
+
super(CT_Transformer_VadRealtime, self).__init__(
|
170 |
+
model_dir, batch_size, device_id, quantize, intra_op_num_threads
|
171 |
+
)
|
172 |
+
|
173 |
+
def __call__(self, text: str, param_dict: map, split_size=20):
|
174 |
+
cache_key = "cache"
|
175 |
+
assert cache_key in param_dict
|
176 |
+
cache = param_dict[cache_key]
|
177 |
+
if cache is not None and len(cache) > 0:
|
178 |
+
precache = "".join(cache)
|
179 |
+
else:
|
180 |
+
precache = ""
|
181 |
+
cache = []
|
182 |
+
full_text = precache + text
|
183 |
+
split_text = code_mix_split_words(full_text)
|
184 |
+
split_text_id = self.converter.tokens2ids(split_text)
|
185 |
+
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
186 |
+
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
187 |
+
new_mini_sentence_punc = []
|
188 |
+
assert len(mini_sentences) == len(mini_sentences_id)
|
189 |
+
|
190 |
+
cache_sent = []
|
191 |
+
cache_sent_id = np.array([], dtype="int32")
|
192 |
+
sentence_punc_list = []
|
193 |
+
sentence_words_list = []
|
194 |
+
cache_pop_trigger_limit = 200
|
195 |
+
skip_num = 0
|
196 |
+
for mini_sentence_i in range(len(mini_sentences)):
|
197 |
+
mini_sentence = mini_sentences[mini_sentence_i]
|
198 |
+
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
199 |
+
mini_sentence = cache_sent + mini_sentence
|
200 |
+
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
201 |
+
text_length = len(mini_sentence_id)
|
202 |
+
data = {
|
203 |
+
"input": np.array(mini_sentence_id[None, :], dtype="int64"),
|
204 |
+
"text_lengths": np.array([text_length], dtype="int32"),
|
205 |
+
"vad_mask": self.vad_mask(text_length, len(cache))[
|
206 |
+
None, None, :, :
|
207 |
+
].astype(np.float32),
|
208 |
+
"sub_masks": np.tril(
|
209 |
+
np.ones((text_length, text_length), dtype=np.float32)
|
210 |
+
)[None, None, :, :].astype(np.float32),
|
211 |
+
}
|
212 |
+
try:
|
213 |
+
outputs = self.infer(
|
214 |
+
data["input"],
|
215 |
+
data["text_lengths"],
|
216 |
+
data["vad_mask"],
|
217 |
+
data["sub_masks"],
|
218 |
+
)
|
219 |
+
y = outputs[0]
|
220 |
+
punctuations = np.argmax(y, axis=-1)[0]
|
221 |
+
assert punctuations.size == len(mini_sentence)
|
222 |
+
except ONNXRuntimeError as e:
|
223 |
+
logging.exception(e)
|
224 |
+
|
225 |
+
# Search for the last Period/QuestionMark as cache
|
226 |
+
if mini_sentence_i < len(mini_sentences) - 1:
|
227 |
+
sentenceEnd = -1
|
228 |
+
last_comma_index = -1
|
229 |
+
for i in range(len(punctuations) - 2, 1, -1):
|
230 |
+
if (
|
231 |
+
self.punc_list[punctuations[i]] == "。"
|
232 |
+
or self.punc_list[punctuations[i]] == "?"
|
233 |
+
):
|
234 |
+
sentenceEnd = i
|
235 |
+
break
|
236 |
+
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
237 |
+
last_comma_index = i
|
238 |
+
|
239 |
+
if (
|
240 |
+
sentenceEnd < 0
|
241 |
+
and len(mini_sentence) > cache_pop_trigger_limit
|
242 |
+
and last_comma_index >= 0
|
243 |
+
):
|
244 |
+
# The sentence it too long, cut off at a comma.
|
245 |
+
sentenceEnd = last_comma_index
|
246 |
+
punctuations[sentenceEnd] = self.period
|
247 |
+
cache_sent = mini_sentence[sentenceEnd + 1 :]
|
248 |
+
cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
|
249 |
+
mini_sentence = mini_sentence[0 : sentenceEnd + 1]
|
250 |
+
punctuations = punctuations[0 : sentenceEnd + 1]
|
251 |
+
|
252 |
+
punctuations_np = [int(x) for x in punctuations]
|
253 |
+
new_mini_sentence_punc += punctuations_np
|
254 |
+
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
|
255 |
+
sentence_words_list += mini_sentence
|
256 |
+
|
257 |
+
assert len(sentence_punc_list) == len(sentence_words_list)
|
258 |
+
words_with_punc = []
|
259 |
+
sentence_punc_list_out = []
|
260 |
+
for i in range(0, len(sentence_words_list)):
|
261 |
+
if i > 0:
|
262 |
+
if (
|
263 |
+
len(sentence_words_list[i][0].encode()) == 1
|
264 |
+
and len(sentence_words_list[i - 1][-1].encode()) == 1
|
265 |
+
):
|
266 |
+
sentence_words_list[i] = " " + sentence_words_list[i]
|
267 |
+
if skip_num < len(cache):
|
268 |
+
skip_num += 1
|
269 |
+
else:
|
270 |
+
words_with_punc.append(sentence_words_list[i])
|
271 |
+
if skip_num >= len(cache):
|
272 |
+
sentence_punc_list_out.append(sentence_punc_list[i])
|
273 |
+
if sentence_punc_list[i] != "_":
|
274 |
+
words_with_punc.append(sentence_punc_list[i])
|
275 |
+
sentence_out = "".join(words_with_punc)
|
276 |
+
|
277 |
+
sentenceEnd = -1
|
278 |
+
for i in range(len(sentence_punc_list) - 2, 1, -1):
|
279 |
+
if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
|
280 |
+
sentenceEnd = i
|
281 |
+
break
|
282 |
+
cache_out = sentence_words_list[sentenceEnd + 1 :]
|
283 |
+
if sentence_out[-1] in self.punc_list:
|
284 |
+
sentence_out = sentence_out[:-1]
|
285 |
+
sentence_punc_list_out[-1] = "_"
|
286 |
+
param_dict[cache_key] = cache_out
|
287 |
+
return sentence_out, sentence_punc_list_out, cache_out
|
288 |
+
|
289 |
+
def vad_mask(self, size, vad_pos, dtype=np.bool_):
|
290 |
+
"""Create mask for decoder self-attention.
|
291 |
+
|
292 |
+
:param int size: size of mask
|
293 |
+
:param int vad_pos: index of vad index
|
294 |
+
:param torch.dtype dtype: result dtype
|
295 |
+
:rtype: torch.Tensor (B, Lmax, Lmax)
|
296 |
+
"""
|
297 |
+
ret = np.ones((size, size), dtype=dtype)
|
298 |
+
if vad_pos <= 0 or vad_pos >= size:
|
299 |
+
return ret
|
300 |
+
sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
|
301 |
+
ret[0 : vad_pos - 1, vad_pos:] = sub_corner
|
302 |
+
return ret
|
303 |
+
|
304 |
+
def infer(
|
305 |
+
self,
|
306 |
+
feats: np.ndarray,
|
307 |
+
feats_len: np.ndarray,
|
308 |
+
vad_mask: np.ndarray,
|
309 |
+
sub_masks: np.ndarray,
|
310 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
311 |
+
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
|
312 |
+
return outputs
|
cttpunctuator/src/utils/OrtInferSession.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :OrtInferSession.py
|
3 |
+
# @Time :2023/4/13 15:13
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from onnxruntime import (
|
12 |
+
GraphOptimizationLevel,
|
13 |
+
InferenceSession,
|
14 |
+
SessionOptions,
|
15 |
+
get_available_providers,
|
16 |
+
get_device,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class ONNXRuntimeError(Exception):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class OrtInferSession:
|
25 |
+
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
|
26 |
+
device_id = str(device_id)
|
27 |
+
sess_opt = SessionOptions()
|
28 |
+
sess_opt.intra_op_num_threads = intra_op_num_threads
|
29 |
+
sess_opt.log_severity_level = 4
|
30 |
+
sess_opt.enable_cpu_mem_arena = False
|
31 |
+
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
32 |
+
|
33 |
+
cuda_ep = "CUDAExecutionProvider"
|
34 |
+
cuda_provider_options = {
|
35 |
+
"device_id": device_id,
|
36 |
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
37 |
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
38 |
+
"do_copy_in_default_stream": "true",
|
39 |
+
}
|
40 |
+
cpu_ep = "CPUExecutionProvider"
|
41 |
+
cpu_provider_options = {
|
42 |
+
"arena_extend_strategy": "kSameAsRequested",
|
43 |
+
}
|
44 |
+
|
45 |
+
EP_list = []
|
46 |
+
if (
|
47 |
+
device_id != "-1"
|
48 |
+
and get_device() == "GPU"
|
49 |
+
and cuda_ep in get_available_providers()
|
50 |
+
):
|
51 |
+
EP_list = [(cuda_ep, cuda_provider_options)]
|
52 |
+
EP_list.append((cpu_ep, cpu_provider_options))
|
53 |
+
|
54 |
+
self._verify_model(model_file)
|
55 |
+
self.session = InferenceSession(
|
56 |
+
model_file, sess_options=sess_opt, providers=EP_list
|
57 |
+
)
|
58 |
+
|
59 |
+
if device_id != "-1" and cuda_ep not in self.session.get_providers():
|
60 |
+
logging.warnings.warn(
|
61 |
+
f"{cuda_ep} is not avaiable for current env, "
|
62 |
+
f"the inference part is automatically shifted to be executed under {cpu_ep}.\n"
|
63 |
+
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
|
64 |
+
"you can check their relations from the offical web site: "
|
65 |
+
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
66 |
+
RuntimeWarning,
|
67 |
+
)
|
68 |
+
|
69 |
+
def __call__(
|
70 |
+
self, input_content: List[Union[np.ndarray, np.ndarray]]
|
71 |
+
) -> np.ndarray:
|
72 |
+
input_dict = dict(zip(self.get_input_names(), input_content))
|
73 |
+
try:
|
74 |
+
return self.session.run(self.get_output_names(), input_dict)
|
75 |
+
except Exception as e:
|
76 |
+
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
|
77 |
+
|
78 |
+
def get_input_names(
|
79 |
+
self,
|
80 |
+
):
|
81 |
+
return [v.name for v in self.session.get_inputs()]
|
82 |
+
|
83 |
+
def get_output_names(
|
84 |
+
self,
|
85 |
+
):
|
86 |
+
return [v.name for v in self.session.get_outputs()]
|
87 |
+
|
88 |
+
def get_character_list(self, key: str = "character"):
|
89 |
+
return self.meta_dict[key].splitlines()
|
90 |
+
|
91 |
+
def have_key(self, key: str = "character") -> bool:
|
92 |
+
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
93 |
+
if key in self.meta_dict.keys():
|
94 |
+
return True
|
95 |
+
return False
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def _verify_model(model_path):
|
99 |
+
model_path = Path(model_path)
|
100 |
+
if not model_path.exists():
|
101 |
+
raise FileNotFoundError(f"{model_path} does not exists.")
|
102 |
+
if not model_path.is_file():
|
103 |
+
raise FileExistsError(f"{model_path} is not a file.")
|
cttpunctuator/src/utils/__pycache__/OrtInferSession.cpython-310.pyc
ADDED
Binary file (3.82 kB). View file
|
|
cttpunctuator/src/utils/__pycache__/text_post_process.cpython-310.pyc
ADDED
Binary file (3.26 kB). View file
|
|
cttpunctuator/src/utils/text_post_process.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# @FileName :text_post_process.py
|
3 |
+
# @Time :2023/4/13 15:09
|
4 |
+
# @Author :lovemefan
|
5 |
+
# @Email :[email protected]
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Dict, Iterable, List, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import yaml
|
11 |
+
from typeguard import check_argument_types
|
12 |
+
|
13 |
+
|
14 |
+
class TokenIDConverterError(Exception):
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
class TokenIDConverter:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
token_list: Union[List, str],
|
22 |
+
):
|
23 |
+
check_argument_types()
|
24 |
+
|
25 |
+
self.token_list = token_list
|
26 |
+
self.unk_symbol = token_list[-1]
|
27 |
+
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
28 |
+
self.unk_id = self.token2id[self.unk_symbol]
|
29 |
+
|
30 |
+
def get_num_vocabulary_size(self) -> int:
|
31 |
+
return len(self.token_list)
|
32 |
+
|
33 |
+
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
34 |
+
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
35 |
+
raise TokenIDConverterError(
|
36 |
+
f"Must be 1 dim ndarray, but got {integers.ndim}"
|
37 |
+
)
|
38 |
+
return [self.token_list[i] for i in integers]
|
39 |
+
|
40 |
+
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
41 |
+
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
42 |
+
|
43 |
+
|
44 |
+
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
45 |
+
assert word_limit > 1
|
46 |
+
if len(words) <= word_limit:
|
47 |
+
return [words]
|
48 |
+
sentences = []
|
49 |
+
length = len(words)
|
50 |
+
sentence_len = length // word_limit
|
51 |
+
for i in range(sentence_len):
|
52 |
+
sentences.append(words[i * word_limit : (i + 1) * word_limit])
|
53 |
+
if length % word_limit > 0:
|
54 |
+
sentences.append(words[sentence_len * word_limit :])
|
55 |
+
return sentences
|
56 |
+
|
57 |
+
|
58 |
+
def code_mix_split_words(text: str):
|
59 |
+
words = []
|
60 |
+
segs = text.split()
|
61 |
+
for seg in segs:
|
62 |
+
# There is no space in seg.
|
63 |
+
current_word = ""
|
64 |
+
for c in seg:
|
65 |
+
if len(c.encode()) == 1:
|
66 |
+
# This is an ASCII char.
|
67 |
+
current_word += c
|
68 |
+
else:
|
69 |
+
# This is a Chinese char.
|
70 |
+
if len(current_word) > 0:
|
71 |
+
words.append(current_word)
|
72 |
+
current_word = ""
|
73 |
+
words.append(c)
|
74 |
+
if len(current_word) > 0:
|
75 |
+
words.append(current_word)
|
76 |
+
return words
|
77 |
+
|
78 |
+
|
79 |
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
80 |
+
if not Path(yaml_path).exists():
|
81 |
+
raise FileExistsError(f"The {yaml_path} does not exist.")
|
82 |
+
|
83 |
+
with open(str(yaml_path), "rb") as f:
|
84 |
+
data = yaml.load(f, Loader=yaml.Loader)
|
85 |
+
return data
|
requirements.txt
ADDED
Binary file (5.01 kB). View file
|
|