roychao19477 commited on
Commit
dba6227
·
1 Parent(s): 5a2e862

Upload to debug

Browse files
Files changed (1) hide show
  1. app.py +31 -8
app.py CHANGED
@@ -82,6 +82,9 @@ avse_model.load_state_dict(avse_state_dict, strict=True)
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
 
 
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
87
  estimated = run_avse(video_path, audio_path)
@@ -101,19 +104,39 @@ def run_avse_inference(video_path, audio_path):
101
  ]).astype(np.float32)
102
  bg_frames /= 255.0
103
 
104
- print(noisy.shape)
105
- print(bg_frames.shape)
106
- fesfse
 
 
 
 
 
 
 
 
107
 
108
 
109
  # Combine into input dict (match what model.enhance expects)
110
- data = {
111
- "noisy_audio": noisy,
112
- "video_frames": bg_frames[np.newaxis, ...]
113
- }
 
 
 
 
114
 
115
  with torch.no_grad():
116
- estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
117
 
118
  # Save result
119
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
85
+ CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
86
+ CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
87
+
88
  @spaces.GPU
89
  def run_avse_inference(video_path, audio_path):
90
  estimated = run_avse(video_path, audio_path)
 
104
  ]).astype(np.float32)
105
  bg_frames /= 255.0
106
 
107
+ audio_chunks = [
108
+ noisy[i:i + CHUNK_SIZE_AUDIO]
109
+ for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
110
+ ]
111
+
112
+ video_chunks = [
113
+ bg_frames[i:i + CHUNK_SIZE_VIDEO]
114
+ for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
115
+ ]
116
+
117
+ min_len = min(len(audio_chunks), len(video_chunks)) # sync length
118
 
119
 
120
  # Combine into input dict (match what model.enhance expects)
121
+ #data = {
122
+ # "noisy_audio": noisy,
123
+ # "video_frames": bg_frames[np.newaxis, ...]
124
+ #}
125
+
126
+ #with torch.no_grad():
127
+ # estimated = avse_model.enhance(data).reshape(-1)
128
+ estimated_chunks = []
129
 
130
  with torch.no_grad():
131
+ for i in range(min_len):
132
+ chunk_data = {
133
+ "noisy_audio": audio_chunks[i],
134
+ "video_frames": video_chunks[i][np.newaxis, ...]
135
+ }
136
+ est = avse_model.enhance(chunk_data).reshape(-1)
137
+ estimated_chunks.append(est)
138
+
139
+ estimated = torch.cat(estimated_chunks).cpu().numpy()
140
 
141
  # Save result
142
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")