Spaces:
Sleeping
Sleeping
roychao19477
commited on
Commit
·
2cb0aee
1
Parent(s):
0b3d66c
Test on lengths
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ import shutil
|
|
7 |
import glob
|
8 |
import gradio as gr
|
9 |
|
|
|
|
|
10 |
# install packages for mamba
|
11 |
def install_mamba():
|
12 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
@@ -63,8 +65,6 @@ from moviepy import ImageSequenceClip
|
|
63 |
from scipy.io import wavfile
|
64 |
from avse_code import run_avse
|
65 |
|
66 |
-
# Load face detector
|
67 |
-
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
68 |
|
69 |
|
70 |
from decord import VideoReader, cpu
|
@@ -75,15 +75,18 @@ import spaces
|
|
75 |
# Load model once globally
|
76 |
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
77 |
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
78 |
-
avse_model = AVSEModule()
|
79 |
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
avse_model.eval()
|
84 |
|
85 |
@spaces.GPU
|
86 |
def run_avse_inference(video_path, audio_path):
|
|
|
|
|
|
|
|
|
|
|
87 |
estimated = run_avse(video_path, audio_path)
|
88 |
# Load audio
|
89 |
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
@@ -101,15 +104,39 @@ def run_avse_inference(video_path, audio_path):
|
|
101 |
]).astype(np.float32)
|
102 |
bg_frames /= 255.0
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# Combine into input dict (match what model.enhance expects)
|
106 |
-
data = {
|
107 |
-
|
108 |
-
|
109 |
-
}
|
|
|
|
|
|
|
|
|
110 |
|
111 |
with torch.no_grad():
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
# Save result
|
115 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
@@ -135,9 +162,32 @@ def extract_resampled_audio(video_path, target_sr=16000):
|
|
135 |
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
136 |
return resampled_audio_path
|
137 |
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
@spaces.GPU
|
140 |
def extract_faces(video_file):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
cap = cv2.VideoCapture(video_file)
|
142 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
143 |
frames = []
|
@@ -148,7 +198,8 @@ def extract_faces(video_file):
|
|
148 |
break
|
149 |
|
150 |
# Inference
|
151 |
-
results = model(frame, verbose=False)[0]
|
|
|
152 |
for box in results.boxes:
|
153 |
# version 1
|
154 |
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
@@ -218,14 +269,7 @@ def extract_faces(video_file):
|
|
218 |
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
219 |
|
220 |
|
221 |
-
|
222 |
-
flipped_output_path = os.path.join(tmpdir, "face_only_video_flipped.mp4")
|
223 |
-
flipped_clip = VideoFileClip(output_path, fps=25)
|
224 |
-
flipped_clip = flipped_clip.fx(vfx.mirror_y)
|
225 |
-
flipped_clip.write_videofile(flipped_output_path, codec="libx264", audio=False, fps=25)
|
226 |
-
|
227 |
-
return flipped_output_path, enhanced_audio_path
|
228 |
-
#return output_path, enhanced_audio_path
|
229 |
#return output_path, audio_path
|
230 |
|
231 |
iface = gr.Interface(
|
@@ -237,7 +281,9 @@ iface = gr.Interface(
|
|
237 |
gr.Audio(label="Enhanced Audio", type="filepath")
|
238 |
],
|
239 |
title="Face Detector",
|
240 |
-
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
|
|
|
241 |
)
|
242 |
|
243 |
iface.launch()
|
|
|
|
7 |
import glob
|
8 |
import gradio as gr
|
9 |
|
10 |
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
11 |
+
|
12 |
# install packages for mamba
|
13 |
def install_mamba():
|
14 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
|
|
65 |
from scipy.io import wavfile
|
66 |
from avse_code import run_avse
|
67 |
|
|
|
|
|
68 |
|
69 |
|
70 |
from decord import VideoReader, cpu
|
|
|
75 |
# Load model once globally
|
76 |
#ckpt_path = "ckpts/ep215_0906.oat.ckpt"
|
77 |
#model = AVSEModule.load_from_checkpoint(ckpt_path)
|
|
|
78 |
#avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
|
79 |
+
|
80 |
+
CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
|
81 |
+
CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
|
|
|
82 |
|
83 |
@spaces.GPU
|
84 |
def run_avse_inference(video_path, audio_path):
|
85 |
+
avse_model = AVSEModule()
|
86 |
+
avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
|
87 |
+
avse_model.load_state_dict(avse_state_dict, strict=True)
|
88 |
+
avse_model.to("cuda")
|
89 |
+
avse_model.eval()
|
90 |
estimated = run_avse(video_path, audio_path)
|
91 |
# Load audio
|
92 |
#noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
|
|
|
104 |
]).astype(np.float32)
|
105 |
bg_frames /= 255.0
|
106 |
|
107 |
+
audio_chunks = [
|
108 |
+
noisy[i:i + CHUNK_SIZE_AUDIO]
|
109 |
+
for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
|
110 |
+
]
|
111 |
+
|
112 |
+
video_chunks = [
|
113 |
+
bg_frames[i:i + CHUNK_SIZE_VIDEO]
|
114 |
+
for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
|
115 |
+
]
|
116 |
+
|
117 |
+
min_len = min(len(audio_chunks), len(video_chunks)) # sync length
|
118 |
+
|
119 |
|
120 |
# Combine into input dict (match what model.enhance expects)
|
121 |
+
#data = {
|
122 |
+
# "noisy_audio": noisy,
|
123 |
+
# "video_frames": bg_frames[np.newaxis, ...]
|
124 |
+
#}
|
125 |
+
|
126 |
+
#with torch.no_grad():
|
127 |
+
# estimated = avse_model.enhance(data).reshape(-1)
|
128 |
+
estimated_chunks = []
|
129 |
|
130 |
with torch.no_grad():
|
131 |
+
for i in range(min_len):
|
132 |
+
chunk_data = {
|
133 |
+
"noisy_audio": audio_chunks[i],
|
134 |
+
"video_frames": video_chunks[i][np.newaxis, ...]
|
135 |
+
}
|
136 |
+
est = avse_model.enhance(chunk_data).reshape(-1)
|
137 |
+
estimated_chunks.append(est)
|
138 |
+
|
139 |
+
estimated = np.concatenate(estimated_chunks, axis=0)
|
140 |
|
141 |
# Save result
|
142 |
tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
|
|
|
162 |
torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
|
163 |
return resampled_audio_path
|
164 |
|
165 |
+
@spaces.GPU
|
166 |
+
def yolo_detection(frame, verbose=False):
|
167 |
+
# Load face detector
|
168 |
+
model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
|
169 |
+
return model(frame, verbose=verbose)[0]
|
170 |
|
171 |
@spaces.GPU
|
172 |
def extract_faces(video_file):
|
173 |
+
if isinstance(video_input, dict):
|
174 |
+
video_path = video_input.get("path") or video_input.get("url")
|
175 |
+
if video_path.startswith("http"):
|
176 |
+
# download video
|
177 |
+
tmpdir = tempfile.mkdtemp()
|
178 |
+
ext = os.path.splitext(urlparse(video_path).path)[1]
|
179 |
+
local_path = os.path.join(tmpdir, "input_video" + ext)
|
180 |
+
with open(local_path, "wb") as f:
|
181 |
+
f.write(requests.get(video_path).content)
|
182 |
+
video_file = local_path
|
183 |
+
else:
|
184 |
+
video_file = video_path
|
185 |
+
else:
|
186 |
+
video_file = video_input # string path from UI
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
cap = cv2.VideoCapture(video_file)
|
192 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
193 |
frames = []
|
|
|
198 |
break
|
199 |
|
200 |
# Inference
|
201 |
+
#results = model(frame, verbose=False)[0]
|
202 |
+
results = yolo_detection(frame, verbose=False)
|
203 |
for box in results.boxes:
|
204 |
# version 1
|
205 |
# x1, y1, x2, y2 = map(int, box.xyxy[0])
|
|
|
269 |
enhanced_audio_path = run_avse_inference(output_path, audio_path)
|
270 |
|
271 |
|
272 |
+
return output_path, enhanced_audio_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
#return output_path, audio_path
|
274 |
|
275 |
iface = gr.Interface(
|
|
|
281 |
gr.Audio(label="Enhanced Audio", type="filepath")
|
282 |
],
|
283 |
title="Face Detector",
|
284 |
+
description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio.",
|
285 |
+
api_name="/predict"
|
286 |
)
|
287 |
|
288 |
iface.launch()
|
289 |
+
|