roychao19477 commited on
Commit
7d86927
·
1 Parent(s): f20206d

Test on lengths

Browse files
Files changed (1) hide show
  1. app.py +128 -37
app.py CHANGED
@@ -82,9 +82,6 @@ avse_model.load_state_dict(avse_state_dict, strict=True)
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
85
- CHUNK_SIZE_AUDIO = 128000 # 3 sec at 16kHz
86
- CHUNK_SIZE_VIDEO = 200 # 25fps × 3 sec
87
-
88
  @spaces.GPU
89
  def run_avse_inference(video_path, audio_path):
90
  estimated = run_avse(video_path, audio_path)
@@ -104,39 +101,15 @@ def run_avse_inference(video_path, audio_path):
104
  ]).astype(np.float32)
105
  bg_frames /= 255.0
106
 
107
- audio_chunks = [
108
- noisy[i:i + CHUNK_SIZE_AUDIO]
109
- for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
110
- ]
111
-
112
- video_chunks = [
113
- bg_frames[i:i + CHUNK_SIZE_VIDEO]
114
- for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
115
- ]
116
-
117
- min_len = min(len(audio_chunks), len(video_chunks)) # sync length
118
-
119
 
120
  # Combine into input dict (match what model.enhance expects)
121
- #data = {
122
- # "noisy_audio": noisy,
123
- # "video_frames": bg_frames[np.newaxis, ...]
124
- #}
125
-
126
- #with torch.no_grad():
127
- # estimated = avse_model.enhance(data).reshape(-1)
128
- estimated_chunks = []
129
 
130
  with torch.no_grad():
131
- for i in range(min_len):
132
- chunk_data = {
133
- "noisy_audio": audio_chunks[i],
134
- "video_frames": video_chunks[i][np.newaxis, ...]
135
- }
136
- est = avse_model.enhance(chunk_data).reshape(-1)
137
- estimated_chunks.append(est)
138
-
139
- estimated = np.concatenate(estimated_chunks, axis=0)
140
 
141
  # Save result
142
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
@@ -162,10 +135,6 @@ def extract_resampled_audio(video_path, target_sr=16000):
162
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
163
  return resampled_audio_path
164
 
165
- #@spaces.GPU
166
- #def yolo_detection(frame, verbose=False):
167
- # return model(frame, verbose=verbose)[0]
168
-
169
  @spaces.GPU
170
  def extract_faces(video_file):
171
  cap = cv2.VideoCapture(video_file)
@@ -179,7 +148,6 @@ def extract_faces(video_file):
179
 
180
  # Inference
181
  results = model(frame, verbose=False)[0]
182
- #results = yolo_detection(frame, verbose=False)
183
  for box in results.boxes:
184
  # version 1
185
  # x1, y1, x2, y2 = map(int, box.xyxy[0])
@@ -265,3 +233,126 @@ iface = gr.Interface(
265
  )
266
 
267
  iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
 
 
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
87
  estimated = run_avse(video_path, audio_path)
 
101
  ]).astype(np.float32)
102
  bg_frames /= 255.0
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
+ data = {
107
+ "noisy_audio": noisy,
108
+ "video_frames": bg_frames[np.newaxis, ...]
109
+ }
 
 
 
 
110
 
111
  with torch.no_grad():
112
+ estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
135
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
136
  return resampled_audio_path
137
 
 
 
 
 
138
  @spaces.GPU
139
  def extract_faces(video_file):
140
  cap = cv2.VideoCapture(video_file)
 
148
 
149
  # Inference
150
  results = model(frame, verbose=False)[0]
 
151
  for box in results.boxes:
152
  # version 1
153
  # x1, y1, x2, y2 = map(int, box.xyxy[0])
 
233
  )
234
 
235
  iface.launch()
