roychao19477 commited on
Commit
138c0ad
·
1 Parent(s): f9f66c4

Update module

Browse files
Files changed (1) hide show
  1. app.py +71 -2
app.py CHANGED
@@ -3,6 +3,8 @@ import subprocess
3
  import spaces
4
  import torch
5
  import os
 
 
6
  import gradio as gr
7
 
8
  # install packages for mamba
@@ -16,6 +18,15 @@ def clone_github():
16
  "git", "clone",
17
  f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git"
18
  ])
 
 
 
 
 
 
 
 
 
19
 
20
  install_mamba()
21
  clone_github()
@@ -54,6 +65,46 @@ from moviepy import ImageSequenceClip
54
  # Load face detector
55
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def extract_resampled_audio(video_path, target_sr=16000):
58
  # Step 1: extract audio via torchaudio
59
  # (moviepy will still extract it to wav temp file)
@@ -127,14 +178,32 @@ def extract_faces(video_file):
127
  ).run(overwrite_output=True)
128
 
129
 
130
- return output_path, audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  iface = gr.Interface(
133
  fn=extract_faces,
134
  inputs=gr.Video(label="Upload or record your video"),
135
  outputs=[
136
  gr.Video(label="Detected Face Only Video"),
137
- gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
 
138
  ],
139
  title="Face Detector",
140
  description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."
 
3
  import spaces
4
  import torch
5
  import os
6
+ import shutil
7
+ import glob
8
  import gradio as gr
9
 
10
  # install packages for mamba
 
18
  "git", "clone",
19
  f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git"
20
  ])
21
+ # move all files except README.md
22
+ for item in glob.glob("tmp_repo/*"):
23
+ if os.path.basename(item) != "README.md":
24
+ if os.path.isdir(item):
25
+ shutil.move(item, ".")
26
+ else:
27
+ shutil.move(item, os.path.join(".", os.path.basename(item)))
28
+
29
+ shutil.rmtree("tmp_repo")
30
 
31
  install_mamba()
32
  clone_github()
 
65
  # Load face detector
66
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
67
 
68
+
69
+ from decord import VideoReader, cpu
70
+ from model import AVSEModule
71
+ from config import sampling_rate
72
+ import spaces
73
+
74
+ # Load model once globally
75
+ ckpt_path = "ckpts/ep215_0906.oat.ckpt"
76
+ model = AVSEModule.load_from_checkpoint(ckpt_path)
77
+ model.to("cuda")
78
+ model.eval()
79
+
80
+ @spaces.GPU
81
+ def run_avse_inference(video_path, audio_path):
82
+ # Load audio
83
+ noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
84
+ noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
85
+
86
+ # Load grayscale video
87
+ vr = VideoReader(video_path, ctx=cpu(0))
88
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
89
+ bg_frames = np.array([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in frames]).astype(np.float32) / 255.0
90
+ bg_frames = torch.tensor(bg_frames).unsqueeze(0).unsqueeze(0) # (1, 1, T, H, W)
91
+
92
+ # Combine into input dict (match what model.enhance expects)
93
+ data = {
94
+ "noisy_audio": noisy,
95
+ "video_frames": bg_frames
96
+ }
97
+
98
+ with torch.no_grad():
99
+ estimated = model.enhance(data).reshape(-1).cpu().numpy()
100
+
101
+ # Save result
102
+ tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
103
+ sf.write(tmp_wav, estimated, samplerate=sampling_rate)
104
+
105
+ return tmp_wav
106
+
107
+
108
  def extract_resampled_audio(video_path, target_sr=16000):
109
  # Step 1: extract audio via torchaudio
110
  # (moviepy will still extract it to wav temp file)
 
178
  ).run(overwrite_output=True)
179
 
180
 
181
+
182
+
183
+ # ------------------------------- #
184
+ # AVSE models
185
+ noisy = self.load_wav(audio_path)
186
+
187
+ vr = VideoReader(output_path, ctx=cpu(0))
188
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
189
+ bg_frames = np.array([
190
+ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
191
+ ]).astype(np.float32)
192
+ bg_frames /= 255.0
193
+
194
+ enhanced_audio_path = run_avse_inference(output_path, audio_path)
195
+
196
+
197
+ return output_path, enhanced_audio_path
198
+ #return output_path, audio_path
199
 
200
  iface = gr.Interface(
201
  fn=extract_faces,
202
  inputs=gr.Video(label="Upload or record your video"),
203
  outputs=[
204
  gr.Video(label="Detected Face Only Video"),
205
+ #gr.Audio(label="Extracted Audio (16kHz)", type="filepath"),
206
+ gr.Audio(label="Enhanced Audio", type="filepath")
207
  ],
208
  title="Face Detector",
209
  description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio."