Update custom_interface_app.py
Browse files- custom_interface_app.py +125 -128
custom_interface_app.py
CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
-
enc_out, logits, _ = self.mods.whisper(wavs
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
@@ -128,30 +128,53 @@ class ASR(Pretrained):
|
|
128 |
|
129 |
def classify_file_w2v2(self, waveform, device):
|
130 |
# Load the audio file
|
|
|
131 |
# waveform, sr = librosa.load(path, sr=16000)
|
132 |
|
|
|
|
|
|
|
133 |
# Get audio length in seconds
|
134 |
-
|
|
|
135 |
|
136 |
-
if audio_length >=
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
segments = []
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
end = len(waveform)
|
147 |
segment_part = waveform[start:end]
|
148 |
-
segment_len = len(segment_part) / 16000
|
149 |
-
if segment_len < 1:
|
150 |
-
continue
|
151 |
-
segments.append(segment_part)
|
152 |
-
start = end
|
153 |
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
segment_tensor = torch.tensor(segment).to(device)
|
156 |
|
157 |
# Fake a batch for the segment
|
@@ -159,171 +182,145 @@ class ASR(Pretrained):
|
|
159 |
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
160 |
|
161 |
# Pass the segment through the ASR model
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
segments = ""
|
167 |
-
for segment in all_segments:
|
168 |
-
segment = segment[0]
|
169 |
-
segments += segment + " "
|
170 |
-
return [segments]
|
171 |
else:
|
172 |
waveform = torch.tensor(waveform).to(device)
|
173 |
waveform = waveform.to(device)
|
174 |
# Fake a batch:
|
175 |
batch = waveform.unsqueeze(0)
|
176 |
rel_length = torch.tensor([1.0]).to(device)
|
177 |
-
outputs = self.encode_batch_w2v2(device, batch, rel_length)
|
178 |
-
return [
|
179 |
-
|
180 |
|
181 |
def classify_file_whisper_mkd(self, waveform, device):
|
182 |
# Load the audio file
|
|
|
183 |
# waveform, sr = librosa.load(path, sr=16000)
|
184 |
|
|
|
|
|
|
|
185 |
# Get audio length in seconds
|
186 |
-
|
|
|
187 |
|
188 |
-
if audio_length >=
|
189 |
-
|
|
|
|
|
|
|
|
|
190 |
segments = []
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
start = 0
|
195 |
-
for i in range(num_segments):
|
196 |
-
end = start + max_duration
|
197 |
-
if end > len(waveform):
|
198 |
-
end = len(waveform)
|
199 |
-
segment_part = waveform[start:end]
|
200 |
-
segment_len = len(segment_part) / 16000
|
201 |
-
if segment_len < 1:
|
202 |
-
continue
|
203 |
-
segments.append(segment_part)
|
204 |
-
start = end
|
205 |
|
206 |
-
for
|
207 |
-
|
|
|
208 |
|
209 |
-
#
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
# segment_output = [" ".join(segment) for segment in segment_output]
|
216 |
-
all_segments.append(segment_output)
|
217 |
-
|
218 |
-
segments = ""
|
219 |
-
for segment in all_segments:
|
220 |
-
segment = segment[0]
|
221 |
-
segments += segment + " "
|
222 |
-
return [segments]
|
223 |
-
else:
|
224 |
-
waveform = torch.tensor(waveform).to(device)
|
225 |
-
waveform = waveform.to(device)
|
226 |
-
batch = waveform.unsqueeze(0)
|
227 |
-
rel_length = torch.tensor([1.0]).to(device)
|
228 |
-
outputs = self.encode_batch_whisper(device, batch, rel_length)
|
229 |
-
return outputs
|
230 |
|
|
|
|
|
|
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
if audio_length >= 30:
|
240 |
-
# split audio every 30 seconds
|
241 |
-
segments = []
|
242 |
-
max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
|
243 |
-
num_segments = int(np.ceil(len(waveform) / max_duration))
|
244 |
-
start = 0
|
245 |
-
for i in range(num_segments):
|
246 |
-
end = start + max_duration
|
247 |
-
if end > len(waveform):
|
248 |
-
end = len(waveform)
|
249 |
-
segment_part = waveform[start:end]
|
250 |
-
segment_len = len(segment_part) / 16000
|
251 |
-
if segment_len < 1:
|
252 |
-
continue
|
253 |
-
segments.append(segment_part)
|
254 |
-
start = end
|
255 |
|
256 |
-
for segment in segments:
|
257 |
segment_tensor = torch.tensor(segment).to(device)
|
258 |
|
259 |
# Fake a batch for the segment
|
260 |
batch = segment_tensor.unsqueeze(0).to(device)
|
261 |
-
rel_length = torch.tensor([1.0]).to(device)
|
262 |
|
263 |
# Pass the segment through the ASR model
|
264 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
265 |
-
|
|
|
266 |
else:
|
267 |
waveform = torch.tensor(waveform).to(device)
|
268 |
waveform = waveform.to(device)
|
|
|
269 |
batch = waveform.unsqueeze(0)
|
270 |
rel_length = torch.tensor([1.0]).to(device)
|
271 |
-
outputs
|
272 |
-
|
273 |
|
274 |
|
275 |
-
def classify_file_whisper(self,
|
276 |
-
|
277 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
278 |
-
return
|
279 |
|
280 |
|
281 |
-
def classify_file_mms(self,
|
282 |
# Load the audio file
|
283 |
-
|
284 |
|
285 |
# Get audio length in seconds
|
286 |
-
audio_length = len(waveform) /
|
287 |
|
288 |
-
if audio_length >=
|
289 |
-
|
|
|
|
|
|
|
290 |
segments = []
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
end = len(waveform)
|
299 |
segment_part = waveform[start:end]
|
300 |
-
segment_len = len(segment_part) / 16000
|
301 |
-
if segment_len < 1:
|
302 |
-
continue
|
303 |
-
segments.append(segment_part)
|
304 |
-
start = end
|
305 |
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
segment_tensor = torch.tensor(segment).to(device)
|
308 |
|
309 |
# Pass the segment through the ASR model
|
310 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
311 |
-
inputs['input_values'] = inputs['input_values']
|
312 |
outputs = model(**inputs).logits
|
313 |
ids = torch.argmax(outputs, dim=-1)[0]
|
314 |
segment_output = processor.decode(ids)
|
315 |
-
|
316 |
-
all_segments.append(segment_output)
|
317 |
-
|
318 |
-
segments = ""
|
319 |
-
for segment in all_segments:
|
320 |
-
segments += segment + " "
|
321 |
-
return [segments]
|
322 |
else:
|
323 |
waveform = torch.tensor(waveform).to(device)
|
324 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
325 |
-
inputs['input_values'] = inputs['input_values']
|
326 |
outputs = model(**inputs).logits
|
327 |
ids = torch.argmax(outputs, dim=-1)[0]
|
328 |
transcription = processor.decode(ids)
|
329 |
-
|
|
|
41 |
# Forward encoder + decoder
|
42 |
tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
|
43 |
tokens = tokens.to(device)
|
44 |
+
enc_out, logits, _ = self.mods.whisper(wavs, tokens)
|
45 |
log_probs = self.hparams.log_softmax(logits)
|
46 |
|
47 |
hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
|
|
|
128 |
|
129 |
def classify_file_w2v2(self, waveform, device):
|
130 |
# Load the audio file
|
131 |
+
# path = "long_sample.wav"
|
132 |
# waveform, sr = librosa.load(path, sr=16000)
|
133 |
|
134 |
+
# increase the volume if needed
|
135 |
+
# waveform = self.increase_volume(waveform)
|
136 |
+
|
137 |
# Get audio length in seconds
|
138 |
+
sr = 16000
|
139 |
+
audio_length = len(waveform) / sr
|
140 |
|
141 |
+
if audio_length >= 20:
|
142 |
+
print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
|
143 |
+
# Detect non-silent segments
|
144 |
+
|
145 |
+
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
|
146 |
+
|
147 |
segments = []
|
148 |
+
current_segment = []
|
149 |
+
current_length = 0
|
150 |
+
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
151 |
+
|
152 |
+
|
153 |
+
for interval in non_silent_intervals:
|
154 |
+
start, end = interval
|
|
|
155 |
segment_part = waveform[start:end]
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
+
# If adding the next part exceeds max duration, store the segment and start a new one
|
158 |
+
if current_length + len(segment_part) > max_duration:
|
159 |
+
segments.append(np.concatenate(current_segment))
|
160 |
+
current_segment = []
|
161 |
+
current_length = 0
|
162 |
+
|
163 |
+
current_segment.append(segment_part)
|
164 |
+
current_length += len(segment_part)
|
165 |
+
|
166 |
+
# Append the last segment if it's not empty
|
167 |
+
if current_segment:
|
168 |
+
segments.append(np.concatenate(current_segment))
|
169 |
+
|
170 |
+
# Process each segment
|
171 |
+
outputs = []
|
172 |
+
for i, segment in enumerate(segments):
|
173 |
+
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
174 |
+
|
175 |
+
# import soundfile as sf
|
176 |
+
# sf.write(f"outputs/segment_{i}.wav", segment, sr)
|
177 |
+
|
178 |
segment_tensor = torch.tensor(segment).to(device)
|
179 |
|
180 |
# Fake a batch for the segment
|
|
|
182 |
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
183 |
|
184 |
# Pass the segment through the ASR model
|
185 |
+
result = " ".join(self.encode_batch_w2v2(device, batch, rel_length)[0])
|
186 |
+
outputs.append(result)
|
187 |
+
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
else:
|
189 |
waveform = torch.tensor(waveform).to(device)
|
190 |
waveform = waveform.to(device)
|
191 |
# Fake a batch:
|
192 |
batch = waveform.unsqueeze(0)
|
193 |
rel_length = torch.tensor([1.0]).to(device)
|
194 |
+
outputs = " ".join(self.encode_batch_w2v2(device, batch, rel_length)[0])
|
195 |
+
return [outputs]
|
196 |
+
|
197 |
|
198 |
def classify_file_whisper_mkd(self, waveform, device):
|
199 |
# Load the audio file
|
200 |
+
# path = "long_sample.wav"
|
201 |
# waveform, sr = librosa.load(path, sr=16000)
|
202 |
|
203 |
+
# increase the volume if needed
|
204 |
+
# waveform = self.increase_volume(waveform)
|
205 |
+
|
206 |
# Get audio length in seconds
|
207 |
+
sr = 16000
|
208 |
+
audio_length = len(waveform) / sr
|
209 |
|
210 |
+
if audio_length >= 20:
|
211 |
+
print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
|
212 |
+
# Detect non-silent segments
|
213 |
+
|
214 |
+
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
|
215 |
+
|
216 |
segments = []
|
217 |
+
current_segment = []
|
218 |
+
current_length = 0
|
219 |
+
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
for interval in non_silent_intervals:
|
222 |
+
start, end = interval
|
223 |
+
segment_part = waveform[start:end]
|
224 |
|
225 |
+
# If adding the next part exceeds max duration, store the segment and start a new one
|
226 |
+
if current_length + len(segment_part) > max_duration:
|
227 |
+
segments.append(np.concatenate(current_segment))
|
228 |
+
current_segment = []
|
229 |
+
current_length = 0
|
230 |
|
231 |
+
current_segment.append(segment_part)
|
232 |
+
current_length += len(segment_part)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
# Append the last segment if it's not empty
|
235 |
+
if current_segment:
|
236 |
+
segments.append(np.concatenate(current_segment))
|
237 |
|
238 |
+
# Process each segment
|
239 |
+
outputs = []
|
240 |
+
for i, segment in enumerate(segments):
|
241 |
+
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
242 |
|
243 |
+
# import soundfile as sf
|
244 |
+
# sf.write(f"outputs/segment_{i}.wav", segment, sr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
|
|
246 |
segment_tensor = torch.tensor(segment).to(device)
|
247 |
|
248 |
# Fake a batch for the segment
|
249 |
batch = segment_tensor.unsqueeze(0).to(device)
|
250 |
+
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
|
251 |
|
252 |
# Pass the segment through the ASR model
|
253 |
segment_output = self.encode_batch_whisper(device, batch, rel_length)
|
254 |
+
outputs.append(segment_output)
|
255 |
+
return outputs
|
256 |
else:
|
257 |
waveform = torch.tensor(waveform).to(device)
|
258 |
waveform = waveform.to(device)
|
259 |
+
# Fake a batch:
|
260 |
batch = waveform.unsqueeze(0)
|
261 |
rel_length = torch.tensor([1.0]).to(device)
|
262 |
+
outputs.append(self.encode_batch_whisper(device, batch, rel_length))
|
263 |
+
return outputs
|
264 |
|
265 |
|
266 |
+
def classify_file_whisper(self, path, pipe, device):
|
267 |
+
waveform, sr = librosa.load(path, sr=16000)
|
268 |
transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
|
269 |
+
return transcription
|
270 |
|
271 |
|
272 |
+
def classify_file_mms(self, path, processor, model, device):
|
273 |
# Load the audio file
|
274 |
+
waveform, sr = librosa.load(path, sr=16000)
|
275 |
|
276 |
# Get audio length in seconds
|
277 |
+
audio_length = len(waveform) / sr
|
278 |
|
279 |
+
if audio_length >= 20:
|
280 |
+
print(f"MMS Audio is too long ({audio_length:.2f} seconds), splitting into segments")
|
281 |
+
# Detect non-silent segments
|
282 |
+
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
|
283 |
+
|
284 |
segments = []
|
285 |
+
current_segment = []
|
286 |
+
current_length = 0
|
287 |
+
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
|
288 |
+
|
289 |
+
|
290 |
+
for interval in non_silent_intervals:
|
291 |
+
start, end = interval
|
|
|
292 |
segment_part = waveform[start:end]
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
+
# If adding the next part exceeds max duration, store the segment and start a new one
|
295 |
+
if current_length + len(segment_part) > max_duration:
|
296 |
+
segments.append(np.concatenate(current_segment))
|
297 |
+
current_segment = []
|
298 |
+
current_length = 0
|
299 |
+
|
300 |
+
current_segment.append(segment_part)
|
301 |
+
current_length += len(segment_part)
|
302 |
+
|
303 |
+
# Append the last segment if it's not empty
|
304 |
+
if current_segment:
|
305 |
+
segments.append(np.concatenate(current_segment))
|
306 |
+
|
307 |
+
# Process each segment
|
308 |
+
outputs = []
|
309 |
+
for i, segment in enumerate(segments):
|
310 |
+
print(f"MMS Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
|
311 |
+
|
312 |
segment_tensor = torch.tensor(segment).to(device)
|
313 |
|
314 |
# Pass the segment through the ASR model
|
315 |
inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
316 |
outputs = model(**inputs).logits
|
317 |
ids = torch.argmax(outputs, dim=-1)[0]
|
318 |
segment_output = processor.decode(ids)
|
319 |
+
yield segment_output
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
else:
|
321 |
waveform = torch.tensor(waveform).to(device)
|
322 |
inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
|
|
|
323 |
outputs = model(**inputs).logits
|
324 |
ids = torch.argmax(outputs, dim=-1)[0]
|
325 |
transcription = processor.decode(ids)
|
326 |
+
yield transcription
|