roychao19477 commited on
Commit
de425e9
·
1 Parent(s): 6792f52

Test on lengths

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -82,6 +82,9 @@ avse_model.load_state_dict(avse_state_dict, strict=True)
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
 
 
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
87
  estimated = run_avse(video_path, audio_path)
@@ -101,15 +104,39 @@ def run_avse_inference(video_path, audio_path):
101
  ]).astype(np.float32)
102
  bg_frames /= 255.0
103
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
- data = {
107
- "noisy_audio": noisy,
108
- "video_frames": bg_frames[np.newaxis, ...]
109
- }
 
 
 
 
110
 
111
  with torch.no_grad():
112
- estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
@@ -135,6 +162,10 @@ def extract_resampled_audio(video_path, target_sr=16000):
135
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
136
  return resampled_audio_path
137
 
 
 
 
 
138
  @spaces.GPU
139
  def extract_faces(video_file):
140
  cap = cv2.VideoCapture(video_file)
@@ -147,7 +178,8 @@ def extract_faces(video_file):
147
  break
148
 
149
  # Inference
150
- results = model(frame, verbose=False)[0]
 
151
  for box in results.boxes:
152
  # version 1
153
  # x1, y1, x2, y2 = map(int, box.xyxy[0])
 
82
  avse_model.to("cuda")
83
  avse_model.eval()
84
 
85
+ CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz
86
+ CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec
87
+
88
  @spaces.GPU
89
  def run_avse_inference(video_path, audio_path):
90
  estimated = run_avse(video_path, audio_path)
 
104
  ]).astype(np.float32)
105
  bg_frames /= 255.0
106
 
107
+ audio_chunks = [
108
+ noisy[i:i + CHUNK_SIZE_AUDIO]
109
+ for i in range(0, len(noisy), CHUNK_SIZE_AUDIO)
110
+ ]
111
+
112
+ video_chunks = [
113
+ bg_frames[i:i + CHUNK_SIZE_VIDEO]
114
+ for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO)
115
+ ]
116
+
117
+ min_len = min(len(audio_chunks), len(video_chunks)) # sync length
118
+
119
 
120
  # Combine into input dict (match what model.enhance expects)
121
+ #data = {
122
+ # "noisy_audio": noisy,
123
+ # "video_frames": bg_frames[np.newaxis, ...]
124
+ #}
125
+
126
+ #with torch.no_grad():
127
+ # estimated = avse_model.enhance(data).reshape(-1)
128
+ estimated_chunks = []
129
 
130
  with torch.no_grad():
131
+ for i in range(min_len):
132
+ chunk_data = {
133
+ "noisy_audio": audio_chunks[i],
134
+ "video_frames": video_chunks[i][np.newaxis, ...]
135
+ }
136
+ est = avse_model.enhance(chunk_data).reshape(-1)
137
+ estimated_chunks.append(est)
138
+
139
+ estimated = np.concatenate(estimated_chunks, axis=0)
140
 
141
  # Save result
142
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
162
  torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr)
163
  return resampled_audio_path
164
 
165
+ @spaces.GPU
166
+ def yolo_detection(frame, verbose=False):
167
+ return model(frame, verbose=verbose)[0]
168
+
169
  @spaces.GPU
170
  def extract_faces(video_file):
171
  cap = cv2.VideoCapture(video_file)
 
178
  break
179
 
180
  # Inference
181
+ #results = model(frame, verbose=False)[0]
182
+ results = yolo_detection(frame, verbose=False)
183
  for box in results.boxes:
184
  # version 1
185
  # x1, y1, x2, y2 = map(int, box.xyxy[0])