236
+
237
+
238
+
239
+ ckpt = "ckpts/SEMamba_advanced.pth"
240
+ cfg_f = "recipes/SEMamba_advanced.yaml"
241
+
242
+ # load config
243
+ with open(cfg_f, 'r') as f:
244
+ cfg = yaml.safe_load(f)
245
+
246
+
247
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ device = "cuda"
249
+ model = SEMamba(cfg).to(device)
250
+ #sdict = torch.load(ckpt, map_location=device)
251
+ #model.load_state_dict(sdict["generator"])
252
+ #model.eval()
253
+
254
+ @spaces.GPU
255
+ def enhance(filepath, model_name):
256
+ # Load model based on selection
257
+ ckpt_path = {
258
+ "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
259
+ "VCTK+DNS": "ckpts/vd.pth"
260
+ }[model_name]
261
+
262
+ print("Loading:", ckpt_path)
263
+ model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
264
+ model.eval()
265
+ with torch.no_grad():
266
+ # load & resample
267
+ wav, orig_sr = librosa.load(filepath, sr=None)
268
+ noisy_wav = wav.copy()
269
+ if orig_sr != 16000:
270
+ wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
271
+ x = torch.from_numpy(wav).float().to(device)
272
+ norm = torch.sqrt(len(x)/torch.sum(x**2))
273
+ #x = (x * norm).unsqueeze(0)
274
+ x = (x * norm)
275
+
276
+ # split into 4s segments (64000 samples)
277
+ segment_len = 4 * 16000
278
+ chunks = x.split(segment_len)
279
+ enhanced_chunks = []
280
+
281
+ for chunk in chunks:
282
+ if len(chunk) < segment_len:
283
+ #pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
284
+ pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
285
+ chunk = torch.cat([chunk, pad])
286
+ chunk = chunk.unsqueeze(0)
287
+
288
+ amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
289
+ amp2, pha2, _ = model(amp, pha)
290
+ out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
291
+ out = (out / norm).squeeze(0)
292
+ enhanced_chunks.append(out)
293
+
294
+ out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
295
+
296
+ # back to original rate
297
+ if orig_sr != 16000:
298
+ out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
299
+
300
+ # Normalize
301
+ peak = np.max(np.abs(out))
302
+ if peak > 0.05:
303
+ out = out / peak * 0.85
304
+
305
+ # write file
306
+ sf.write("enhanced.wav", out, orig_sr)
307
+
308
+ # spectrograms
309
+ fig, axs = plt.subplots(1, 2, figsize=(16, 4))
310
+
311
+ # noisy
312
+ D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
313
+ S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
314
+ librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
315
+ axs[0].set_title("Noisy Spectrogram")
316
+
317
+ # enhanced
318
+ D_clean = librosa.stft(out, n_fft=512, hop_length=256)
319
+ S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
320
+ librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
321
+ #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
322
+ axs[1].set_title("Enhanced Spectrogram")
323
+
324
+ plt.tight_layout()
325
+
326
+ return "enhanced.wav", fig
327
+
328
+ #with gr.Blocks() as demo:
329
+ # gr.Markdown(ABOUT)
330
+ # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
331
+ # enhance_btn = gr.Button("Enhance")
332
+ # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
333
+ # plot_output = gr.Plot(label="Spectrograms")
334
+ #
335
+ # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
336
+ #
337
+ #demo.queue().launch()
338
+
339
+ with gr.Blocks() as demo:
340
+ gr.Markdown(ABOUT)
341
+ input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
342
+ model_choice = gr.Radio(
343
+ label="Choose Model (The use of VCTK+DNS is recommended)",
344
+ choices=["VCTK-Demand", "VCTK+DNS"],
345
+ value="VCTK-Demand"
346
+ )
347
+ enhance_btn = gr.Button("Enhance")
348
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
349
+ plot_output = gr.Plot(label="Spectrograms")
350
+
351
+ enhance_btn.click(
352
+ fn=enhance,
353
+ inputs=[input_audio, model_choice],
354
+ outputs=[output_audio, plot_output]
355
+ )
356
+ gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
357
+
358
+ demo.queue().launch()