Spaces:
Runtime error
Runtime error
Upload auto_rerank.py
Browse files- tools/auto_rerank.py +159 -0
tools/auto_rerank.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
4 |
+
|
5 |
+
import string
|
6 |
+
import time
|
7 |
+
from threading import Lock
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import opencc
|
12 |
+
import torch
|
13 |
+
from faster_whisper import WhisperModel
|
14 |
+
|
15 |
+
t2s_converter = opencc.OpenCC("t2s")
|
16 |
+
|
17 |
+
|
18 |
+
def load_model(*, device="cuda"):
|
19 |
+
model = WhisperModel(
|
20 |
+
"medium",
|
21 |
+
device=device,
|
22 |
+
compute_type="float16",
|
23 |
+
download_root="faster_whisper",
|
24 |
+
)
|
25 |
+
print("faster_whisper loaded!")
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def batch_asr_internal(model: WhisperModel, audios, sr):
|
31 |
+
resampled_audios = []
|
32 |
+
for audio in audios:
|
33 |
+
|
34 |
+
if isinstance(audio, np.ndarray):
|
35 |
+
audio = torch.from_numpy(audio).float()
|
36 |
+
|
37 |
+
if audio.dim() > 1:
|
38 |
+
audio = audio.squeeze()
|
39 |
+
|
40 |
+
assert audio.dim() == 1
|
41 |
+
audio_np = audio.numpy()
|
42 |
+
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
43 |
+
resampled_audios.append(resampled_audio)
|
44 |
+
|
45 |
+
trans_results = []
|
46 |
+
|
47 |
+
for resampled_audio in resampled_audios:
|
48 |
+
segments, info = model.transcribe(
|
49 |
+
resampled_audio,
|
50 |
+
language=None,
|
51 |
+
beam_size=5,
|
52 |
+
initial_prompt="Punctuation is needed in any language.",
|
53 |
+
)
|
54 |
+
trans_results.append(list(segments))
|
55 |
+
|
56 |
+
results = []
|
57 |
+
for trans_res, audio in zip(trans_results, audios):
|
58 |
+
|
59 |
+
duration = len(audio) / sr * 1000
|
60 |
+
huge_gap = False
|
61 |
+
max_gap = 0.0
|
62 |
+
|
63 |
+
text = None
|
64 |
+
last_tr = None
|
65 |
+
|
66 |
+
for tr in trans_res:
|
67 |
+
delta = tr.text.strip()
|
68 |
+
if tr.id > 1:
|
69 |
+
max_gap = max(tr.start - last_tr.end, max_gap)
|
70 |
+
text += delta
|
71 |
+
else:
|
72 |
+
text = delta
|
73 |
+
|
74 |
+
last_tr = tr
|
75 |
+
if max_gap > 3.0:
|
76 |
+
huge_gap = True
|
77 |
+
break
|
78 |
+
|
79 |
+
sim_text = t2s_converter.convert(text)
|
80 |
+
results.append(
|
81 |
+
{
|
82 |
+
"text": sim_text,
|
83 |
+
"duration": duration,
|
84 |
+
"huge_gap": huge_gap,
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
return results
|
89 |
+
|
90 |
+
|
91 |
+
global_lock = Lock()
|
92 |
+
|
93 |
+
|
94 |
+
def batch_asr(model, audios, sr):
|
95 |
+
return batch_asr_internal(model, audios, sr)
|
96 |
+
|
97 |
+
|
98 |
+
def is_chinese(text):
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def calculate_wer(text1, text2, debug=False):
|
103 |
+
chars1 = remove_punctuation(text1)
|
104 |
+
chars2 = remove_punctuation(text2)
|
105 |
+
|
106 |
+
m, n = len(chars1), len(chars2)
|
107 |
+
|
108 |
+
if m > n:
|
109 |
+
chars1, chars2 = chars2, chars1
|
110 |
+
m, n = n, m
|
111 |
+
|
112 |
+
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
113 |
+
curr = [0] * (m + 1)
|
114 |
+
|
115 |
+
for j in range(1, n + 1):
|
116 |
+
curr[0] = j
|
117 |
+
for i in range(1, m + 1):
|
118 |
+
if chars1[i - 1] == chars2[j - 1]:
|
119 |
+
curr[i] = prev[i - 1]
|
120 |
+
else:
|
121 |
+
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
122 |
+
prev, curr = curr, prev
|
123 |
+
|
124 |
+
edits = prev[m]
|
125 |
+
tot = max(len(chars1), len(chars2))
|
126 |
+
wer = edits / tot
|
127 |
+
|
128 |
+
if debug:
|
129 |
+
print(" gt: ", chars1)
|
130 |
+
print(" pred: ", chars2)
|
131 |
+
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
132 |
+
|
133 |
+
return wer
|
134 |
+
|
135 |
+
|
136 |
+
def remove_punctuation(text):
|
137 |
+
chinese_punctuation = (
|
138 |
+
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
139 |
+
'‛""„‟…‧﹏'
|
140 |
+
)
|
141 |
+
all_punctuation = string.punctuation + chinese_punctuation
|
142 |
+
translator = str.maketrans("", "", all_punctuation)
|
143 |
+
text_without_punctuation = text.translate(translator)
|
144 |
+
return text_without_punctuation
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
model = load_model()
|
149 |
+
audios = [
|
150 |
+
librosa.load("44100.wav", sr=44100)[0],
|
151 |
+
librosa.load("lengyue.wav", sr=44100)[0],
|
152 |
+
]
|
153 |
+
print(np.array(audios[0]))
|
154 |
+
print(batch_asr(model, audios, 44100))
|
155 |
+
|
156 |
+
start_time = time.time()
|
157 |
+
for _ in range(10):
|
158 |
+
print(batch_asr(model, audios, 44100))
|
159 |
+
print("Time taken:", time.time() - start_time)
|