Spaces:
Running
on
Zero
Running
on
Zero
roychao19477
commited on
Commit
·
138c0ad
1
Parent(s):
f9f66c4
Update module
Browse files
app.py
CHANGED
@@ -3,6 +3,8 @@ import subprocess
|
|
3 |
import spaces
|
4 |
import torch
|
5 |
import os
|
|
|
|
|
6 |
import gradio as gr
|
7 |
|
8 |
# install packages for mamba
|
@@ -16,6 +18,15 @@ def clone_github():
|
|
16 |
"git", "clone",
|
17 |
f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git"
|
18 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
install_mamba()
|
21 |
clone_github()
|
@@ -54,6 +65,46 @@ from moviepy import ImageSequenceClip
|
|
54 |
# Load face detector
|
55 |
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def extract_resampled_audio(video_path, target_sr=16000):
|
58 |
# Step 1: extract audio via torchaudio
|
59 |
# (moviepy will still extract it to wav temp file)
|
@@ -127,14 +178,32 @@ def extract_faces(video_file):
|
|
127 |
).run(overwrite_output=True)
|
128 |
|
129 |
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
iface = gr.Interface(
|
133 |
fn=extract_faces,
|
134 |
inputs=gr.Video(label="Upload or record your video"),
|
135 |
outputs=[
|
136 |
gr.Video(label="Detected Face Only Video"),
|
137 |
-
gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
|
|
|
138 |
],
|
139 |
title="Face Detector",
|
140 |
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|
|
|
3 |
import spaces
|
4 |
import torch
|
5 |
import os
|
6 |
+
import shutil
|
7 |
+
import glob
|
8 |
import gradio as gr
|
9 |
|
10 |
# install packages for mamba
|
|
|
18 |
"git", "clone",
|
19 |
f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git"
|
20 |
])
|
21 |
+
# move all files except README.md
|
22 |
+
for item in glob.glob("tmp_repo/*"):
|
23 |
+
if os.path.basename(item) != "README.md":
|
24 |
+
if os.path.isdir(item):
|
25 |
+
shutil.move(item, ".")
|
26 |
+
else:
|
27 |
+
shutil.move(item, os.path.join(".", os.path.basename(item)))
|
28 |
+
|
29 |
+
shutil.rmtree("tmp_repo")
|
30 |
|
31 |
install_mamba()
|
32 |
clone_github()
|
|
|
65 |
# Load face detector
|
66 |
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
67 |
|
68 |
+
|
69 |
+
from decord import VideoReader, cpu
|
70 |
+
from model import AVSEModule
|
71 |
+
from config import sampling_rate
|
72 |
+
import spaces
|
73 |
+
|
74 |
+
# Load model once globally
|
75 |
+
ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
76 |
+
model = AVSEModule.load_from_checkpoint(ckpt_path)
|
77 |
+
model.to("cuda")
|
78 |
+
model.eval()
|
79 |
+
|
80 |
+
@spaces.GPU
|
81 |
+
def run_avse_inference(video_path, audio_path):
|
82 |
+
# Load audio
|
83 |
+
noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
84 |
+
noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
|
85 |
+
|
86 |
+
# Load grayscale video
|
87 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
88 |
+
frames = vr.get_batch(list(range(len(vr)))).asnumpy()
|
89 |
+
bg_frames = np.array([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in frames]).astype(np.float32) / 255.0
|
90 |
+
bg_frames = torch.tensor(bg_frames).unsqueeze(0).unsqueeze(0) # (1, 1, T, H, W)
|
91 |
+
|
92 |
+
# Combine into input dict (match what model.enhance expects)
|
93 |
+
data = {
|
94 |
+
"noisy_audio": noisy,
|
95 |
+
"video_frames": bg_frames
|
96 |
+
}
|
97 |
+
|
98 |
+
with torch.no_grad():
|
99 |
+
estimated = model.enhance(data).reshape(-1).cpu().numpy()
|
100 |
+
|
101 |
+
# Save result
|
102 |
+
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
103 |
+
sf.write(tmp_wav, estimated, samplerate=sampling_rate)
|
104 |
+
|
105 |
+
return tmp_wav
|
106 |
+
|
107 |
+
|
108 |
def extract_resampled_audio(video_path, target_sr=16000):
|
109 |
# Step 1: extract audio via torchaudio
|
110 |
# (moviepy will still extract it to wav temp file)
|
|
|
178 |
).run(overwrite_output=True)
|
179 |
|
180 |
|
181 |
+
|
182 |
+
|
183 |
+
# ------------------------------- #
|
184 |
+
# AVSE models
|
185 |
+
noisy = self.load_wav(audio_path)
|
186 |
+
|
187 |
+
vr = VideoReader(output_path, ctx=cpu(0))
|
188 |
+
frames = vr.get_batch(list(range(len(vr)))).asnumpy()
|
189 |
+
bg_frames = np.array([
|
190 |
+
cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
|
191 |
+
]).astype(np.float32)
|
192 |
+
bg_frames /= 255.0
|
193 |
+
|
194 |
+
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
195 |
+
|
196 |
+
|
197 |
+
return output_path, enhanced_audio_path
|
198 |
+
#return output_path, audio_path
|
199 |
|
200 |
iface = gr.Interface(
|
201 |
fn=extract_faces,
|
202 |
inputs=gr.Video(label="Upload or record your video"),
|
203 |
outputs=[
|
204 |
gr.Video(label="Detected Face Only Video"),
|
205 |
+
#gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
|
206 |
+
gr.Audio(label="Enhanced Audio", type="filepath")
|
207 |
],
|
208 |
title="Face Detector",
|
209 |
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|