roychao19477 commited on
Commit
e4058c9
·
1 Parent(s): 68a2244

Add limitations

Browse files
Files changed (1) hide show
  1. app.py +1 -142
app.py CHANGED
@@ -136,28 +136,10 @@ def extract_resampled_audio(video_path, target_sr=16000):
136
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
137
  return resampled_audio_path
138
 
139
- import cv2
140
- import ffmpeg
141
-
142
- def downsample_if_needed(video_path):
143
- cap = cv2.VideoCapture(video_path)
144
- width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
145
- height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
146
- cap.release()
147
-
148
- if max(width, height) > 720:
149
- downsampled_path = f"/tmp/downsampled_720p.mp4"
150
- ffmpeg.input(video_path).output(
151
- downsampled_path, vf="scale='min(720,iw)':-2", **{"c:v": "libx264"}
152
- ).overwrite_output().run()
153
- return downsampled_path
154
- return video_path
155
 
156
  @spaces.GPU
157
  def extract_faces(video_file):
158
- video_path = downsample_if_needed(video_file)
159
-
160
- cap = cv2.VideoCapture(video_path)
161
  fps = cap.get(cv2.CAP_PROP_FPS)
162
  frames = []
163
 
@@ -253,126 +235,3 @@ iface = gr.Interface(
253
  )
254
 
255
  iface.launch()
256
-
257
-
258
-
259
- ckpt = "ckpts/SEMamba_advanced.pth"
260
- cfg_f = "recipes/SEMamba_advanced.yaml"
261
-
262
- # load config
263
- with open(cfg_f, 'r') as f:
264
- cfg = yaml.safe_load(f)
265
-
266
-
267
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
268
- device = "cuda"
269
- model = SEMamba(cfg).to(device)
270
- #sdict = torch.load(ckpt, map_location=device)
271
- #model.load_state_dict(sdict["generator"])
272
- #model.eval()
273
-
274
- @spaces.GPU
275
- def enhance(filepath, model_name):
276
- # Load model based on selection
277
- ckpt_path = {
278
- "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
279
- "VCTK+DNS": "ckpts/vd.pth"
280
- }[model_name]
281
-
282
- print("Loading:", ckpt_path)
283
- model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
284
- model.eval()
285
- with torch.no_grad():
286
- # load & resample
287
- wav, orig_sr = librosa.load(filepath, sr=None)
288
- noisy_wav = wav.copy()
289
- if orig_sr != 16000:
290
- wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
291
- x = torch.from_numpy(wav).float().to(device)
292
- norm = torch.sqrt(len(x)/torch.sum(x**2))
293
- #x = (x * norm).unsqueeze(0)
294
- x = (x * norm)
295
-
296
- # split into 4s segments (64000 samples)
297
- segment_len = 4 * 16000
298
- chunks = x.split(segment_len)
299
- enhanced_chunks = []
300
-
301
- for chunk in chunks:
302
- if len(chunk) < segment_len:
303
- #pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
304
- pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
305
- chunk = torch.cat([chunk, pad])
306
- chunk = chunk.unsqueeze(0)
307
-
308
- amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
309
- amp2, pha2, _ = model(amp, pha)
310
- out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
311
- out = (out / norm).squeeze(0)
312
- enhanced_chunks.append(out)
313
-
314
- out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
315
-
316
- # back to original rate
317
- if orig_sr != 16000:
318
- out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
319
-
320
- # Normalize
321
- peak = np.max(np.abs(out))
322
- if peak > 0.05:
323
- out = out / peak * 0.85
324
-
325
- # write file
326
- sf.write("enhanced.wav", out, orig_sr)
327
-
328
- # spectrograms
329
- fig, axs = plt.subplots(1, 2, figsize=(16, 4))
330
-
331
- # noisy
332
- D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
333
- S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
334
- librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
335
- axs[0].set_title("Noisy Spectrogram")
336
-
337
- # enhanced
338
- D_clean = librosa.stft(out, n_fft=512, hop_length=256)
339
- S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
340
- librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
341
- #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
342
- axs[1].set_title("Enhanced Spectrogram")
343
-
344
- plt.tight_layout()
345
-
346
- return "enhanced.wav", fig
347
-
348
- #with gr.Blocks() as demo:
349
- # gr.Markdown(ABOUT)
350
- # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
351
- # enhance_btn = gr.Button("Enhance")
352
- # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
353
- # plot_output = gr.Plot(label="Spectrograms")
354
- #
355
- # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
356
- #
357
- #demo.queue().launch()
358
-
359
- with gr.Blocks() as demo:
360
- gr.Markdown(ABOUT)
361
- input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
362
- model_choice = gr.Radio(
363
- label="Choose Model (The use of VCTK+DNS is recommended)",
364
- choices=["VCTK-Demand", "VCTK+DNS"],
365
- value="VCTK-Demand"
366
- )
367
- enhance_btn = gr.Button("Enhance")
368
- output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
369
- plot_output = gr.Plot(label="Spectrograms")
370
-
371
- enhance_btn.click(
372
- fn=enhance,
373
- inputs=[input_audio, model_choice],
374
- outputs=[output_audio, plot_output]
375
- )
376
- gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
377
-
378
- demo.queue().launch()
 
136
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
137
  return resampled_audio_path
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  @spaces.GPU
141
  def extract_faces(video_file):
142
+ cap = cv2.VideoCapture(video_file)
 
 
143
  fps = cap.get(cv2.CAP_PROP_FPS)
144
  frames = []
145
 
 
235
  )
236
 
237
  iface.launch()