roychao19477 commited on
Commit
3a0e329
·
1 Parent(s): 0d22451

Update module

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -83,19 +83,19 @@ avse_model.eval()
83
  @spaces.GPU
84
  def run_avse_inference(video_path, audio_path):
85
  # Load audio
86
- noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
87
- noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
 
88
 
89
  # Load grayscale video
90
  vr = VideoReader(video_path, ctx=cpu(0))
91
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()
92
  bg_frames = np.array([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in frames]).astype(np.float32) / 255.0
93
- bg_frames = torch.tensor(bg_frames).unsqueeze(0).unsqueeze(0) # (1, 1, T, H, W)
94
 
95
  # Combine into input dict (match what model.enhance expects)
96
  data = {
97
- "noisy_audio": noisy.numpy(),
98
- "video_frames": bg_frames.numpy()
99
  }
100
 
101
  with torch.no_grad():
@@ -189,7 +189,6 @@ def extract_faces(video_file):
189
 
190
  # ------------------------------- #
191
  # AVSE models
192
- noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
193
 
194
  vr = VideoReader(output_path, ctx=cpu(0))
195
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()
 
83
  @spaces.GPU
84
  def run_avse_inference(video_path, audio_path):
85
  # Load audio
86
+ #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
87
+ #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
88
+ noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
89
 
90
  # Load grayscale video
91
  vr = VideoReader(video_path, ctx=cpu(0))
92
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()
93
  bg_frames = np.array([cv2.cvtColor(f, cv2.COLOR_RGB2GRAY) for f in frames]).astype(np.float32) / 255.0
 
94
 
95
  # Combine into input dict (match what model.enhance expects)
96
  data = {
97
+ "noisy_audio": noisy,
98
+ "video_frames": bg_frames
99
  }
100
 
101
  with torch.no_grad():
 
189
 
190
  # ------------------------------- #
191
  # AVSE models
 
192
 
193
  vr = VideoReader(output_path, ctx=cpu(0))
194
  frames = vr.get_batch(list(range(len(vr)))).asnumpy()