roychao19477 commited on
Commit
1728184
·
1 Parent(s): a45a351

Upload model

Browse files
Files changed (1) hide show
  1. app_v1.py +358 -0
app_v1.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shlex
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+ import os
6
+ import shutil
7
+ import glob
8
+ import gradio as gr
9
+
10
+ # install packages for mamba
11
+ def install_mamba():
12
+ subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
13
+
14
+ def clone_github():
15
+ subprocess.run([
16
+ "git", "clone",
17
+ f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git",
18
+ ])
19
+ # move all files except README.md
20
+ for item in glob.glob("for_HF_AVSEMamba/*"):
21
+ if os.path.basename(item) != "README.md":
22
+ if os.path.isdir(item):
23
+ shutil.move(item, ".")
24
+ else:
25
+ shutil.move(item, os.path.join(".", os.path.basename(item)))
26
+
27
+ #shutil.rmtree("tmp_repo")
28
+ #subprocess.run(["ls"], check=True)
29
+
30
+ install_mamba()
31
+ clone_github()
32
+
33
+ ABOUT = """
34
+ # SEMamba: Speech Enhancement
35
+ A Mamba-based model that denoises real-world audio.
36
+ Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
37
+ """
38
+
39
+
40
+ import torch
41
+ import ffmpeg
42
+ import torchaudio
43
+ import torchaudio.transforms as T
44
+ import yaml
45
+ import librosa
46
+ import librosa.display
47
+ import matplotlib
48
+ import numpy as np
49
+ import soundfile as sf
50
+ import matplotlib.pyplot as plt
51
+ from models.stfts import mag_phase_stft, mag_phase_istft
52
+ from models.generator import SEMamba
53
+ from models.pcs400 import cal_pcs
54
+ from ultralytics import YOLO
55
+ import supervision as sv
56
+
57
+ import gradio as gr
58
+ import cv2
59
+ import os
60
+ import tempfile
61
+ from ultralytics import YOLO
62
+ from moviepy import ImageSequenceClip
63
+ from scipy.io import wavfile
64
+ from avse_code import run_avse
65
+
66
+ # Load face detector
67
+ model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
68
+
69
+
70
+ from decord import VideoReader, cpu
71
+ from model import AVSEModule
72
+ from config import sampling_rate
73
+ import spaces
74
+
75
+ # Load model once globally
76
+ #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
+ #model = AVSEModule.load_from_checkpoint(ckpt_path)
78
+ avse_model = AVSEModule()
79
+ #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
80
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
81
+ avse_model.load_state_dict(avse_state_dict, strict=True)
82
+ avse_model.to("cuda")
83
+ avse_model.eval()
84
+
85
+ @spaces.GPU
86
+ def run_avse_inference(video_path, audio_path):
87
+ estimated = run_avse(video_path, audio_path)
88
+ # Load audio
89
+ #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
90
+ #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
91
+ noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
92
+
93
+ # Norm.
94
+ #noisy = noisy * (0.8 / np.max(np.abs(noisy)))
95
+
96
+ # Load grayscale video
97
+ vr = VideoReader(video_path, ctx=cpu(0))
98
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
99
+ bg_frames = np.array([
100
+ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
101
+ ]).astype(np.float32)
102
+ bg_frames /= 255.0
103
+
104
+
105
+ # Combine into input dict (match what model.enhance expects)
106
+ data = {
107
+ "noisy_audio": noisy,
108
+ "video_frames": bg_frames[np.newaxis, ...]
109
+ }
110
+
111
+ with torch.no_grad():
112
+ estimated = avse_model.enhance(data).reshape(-1)
113
+
114
+ # Save result
115
+ tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
116
+ sf.write(tmp_wav, estimated, samplerate=sampling_rate)
117
+
118
+ return tmp_wav
119
+
120
+
121
+ def extract_resampled_audio(video_path, target_sr=16000):
122
+ # Step 1: extract audio via torchaudio
123
+ # (moviepy will still extract it to wav temp file)
124
+ tmp_audio_path = tempfile.mktemp(suffix=".wav")
125
+ subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path])
126
+
127
+ # Step 2: Load and resample
128
+ waveform, sr = torchaudio.load(tmp_audio_path)
129
+ if sr != target_sr:
130
+ resampler = T.Resample(orig_freq=sr, new_freq=target_sr)
131
+ waveform = resampler(waveform)
132
+
133
+ # Step 3: Save resampled audio
134
+ resampled_audio_path = tempfile.mktemp(suffix="_16k.wav")
135
+ torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
136
+ return resampled_audio_path
137
+
138
+ @spaces.GPU
139
+ def extract_faces(video_file):
140
+ cap = cv2.VideoCapture(video_file)
141
+ fps = cap.get(cv2.CAP_PROP_FPS)
142
+ frames = []
143
+
144
+ while True:
145
+ ret, frame = cap.read()
146
+ if not ret:
147
+ break
148
+
149
+ # Inference
150
+ results = model(frame, verbose=False)[0]
151
+ for box in results.boxes:
152
+ # version 1
153
+ # x1, y1, x2, y2 = map(int, box.xyxy[0])
154
+
155
+ # version 2
156
+ h, w, _ = frame.shape
157
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
158
+ pad_ratio = 0.5 # 30% padding
159
+
160
+ dx = (x2 - x1) * pad_ratio
161
+ dy = (y2 - y1) * pad_ratio
162
+
163
+ x1 = int(max(0, x1 - dx))
164
+ y1 = int(max(0, y1 - dy))
165
+ x2 = int(min(w, x2 + dx))
166
+ y2 = int(min(h, y2 + dy))
167
+ # Added for v3
168
+ shift_down = int(0.1 * (y2 - y1))
169
+ y1 = int(min(max(0, y1 + shift_down), h))
170
+ y2 = int(min(max(0, y2 + shift_down), h))
171
+ face_crop = frame[y1:y2, x1:x2]
172
+ if face_crop.size != 0:
173
+ resized = cv2.resize(face_crop, (224, 224))
174
+ frames.append(resized)
175
+
176
+ #h_crop, w_crop = face_crop.shape[:2]
177
+ #side = min(h_crop, w_crop)
178
+ #start_y = (h_crop - side) // 2
179
+ #start_x = (w_crop - side) // 2
180
+ #square_crop = face_crop[start_y:start_y+side, start_x:start_x+side]
181
+ #resized = cv2.resize(square_crop, (224, 224))
182
+ #frames.append(resized)
183
+
184
+ break # only one face per frame
185
+
186
+ cap.release()
187
+
188
+ # Save as video
189
+ tmpdir = tempfile.mkdtemp()
190
+ output_path = os.path.join(tmpdir, "face_only_video.mp4")
191
+ #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25)
192
+ #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps)
193
+ clip = ImageSequenceClip(
194
+ [cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames],
195
+ fps=fps
196
+ )
197
+ clip.write_videofile(output_path, codec="libx264", audio=False, fps=25)
198
+
199
+ # Save audio from original, resampled to 16kHz
200
+ audio_path = os.path.join(tmpdir, "audio_16k.wav")
201
+
202
+ # Extract audio using ffmpeg-python (more robust than moviepy)
203
+ ffmpeg.input(video_file).output(
204
+ audio_path,
205
+ ar=16000, # resample to 16k
206
+ ac=1, # mono
207
+ format='wav',
208
+ vn=None # no video
209
+ ).run(overwrite_output=True)
210
+
211
+
212
+
213
+
214
+ # ------------------------------- #
215
+ # AVSE models
216
+
217
+ enhanced_audio_path = run_avse_inference(output_path, audio_path)
218
+
219
+
220
+ return output_path, enhanced_audio_path
221
+ #return output_path, audio_path
222
+
223
+ iface = gr.Interface(
224
+ fn=extract_faces,
225
+ inputs=gr.Video(label="Upload or record your video"),
226
+ outputs=[
227
+ gr.Video(label="Detected Face Only Video"),
228
+ #gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
229
+ gr.Audio(label="Enhanced Audio", type="filepath")
230
+ ],
231
+ title="Face Detector",
232
+ description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
233
+ )
234
+
235
+ iface.launch()
236
+
237
+
238
+
239
+ ckpt = "ckpts/SEMamba_advanced.pth"
240
+ cfg_f = "recipes/SEMamba_advanced.yaml"
241
+
242
+ # load config
243
+ with open(cfg_f, 'r') as f:
244
+ cfg = yaml.safe_load(f)
245
+
246
+
247
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ device = "cuda"
249
+ model = SEMamba(cfg).to(device)
250
+ #sdict = torch.load(ckpt, map_location=device)
251
+ #model.load_state_dict(sdict["generator"])
252
+ #model.eval()
253
+
254
+ @spaces.GPU
255
+ def enhance(filepath, model_name):
256
+ # Load model based on selection
257
+ ckpt_path = {
258
+ "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
259
+ "VCTK+DNS": "ckpts/vd.pth"
260
+ }[model_name]
261
+
262
+ print("Loading:", ckpt_path)
263
+ model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
264
+ model.eval()
265
+ with torch.no_grad():
266
+ # load & resample
267
+ wav, orig_sr = librosa.load(filepath, sr=None)
268
+ noisy_wav = wav.copy()
269
+ if orig_sr != 16000:
270
+ wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
271
+ x = torch.from_numpy(wav).float().to(device)
272
+ norm = torch.sqrt(len(x)/torch.sum(x**2))
273
+ #x = (x * norm).unsqueeze(0)
274
+ x = (x * norm)
275
+
276
+ # split into 4s segments (64000 samples)
277
+ segment_len = 4 * 16000
278
+ chunks = x.split(segment_len)
279
+ enhanced_chunks = []
280
+
281
+ for chunk in chunks:
282
+ if len(chunk) < segment_len:
283
+ #pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
284
+ pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4)
285
+ chunk = torch.cat([chunk, pad])
286
+ chunk = chunk.unsqueeze(0)
287
+
288
+ amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
289
+ amp2, pha2, _ = model(amp, pha)
290
+ out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
291
+ out = (out / norm).squeeze(0)
292
+ enhanced_chunks.append(out)
293
+
294
+ out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
295
+
296
+ # back to original rate
297
+ if orig_sr != 16000:
298
+ out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
299
+
300
+ # Normalize
301
+ peak = np.max(np.abs(out))
302
+ if peak > 0.05:
303
+ out = out / peak * 0.85
304
+
305
+ # write file
306
+ sf.write("enhanced.wav", out, orig_sr)
307
+
308
+ # spectrograms
309
+ fig, axs = plt.subplots(1, 2, figsize=(16, 4))
310
+
311
+ # noisy
312
+ D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256)
313
+ S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
314
+ librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0)
315
+ axs[0].set_title("Noisy Spectrogram")
316
+
317
+ # enhanced
318
+ D_clean = librosa.stft(out, n_fft=512, hop_length=256)
319
+ S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
320
+ librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
321
+ #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0)
322
+ axs[1].set_title("Enhanced Spectrogram")
323
+
324
+ plt.tight_layout()
325
+
326
+ return "enhanced.wav", fig
327
+
328
+ #with gr.Blocks() as demo:
329
+ # gr.Markdown(ABOUT)
330
+ # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
331
+ # enhance_btn = gr.Button("Enhance")
332
+ # output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
333
+ # plot_output = gr.Plot(label="Spectrograms")
334
+ #
335
+ # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
336
+ #
337
+ #demo.queue().launch()
338
+
339
+ with gr.Blocks() as demo:
340
+ gr.Markdown(ABOUT)
341
+ input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
342
+ model_choice = gr.Radio(
343
+ label="Choose Model (The use of VCTK+DNS is recommended)",
344
+ choices=["VCTK-Demand", "VCTK+DNS"],
345
+ value="VCTK-Demand"
346
+ )
347
+ enhance_btn = gr.Button("Enhance")
348
+ output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
349
+ plot_output = gr.Plot(label="Spectrograms")
350
+
351
+ enhance_btn.click(
352
+ fn=enhance,
353
+ inputs=[input_audio, model_choice],
354
+ outputs=[output_audio, plot_output]
355
+ )
356
+ gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.")
357
+
358
+ demo.queue().launch()