terry-li-hm commited on
Commit
a76b03e
·
1 Parent(s): a08029c
Files changed (1) hide show
  1. sv.py +357 -0
sv.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import math
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ from funasr import AutoModel
9
+ from pyannote.audio import Audio, Pipeline
10
+ from pyannote.core import Segment
11
+
12
+ # Load models
13
+ model = AutoModel(
14
+ model="FunAudioLLM/SenseVoiceSmall",
15
+ # vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
16
+ # vad_kwargs={"max_single_segment_time": 30000},
17
+ hub="hf",
18
+ device="cuda" if torch.cuda.is_available() else "cpu",
19
+ )
20
+
21
+ pyannote_pipeline = Pipeline.from_pretrained(
22
+ "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN")
23
+ )
24
+ if torch.cuda.is_available():
25
+ pyannote_pipeline.to(torch.device("cuda"))
26
+
27
+ # Emoji dictionaries and formatting functions
28
+ emo_dict = {
29
+ "<|HAPPY|>": "😊",
30
+ "<|SAD|>": "😔",
31
+ "<|ANGRY|>": "😡",
32
+ "<|NEUTRAL|>": "",
33
+ "<|FEARFUL|>": "😰",
34
+ "<|DISGUSTED|>": "🤢",
35
+ "<|SURPRISED|>": "😮",
36
+ }
37
+
38
+ event_dict = {
39
+ "<|BGM|>": "🎼",
40
+ "<|Speech|>": "",
41
+ "<|Applause|>": "👏",
42
+ "<|Laughter|>": "😀",
43
+ "<|Cry|>": "😭",
44
+ "<|Sneeze|>": "🤧",
45
+ "<|Breath|>": "",
46
+ "<|Cough|>": "🤧",
47
+ }
48
+
49
+ emoji_dict = {
50
+ "<|nospeech|><|Event_UNK|>": "❓",
51
+ "<|zh|>": "",
52
+ "<|en|>": "",
53
+ "<|yue|>": "",
54
+ "<|ja|>": "",
55
+ "<|ko|>": "",
56
+ "<|nospeech|>": "",
57
+ "<|HAPPY|>": "😊",
58
+ "<|SAD|>": "😔",
59
+ "<|ANGRY|>": "😡",
60
+ "<|NEUTRAL|>": "",
61
+ "<|BGM|>": "🎼",
62
+ "<|Speech|>": "",
63
+ "<|Applause|>": "👏",
64
+ "<|Laughter|>": "😀",
65
+ "<|FEARFUL|>": "😰",
66
+ "<|DISGUSTED|>": "🤢",
67
+ "<|SURPRISED|>": "😮",
68
+ "<|Cry|>": "😭",
69
+ "<|EMO_UNKNOWN|>": "",
70
+ "<|Sneeze|>": "🤧",
71
+ "<|Breath|>": "",
72
+ "<|Cough|>": "😷",
73
+ "<|Sing|>": "",
74
+ "<|Speech_Noise|>": "",
75
+ "<|withitn|>": "",
76
+ "<|woitn|>": "",
77
+ "<|GBG|>": "",
78
+ "<|Event_UNK|>": "",
79
+ }
80
+
81
+ lang_dict = {
82
+ "<|zh|>": "<|lang|>",
83
+ "<|en|>": "<|lang|>",
84
+ "<|yue|>": "<|lang|>",
85
+ "<|ja|>": "<|lang|>",
86
+ "<|ko|>": "<|lang|>",
87
+ "<|nospeech|>": "<|lang|>",
88
+ }
89
+
90
+ emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
91
+ event_set = {"🎼", "👏", "😀", "😭", "🤧", "😷"}
92
+
93
+
94
+ def format_str(s):
95
+ for sptk in emoji_dict:
96
+ s = s.replace(sptk, emoji_dict[sptk])
97
+ return s
98
+
99
+
100
+ def format_str_v2(s):
101
+ sptk_dict = {}
102
+ for sptk in emoji_dict:
103
+ sptk_dict[sptk] = s.count(sptk)
104
+ s = s.replace(sptk, "")
105
+ emo = "<|NEUTRAL|>"
106
+ for e in emo_dict:
107
+ if sptk_dict[e] > sptk_dict[emo]:
108
+ emo = e
109
+ for e in event_dict:
110
+ if sptk_dict[e] > 0:
111
+ s = event_dict[e] + s
112
+ s = s + emo_dict[emo]
113
+
114
+ for emoji in emo_set.union(event_set):
115
+ s = s.replace(" " + emoji, emoji)
116
+ s = s.replace(emoji + " ", emoji)
117
+ return s.strip()
118
+
119
+
120
+ def format_str_v3(s):
121
+ def get_emo(s):
122
+ return s[-1] if s[-1] in emo_set else None
123
+
124
+ def get_event(s):
125
+ return s[0] if s[0] in event_set else None
126
+
127
+ s = s.replace("<|nospeech|><|Event_UNK|>", "❓")
128
+ for lang in lang_dict:
129
+ s = s.replace(lang, "<|lang|>")
130
+ s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
131
+ new_s = " " + s_list[0]
132
+ cur_ent_event = get_event(new_s)
133
+ for i in range(1, len(s_list)):
134
+ if len(s_list[i]) == 0:
135
+ continue
136
+ if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
137
+ s_list[i] = s_list[i][1:]
138
+ # else:
139
+ cur_ent_event = get_event(s_list[i])
140
+ if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
141
+ new_s = new_s[:-1]
142
+ new_s += s_list[i].strip().lstrip()
143
+ new_s = new_s.replace("The.", " ")
144
+ return new_s.strip()
145
+
146
+
147
+ def time_to_seconds(time_str):
148
+ h, m, s = time_str.split(":")
149
+ return round(int(h) * 3600 + int(m) * 60 + float(s), 9)
150
+
151
+
152
+ import datetime
153
+
154
+
155
+ def parse_time(time_str):
156
+ # Remove 's' if present at the end of the string
157
+ time_str = time_str.rstrip("s")
158
+
159
+ # Split the time string into hours, minutes, and seconds
160
+ parts = time_str.split(":")
161
+
162
+ if len(parts) == 3:
163
+ h, m, s = parts
164
+ elif len(parts) == 2:
165
+ h = "0"
166
+ m, s = parts
167
+ else:
168
+ h = m = "0"
169
+ s = parts[0]
170
+
171
+ return int(h) * 3600 + int(m) * 60 + float(s)
172
+
173
+
174
+ def format_time(seconds, use_short_format=True):
175
+ if isinstance(seconds, datetime.timedelta):
176
+ seconds = seconds.total_seconds()
177
+
178
+ minutes, seconds = divmod(seconds, 60)
179
+ hours, minutes = divmod(int(minutes), 60)
180
+
181
+ if use_short_format or (hours == 0 and minutes == 0):
182
+ return f"{seconds:05.3f}s"
183
+ elif hours == 0:
184
+ return f"{minutes:02d}:{seconds:06.3f}"
185
+ else:
186
+ return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
187
+
188
+
189
+ def format_time_with_leading_zeros(seconds):
190
+ formatted = f"{seconds:06.3f}s"
191
+ print(f"Debug: Input seconds: {seconds}, Formatted output: {formatted}")
192
+ return formatted
193
+
194
+
195
+ def generate_diarization(audio_path):
196
+ # Get the Hugging Face token from the environment variable
197
+ hf_token = os.environ.get("HF_TOKEN")
198
+ if not hf_token:
199
+ raise ValueError(
200
+ "HF_TOKEN environment variable is not set. Please set it with your Hugging Face token."
201
+ )
202
+
203
+ # Initialize the audio processor
204
+ audio = Audio(sample_rate=16000, mono=True)
205
+
206
+ # Load the pretrained pipeline
207
+ pipeline = Pipeline.from_pretrained(
208
+ "pyannote/speaker-diarization-3.1", use_auth_token=hf_token
209
+ )
210
+
211
+ # Send pipeline to GPU if available
212
+ if torch.cuda.is_available():
213
+ pipeline.to(torch.device("cuda"))
214
+
215
+ # Set the correct path for the audio file
216
+ script_dir = os.path.dirname(os.path.abspath(__file__))
217
+ possible_paths = [
218
+ os.path.join(script_dir, "example", "mtr.mp3"),
219
+ os.path.join(script_dir, "..", "example", "mtr.mp3"),
220
+ os.path.join(script_dir, "mtr.mp3"),
221
+ "mtr.mp3",
222
+ audio_path, # Add the provided audio_path to the list of possible paths
223
+ ]
224
+
225
+ file_path = None
226
+ for path in possible_paths:
227
+ if os.path.exists(path):
228
+ file_path = path
229
+ break
230
+
231
+ if file_path is None:
232
+ print("Debugging information:")
233
+ print(f"Current working directory: {os.getcwd()}")
234
+ print(f"Script directory: {script_dir}")
235
+ print("Attempted paths:")
236
+ for path in possible_paths:
237
+ print(f" {path}")
238
+ raise FileNotFoundError(
239
+ "Could not find the audio file. Please ensure it's in the correct location."
240
+ )
241
+
242
+ print(f"Using audio file: {file_path}")
243
+
244
+ # Process the audio file
245
+ waveform, sample_rate = audio(file_path)
246
+
247
+ # Create a dictionary with the audio information
248
+ file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"}
249
+
250
+ # Run the diarization
251
+ output = pipeline(file)
252
+
253
+ # Save results in human-readable format
254
+ diarization_segments = []
255
+ txt_file = "mtr_dn.txt"
256
+ with open(txt_file, "w") as f:
257
+ for turn, _, speaker in output.itertracks(yield_label=True):
258
+ start_time = format_time(turn.start)
259
+ end_time = format_time(turn.end)
260
+ duration = format_time(turn.end - turn.start)
261
+ line = f"{start_time} - {end_time} ({duration}): {speaker}\n"
262
+ f.write(line)
263
+ print(line.strip())
264
+ diarization_segments.append(
265
+ (
266
+ parse_time(start_time),
267
+ parse_time(end_time),
268
+ parse_time(duration),
269
+ speaker,
270
+ )
271
+ )
272
+
273
+ print(f"\nHuman-readable diarization results saved to {txt_file}")
274
+ return diarization_segments
275
+
276
+
277
+ def process_audio(audio_path, language="yue", fs=16000):
278
+ # Generate diarization segments
279
+ diarization_segments = generate_diarization(audio_path)
280
+
281
+ # Load and preprocess audio
282
+ waveform, sample_rate = torchaudio.load(audio_path)
283
+ if sample_rate != fs:
284
+ resampler = torchaudio.transforms.Resample(sample_rate, fs)
285
+ waveform = resampler(waveform)
286
+
287
+ input_wav = waveform.mean(0).numpy()
288
+
289
+ # Determine if the audio is less than one minute
290
+ total_duration = sum(duration for _, _, duration, _ in diarization_segments)
291
+ use_short_format = total_duration < 60
292
+
293
+ # Process the audio in chunks based on diarization segments
294
+ results = []
295
+ for start_time, end_time, duration, speaker in diarization_segments:
296
+ start_seconds = start_time
297
+ end_seconds = end_time
298
+
299
+ # Convert time to sample indices
300
+ start_sample = int(start_seconds * fs)
301
+ end_sample = int(end_seconds * fs)
302
+
303
+ chunk = input_wav[start_sample:end_sample]
304
+ try:
305
+ text = model.generate(
306
+ input=chunk,
307
+ cache={},
308
+ language=language,
309
+ use_itn=True,
310
+ batch_size_s=500,
311
+ merge_vad=True,
312
+ )
313
+ text = text[0]["text"]
314
+ text = format_str_v3(text)
315
+
316
+ # Handle empty transcriptions
317
+ if not text.strip():
318
+ text = "[inaudible]"
319
+
320
+ results.append((speaker, start_time, end_time, duration, text))
321
+ except AssertionError as e:
322
+ if "choose a window size" in str(e):
323
+ print(
324
+ f"Warning: Audio segment too short to process. Skipping. Error: {e}"
325
+ )
326
+ results.append((speaker, start_time, end_time, duration, "[too short]"))
327
+ else:
328
+ raise
329
+
330
+ # Format the results
331
+ formatted_text = ""
332
+ for speaker, start, end, duration, text in results:
333
+ start_str = format_time_with_leading_zeros(start)
334
+ end_str = format_time_with_leading_zeros(end)
335
+ duration_str = format_time_with_leading_zeros(duration)
336
+ speaker_num = "1" if speaker == "SPEAKER_00" else "2"
337
+ line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}"
338
+ formatted_text += line + "\n"
339
+ print(f"Debug: Formatted line: {line}")
340
+
341
+ print("Debug: Full formatted text:")
342
+ print(formatted_text)
343
+ return formatted_text.strip()
344
+
345
+
346
+ if __name__ == "__main__":
347
+ audio_path = "example/mtr.mp3" # Replace with your audio file path
348
+ language = "yue" # Set language to Cantonese
349
+
350
+ result = process_audio(audio_path, language)
351
+
352
+ # Save the result to mtr.txt
353
+ output_path = "mtr.txt"
354
+ with open(output_path, "w", encoding="utf-8") as f:
355
+ f.write(result)
356
+
357
+ print(f"Diarization and transcription result has been saved to {output_path}")