uto1125 commited on
Commit
db755d4
·
verified ·
1 Parent(s): 8b8629f

Upload auto_rerank.py

Browse files
Files changed (1) hide show
  1. tools/auto_rerank.py +159 -0
tools/auto_rerank.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["MODELSCOPE_CACHE"] = ".cache/"
4
+
5
+ import string
6
+ import time
7
+ from threading import Lock
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import opencc
12
+ import torch
13
+ from faster_whisper import WhisperModel
14
+
15
+ t2s_converter = opencc.OpenCC("t2s")
16
+
17
+
18
+ def load_model(*, device="cuda"):
19
+ model = WhisperModel(
20
+ "medium",
21
+ device=device,
22
+ compute_type="float16",
23
+ download_root="faster_whisper",
24
+ )
25
+ print("faster_whisper loaded!")
26
+ return model
27
+
28
+
29
+ @torch.no_grad()
30
+ def batch_asr_internal(model: WhisperModel, audios, sr):
31
+ resampled_audios = []
32
+ for audio in audios:
33
+
34
+ if isinstance(audio, np.ndarray):
35
+ audio = torch.from_numpy(audio).float()
36
+
37
+ if audio.dim() > 1:
38
+ audio = audio.squeeze()
39
+
40
+ assert audio.dim() == 1
41
+ audio_np = audio.numpy()
42
+ resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
43
+ resampled_audios.append(resampled_audio)
44
+
45
+ trans_results = []
46
+
47
+ for resampled_audio in resampled_audios:
48
+ segments, info = model.transcribe(
49
+ resampled_audio,
50
+ language=None,
51
+ beam_size=5,
52
+ initial_prompt="Punctuation is needed in any language.",
53
+ )
54
+ trans_results.append(list(segments))
55
+
56
+ results = []
57
+ for trans_res, audio in zip(trans_results, audios):
58
+
59
+ duration = len(audio) / sr * 1000
60
+ huge_gap = False
61
+ max_gap = 0.0
62
+
63
+ text = None
64
+ last_tr = None
65
+
66
+ for tr in trans_res:
67
+ delta = tr.text.strip()
68
+ if tr.id > 1:
69
+ max_gap = max(tr.start - last_tr.end, max_gap)
70
+ text += delta
71
+ else:
72
+ text = delta
73
+
74
+ last_tr = tr
75
+ if max_gap > 3.0:
76
+ huge_gap = True
77
+ break
78
+
79
+ sim_text = t2s_converter.convert(text)
80
+ results.append(
81
+ {
82
+ "text": sim_text,
83
+ "duration": duration,
84
+ "huge_gap": huge_gap,
85
+ }
86
+ )
87
+
88
+ return results
89
+
90
+
91
+ global_lock = Lock()
92
+
93
+
94
+ def batch_asr(model, audios, sr):
95
+ return batch_asr_internal(model, audios, sr)
96
+
97
+
98
+ def is_chinese(text):
99
+ return True
100
+
101
+
102
+ def calculate_wer(text1, text2, debug=False):
103
+ chars1 = remove_punctuation(text1)
104
+ chars2 = remove_punctuation(text2)
105
+
106
+ m, n = len(chars1), len(chars2)
107
+
108
+ if m > n:
109
+ chars1, chars2 = chars2, chars1
110
+ m, n = n, m
111
+
112
+ prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
113
+ curr = [0] * (m + 1)
114
+
115
+ for j in range(1, n + 1):
116
+ curr[0] = j
117
+ for i in range(1, m + 1):
118
+ if chars1[i - 1] == chars2[j - 1]:
119
+ curr[i] = prev[i - 1]
120
+ else:
121
+ curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
122
+ prev, curr = curr, prev
123
+
124
+ edits = prev[m]
125
+ tot = max(len(chars1), len(chars2))
126
+ wer = edits / tot
127
+
128
+ if debug:
129
+ print(" gt: ", chars1)
130
+ print(" pred: ", chars2)
131
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
132
+
133
+ return wer
134
+
135
+
136
+ def remove_punctuation(text):
137
+ chinese_punctuation = (
138
+ " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
139
+ '‛""„‟…‧﹏'
140
+ )
141
+ all_punctuation = string.punctuation + chinese_punctuation
142
+ translator = str.maketrans("", "", all_punctuation)
143
+ text_without_punctuation = text.translate(translator)
144
+ return text_without_punctuation
145
+
146
+
147
+ if __name__ == "__main__":
148
+ model = load_model()
149
+ audios = [
150
+ librosa.load("44100.wav", sr=44100)[0],
151
+ librosa.load("lengyue.wav", sr=44100)[0],
152
+ ]
153
+ print(np.array(audios[0]))
154
+ print(batch_asr(model, audios, 44100))
155
+
156
+ start_time = time.time()
157
+ for _ in range(10):
158
+ print(batch_asr(model, audios, 44100))
159
+ print("Time taken:", time.time() - start_time)