roychao19477 commited on
Commit
2cb0aee
·
1 Parent(s): 0b3d66c

Test on lengths

Browse files
Files changed (1) hide show
  1. app.py +68 -22
app.py CHANGED
@@ -7,6 +7,8 @@ import shutil
7
  import glob
8
  import gradio as gr
9
 
 
 
10
  # install packages for mamba
11
  def install_mamba():
12
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
@@ -63,8 +65,6 @@ from moviepy import ImageSequenceClip
63
  from scipy.io import wavfile
64
  from avse_code import run_avse
65
 
66
- # Load face detector
67
- model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
68
 
69
 
70
  from decord import VideoReader, cpu
@@ -75,15 +75,18 @@ import spaces
75
  # Load model once globally
76
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
78
- avse_model = AVSEModule()
79
  #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
80
- avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
81
- avse_model.load_state_dict(avse_state_dict, strict=True)
82
- avse_model.to("cuda")
83
- avse_model.eval()
84
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
 
 
 
 
 
87
  estimated = run_avse(video_path, audio_path)
88
  # Load audio
89
  #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
@@ -101,15 +104,39 @@ def run_avse_inference(video_path, audio_path):
101
  ]).astype(np.float32)
102
  bg_frames /= 255.0
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
- data = {
107
- "noisy_audio": noisy,
108
- "video_frames": bg_frames[np.newaxis, ...]
109
- }
 
 
 
 
110
 
111
  with torch.no_grad():
112
- estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
@@ -135,9 +162,32 @@ def extract_resampled_audio(video_path, target_sr=16000):
135
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
136
  return resampled_audio_path
137
 
 
 
 
 
 
138
 
139
  @spaces.GPU
140
  def extract_faces(video_file):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  cap = cv2.VideoCapture(video_file)
142
  fps = cap.get(cv2.CAP_PROP_FPS)
143
  frames = []
@@ -148,7 +198,8 @@ def extract_faces(video_file):
148
  break
149
 
150
  # Inference
151
- results = model(frame, verbose=False)[0]
 
152
  for box in results.boxes:
153
  # version 1
154
  # x1, y1, x2, y2 = map(int, box.xyxy[0])
@@ -218,14 +269,7 @@ def extract_faces(video_file):
218
  enhanced_audio_path = run_avse_inference(output_path, audio_path)
219
 
220
 
221
- from moviepy import VideoFileClip
222
- flipped_output_path = os.path.join(tmpdir, "face_only_video_flipped.mp4")
223
- flipped_clip = VideoFileClip(output_path, fps=25)
224
- flipped_clip = flipped_clip.fx(vfx.mirror_y)
225
- flipped_clip.write_videofile(flipped_output_path, codec="libx264", audio=False, fps=25)
226
-
227
- return flipped_output_path, enhanced_audio_path
228
- #return output_path, enhanced_audio_path
229
  #return output_path, audio_path
230
 
231
  iface = gr.Interface(
@@ -237,7 +281,9 @@ iface = gr.Interface(
237
  gr.Audio(label="Enhanced Audio", type="filepath")
238
  ],
239
  title="Face Detector",
240
- description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
 
241
  )
242
 
243
  iface.launch()
 
 
7
  import glob
8
  import gradio as gr
9
 
10
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
11
+
12
  # install packages for mamba
13
  def install_mamba():
14
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
 
65
  from scipy.io import wavfile
66
  from avse_code import run_avse
67
 
 
 
68
 
69
 
70
  from decord import VideoReader, cpu
 
75
  # Load model once globally
76
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
 
78
  #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
79
+
80
+ CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
81
+ CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
 
82
 
83
  @spaces.GPU
84
  def run_avse_inference(video_path, audio_path):
85
+ avse_model = AVSEModule()
86
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
87
+ avse_model.load_state_dict(avse_state_dict, strict=True)
88
+ avse_model.to("cuda")
89
+ avse_model.eval()
90
  estimated = run_avse(video_path, audio_path)
91
  # Load audio
92
  #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
 
104
  ]).astype(np.float32)
105
  bg_frames /= 255.0
106
 
107
+ audio_chunks = [
108
+ noisy[i:i + CHUNK_SIZE_AUDIO]
109
+ for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
110
+ ]
111
+
112
+ video_chunks = [
113
+ bg_frames[i:i + CHUNK_SIZE_VIDEO]
114
+ for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
115
+ ]
116
+
117
+ min_len = min(len(audio_chunks), len(video_chunks)) # sync length
118
+
119
 
120
  # Combine into input dict (match what model.enhance expects)
121
+ #data = {
122
+ # "noisy_audio": noisy,
123
+ # "video_frames": bg_frames[np.newaxis, ...]
124
+ #}
125
+
126
+ #with torch.no_grad():
127
+ # estimated = avse_model.enhance(data).reshape(-1)
128
+ estimated_chunks = []
129
 
130
  with torch.no_grad():
131
+ for i in range(min_len):
132
+ chunk_data = {
133
+ "noisy_audio": audio_chunks[i],
134
+ "video_frames": video_chunks[i][np.newaxis, ...]
135
+ }
136
+ est = avse_model.enhance(chunk_data).reshape(-1)
137
+ estimated_chunks.append(est)
138
+
139
+ estimated = np.concatenate(estimated_chunks, axis=0)
140
 
141
  # Save result
142
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
162
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
163
  return resampled_audio_path
164
 
165
+ @spaces.GPU
166
+ def yolo_detection(frame, verbose=False):
167
+ # Load face detector
168
+ model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
169
+ return model(frame, verbose=verbose)[0]
170
 
171
  @spaces.GPU
172
  def extract_faces(video_file):
173
+ if isinstance(video_input, dict):
174
+ video_path = video_input.get("path") or video_input.get("url")
175
+ if video_path.startswith("http"):
176
+ # download video
177
+ tmpdir = tempfile.mkdtemp()
178
+ ext = os.path.splitext(urlparse(video_path).path)[1]
179
+ local_path = os.path.join(tmpdir, "input_video" + ext)
180
+ with open(local_path, "wb") as f:
181
+ f.write(requests.get(video_path).content)
182
+ video_file = local_path
183
+ else:
184
+ video_file = video_path
185
+ else:
186
+ video_file = video_input # string path from UI
187
+
188
+
189
+
190
+
191
  cap = cv2.VideoCapture(video_file)
192
  fps = cap.get(cv2.CAP_PROP_FPS)
193
  frames = []
 
198
  break
199
 
200
  # Inference
201
+ #results = model(frame, verbose=False)[0]
202
+ results = yolo_detection(frame, verbose=False)
203
  for box in results.boxes:
204
  # version 1
205
  # x1, y1, x2, y2 = map(int, box.xyxy[0])
 
269
  enhanced_audio_path = run_avse_inference(output_path, audio_path)
270
 
271
 
272
+ return output_path, enhanced_audio_path
 
 
 
 
 
 
 
273
  #return output_path, audio_path
274
 
275
  iface = gr.Interface(
 
281
  gr.Audio(label="Enhanced Audio", type="filepath")
282
  ],
283
  title="Face Detector",
284
+ description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio.",
285
+ api_name="/predict"
286
  )
287
 
288
  iface.launch()
289
+