federicocosta1989 commited on
Commit
da53b6a
·
verified ·
1 Parent(s): f8893d8

Upload 4 files

Browse files

Added dev scripts

Files changed (4) hide show
  1. 1s_audio.wav +0 -0
  2. requirements_dev.txt +17 -0
  3. settings.py +7 -0
  4. whisper_cs_dev.py +318 -0
1s_audio.wav ADDED
Binary file (33.4 kB). View file
 
requirements_dev.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ numpy<2
3
+ hf_transfer
4
+ torch
5
+ pyannote.audio
6
+ yt-dlp
7
+ gradio==5.15.0
8
+ torchaudio==2.2.1
9
+ librosa==0.10.1
10
+ ffmpeg-python==0.2.0
11
+ aina-gradio-theme==2.3
12
+ spaces
13
+ peft==0.11.1
14
+ whisper_timestamped
15
+ typing
16
+ faster_whisper
17
+ ctranslate2==4.4.0
settings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DEBUG_MODE = True
2
+ MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs"
3
+ MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs"
4
+ LEFT_CHANNEL_TEMP_PATH = "temp_mono_speaker2.wav"
5
+ RIGHT_CHANNEL_TEMP_PATH = "temp_mono_speaker1.wav"
6
+ RESAMPLING_FREQ = 16000
7
+ FAKE_AUDIO_PATH = "1s_audio.wav"
whisper_cs_dev.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydub import AudioSegment
2
+ import os
3
+ import torchaudio
4
+ import torch
5
+ import re
6
+ import whisper_timestamped as whisper_ts
7
+ from faster_whisper import WhisperModel
8
+ from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V2, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, FAKE_AUDIO_PATH, RESAMPLING_FREQ
9
+
10
+
11
+ def get_settings():
12
+
13
+ if DEBUG_MODE: print(f"Entering get_settings function...")
14
+
15
+ is_cuda_available = torch.cuda.is_available()
16
+ if is_cuda_available:
17
+ device = "cuda"
18
+ compute_type = "float16"
19
+ else:
20
+ device = "cpu"
21
+ compute_type = "int8"
22
+ if DEBUG_MODE: print(f"is_cuda_available: {is_cuda_available}")
23
+ if DEBUG_MODE: print(f"device: {device}")
24
+ if DEBUG_MODE: print(f"compute_type: {compute_type}")
25
+
26
+ if DEBUG_MODE: print(f"Exited get_settings function.")
27
+
28
+ return device, compute_type
29
+
30
+
31
+ def load_model(use_v2_fast, device, compute_type):
32
+
33
+ if DEBUG_MODE: print(f"Entering load_model function...")
34
+
35
+ if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
36
+
37
+ if use_v2_fast:
38
+ if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2_FAST} using {device} with {compute_type}...")
39
+ model = WhisperModel(
40
+ MODEL_PATH_V2_FAST,
41
+ device = device,
42
+ compute_type = compute_type,
43
+ )
44
+ else:
45
+ if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2} using {device} with {compute_type}...")
46
+ # TODO add compute_type to load model
47
+ model = whisper_ts.load_model(
48
+ MODEL_PATH_V2,
49
+ device = device,
50
+ )
51
+
52
+ # HACK we need to do this for strange reasons.
53
+ # If we don't do this, we get:
54
+ #Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory
55
+ fake_model = whisper_ts.load_model(MODEL_PATH_V2, device=device)
56
+
57
+ if DEBUG_MODE: print(f"Exiting load_model function...")
58
+
59
+ return model, fake_model
60
+
61
+
62
+ def split_input_stereo_channels(audio_path):
63
+
64
+ if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...")
65
+
66
+ ext = os.path.splitext(audio_path)[1].lower()
67
+
68
+ if ext == ".wav":
69
+ audio = AudioSegment.from_wav(audio_path)
70
+ elif ext == ".mp3":
71
+ audio = AudioSegment.from_file(audio_path, format="mp3")
72
+ else:
73
+ raise ValueError(f"Unsupported file format for: {audio_path}")
74
+
75
+ channels = audio.split_to_mono()
76
+
77
+ if len(channels) != 2:
78
+ raise ValueError(f"Audio {audio_path} has {len(channels)} channels (instead of 2).")
79
+
80
+ channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right
81
+ channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left
82
+
83
+ if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.")
84
+
85
+
86
+ def format_audio(audio_path):
87
+
88
+ if DEBUG_MODE: print(f"Entering format_audio function...")
89
+
90
+ input_audio, sample_rate = torchaudio.load(audio_path)
91
+
92
+ if input_audio.shape[0] == 2:
93
+ input_audio = torch.mean(input_audio, dim=0, keepdim=True)
94
+
95
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ)
96
+ input_audio = resampler(input_audio)
97
+ input_audio = input_audio.squeeze()
98
+
99
+ if DEBUG_MODE: print(f"Exited format_audio function.")
100
+
101
+ return input_audio, RESAMPLING_FREQ
102
+
103
+
104
+ def process_waveforms():
105
+
106
+ if DEBUG_MODE: print(f"Entering process_waveforms function...")
107
+
108
+ left_waveform, _ = format_audio(LEFT_CHANNEL_TEMP_PATH)
109
+ right_waveform, _ = format_audio(RIGHT_CHANNEL_TEMP_PATH)
110
+
111
+ # TODO should this be equal to compute_type?
112
+ left_waveform = left_waveform.numpy().astype("float32")
113
+ right_waveform = right_waveform.numpy().astype("float32")
114
+
115
+ if DEBUG_MODE: print(f"Exited process_waveforms function.")
116
+
117
+ return left_waveform, right_waveform
118
+
119
+
120
+ def transcribe_audio_no_fast_model(model, audio_path):
121
+
122
+ if DEBUG_MODE: print(f"Entering transcribe_audio_no_fast_model function...")
123
+
124
+ result = whisper_ts.transcribe(
125
+ model,
126
+ audio_path,
127
+ beam_size=5,
128
+ best_of=5,
129
+ temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
130
+ vad=False,
131
+ detect_disfluencies=True,
132
+ )
133
+
134
+ words = []
135
+ for segment in result.get('segments', []):
136
+ for word in segment.get('words', []):
137
+ word_text = word.get('word', '').strip()
138
+ if word_text.startswith(' '):
139
+ word_text = word_text[1:]
140
+
141
+ words.append({
142
+ 'word': word_text,
143
+ 'start': word.get('start', 0),
144
+ 'end': word.get('end', 0),
145
+ 'confidence': word.get('confidence', 0)
146
+ })
147
+
148
+ return {
149
+ 'audio_path': audio_path,
150
+ 'text': result['text'].strip(),
151
+ 'segments': result.get('segments', []),
152
+ 'words': words,
153
+ 'duration': result.get('duration', 0),
154
+ 'success': True
155
+ }
156
+
157
+ if DEBUG_MODE: print(f"Exited transcribe_audio_no_fast_model function.")
158
+
159
+
160
+ def transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model):
161
+
162
+ if DEBUG_MODE: print(f"Entering transcribe_channels function...")
163
+
164
+ # HACK we need to do this for strange reasons.
165
+ # If we don't do this, we get:
166
+ #Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory
167
+ fake_result = whisper_ts.transcribe(
168
+ fake_model,
169
+ FAKE_AUDIO_PATH,
170
+ beam_size=1,
171
+ )
172
+
173
+ if use_v2_fast:
174
+ left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe")
175
+ right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe")
176
+ left_result = list(left_result)
177
+ right_result = list(right_result)
178
+ else:
179
+ left_result = transcribe_audio_no_fast_model(model, left_waveform)
180
+ right_result = transcribe_audio_no_fast_model(model, right_waveform)
181
+
182
+ if DEBUG_MODE: print(f"Exited transcribe_channels function.")
183
+
184
+ return left_result, right_result
185
+
186
+
187
+ # TODO refactor and rename this function
188
+ def post_process_transcription(transcription, max_repeats=2):
189
+
190
+ tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription)
191
+
192
+ cleaned_tokens = []
193
+ repetition_count = 0
194
+ previous_token = None
195
+
196
+ for token in tokens:
197
+ reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token)
198
+
199
+ if reduced_token == previous_token:
200
+ repetition_count += 1
201
+ if repetition_count <= max_repeats:
202
+ cleaned_tokens.append(reduced_token)
203
+ else:
204
+ repetition_count = 1
205
+ cleaned_tokens.append(reduced_token)
206
+
207
+ previous_token = reduced_token
208
+
209
+ cleaned_transcription = " ".join(cleaned_tokens)
210
+ cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip()
211
+
212
+ return cleaned_transcription
213
+
214
+ # TODO not used right now, decide to use it or not
215
+ def post_merge_consecutive_segments_from_text(transcription_text: str) -> str:
216
+ segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text)
217
+ merged_transcription = ''
218
+ current_speaker = None
219
+ current_segment = []
220
+
221
+ for i in range(1, len(segments) - 1, 2):
222
+ speaker_tag = segments[i]
223
+ text = segments[i + 1].strip()
224
+
225
+ speaker = re.search(r'\d{2}', speaker_tag).group()
226
+
227
+ if speaker == current_speaker:
228
+ current_segment.append(text)
229
+ else:
230
+ if current_speaker is not None:
231
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
232
+ current_speaker = speaker
233
+ current_segment = [text]
234
+
235
+ if current_speaker is not None:
236
+ merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n'
237
+
238
+ return merged_transcription.strip()
239
+
240
+
241
+ def get_segments(result, speaker_label, use_v2_fast):
242
+
243
+ if DEBUG_MODE: print(f"Entering get_segments function...")
244
+
245
+ if use_v2_fast:
246
+ segments = result
247
+ final_segments = [
248
+ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
249
+ for seg in segments if seg.text
250
+ ]
251
+ else:
252
+ segments = result.get("segments", [])
253
+ if not segments:
254
+ final_segments = []
255
+ final_segments = [
256
+ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label,
257
+ post_process_transcription(seg.get("text", "").strip()))
258
+ for seg in segments if seg.get("text")
259
+ ]
260
+
261
+ if DEBUG_MODE: print(f"EXited get_segments function.")
262
+
263
+ return final_segments
264
+
265
+
266
+ def post_process_transcripts(left_result, right_result, use_v2_fast):
267
+
268
+ if DEBUG_MODE: print(f"Entering post_process_transcripts function...")
269
+
270
+ left_segs = get_segments(left_result, "Speaker 1", use_v2_fast)
271
+ right_segs = get_segments(right_result, "Speaker 2", use_v2_fast)
272
+
273
+ merged_transcript = sorted(
274
+ left_segs + right_segs,
275
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
276
+ )
277
+
278
+ clean_output = ""
279
+ for start, end, speaker, text in merged_transcript:
280
+ clean_output += f"[{speaker}]: {text}\n"
281
+ clean_output = clean_output.strip()
282
+
283
+ if DEBUG_MODE: print(f"Exited post_process_transcripts function.")
284
+
285
+ return clean_output
286
+
287
+
288
+ def cleanup_temp_files(*file_paths):
289
+
290
+ if DEBUG_MODE: print(f"Entered cleanup_temp_files function...")
291
+
292
+ if DEBUG_MODE: print(f"File paths to remove: {file_paths}")
293
+
294
+ for path in file_paths:
295
+ if path and os.path.exists(path):
296
+ if DEBUG_MODE: print(f"Removing path: {path}")
297
+ os.remove(path)
298
+
299
+ if DEBUG_MODE: print(f"Exited cleanup_temp_files function.")
300
+
301
+
302
+ def generate(audio_path, use_v2_fast):
303
+
304
+ if DEBUG_MODE: print(f"Entering generate function...")
305
+
306
+ device, compute_type = get_settings()
307
+ model, fake_model = load_model(use_v2_fast, device, compute_type)
308
+ split_input_stereo_channels(audio_path)
309
+ left_waveform, right_waveform = process_waveforms()
310
+ left_result, right_result = transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model)
311
+ output = post_process_transcripts(left_result, right_result, use_v2_fast)
312
+ cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH)
313
+
314
+ if DEBUG_MODE: print(f"Exited generate function.")
315
+
316
+ return output
317
+
318
+