roychao19477 commited on
Commit
678d466
·
1 Parent(s): 6e97a1b

Version revise

Browse files
Files changed (1) hide show
  1. app.py +5 -38
app.py CHANGED
@@ -103,46 +103,13 @@ def run_avse_inference(video_path, audio_path):
103
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
- #data = {
107
- # "noisy_audio": noisy,
108
- # "video_frames": bg_frames[np.newaxis, ...]
109
- #}
110
 
111
  with torch.no_grad():
112
- # Version 1
113
- #estimated = avse_model.enhance(data).reshape(-1)
114
- # Version 2
115
- chunk_sec = 6
116
- sr = 16000
117
- audio_chunk_len = chunk_sec * sr # 48000
118
- video_chunk_len = chunk_sec * 25 # 75
119
-
120
- estimated_chunks = []
121
-
122
- for i in range(0, len(noisy), audio_chunk_len):
123
- audio_chunk = noisy[i:i+audio_chunk_len]
124
- if len(audio_chunk) < audio_chunk_len:
125
- pad = np.zeros(audio_chunk_len - len(audio_chunk), dtype=audio_chunk.dtype)
126
- audio_chunk = np.concatenate([audio_chunk, pad])
127
-
128
- vid_idx = i // sr * 25 # convert audio index to video frame index
129
- #video_chunk = bg_frames[0, vid_idx:vid_idx+video_chunk_len, :, :]
130
- video_chunk = bg_frames[vid_idx:vid_idx+video_chunk_len, :, :]
131
- if video_chunk.shape[0] < video_chunk_len:
132
- pad = np.zeros((video_chunk_len - video_chunk.shape[0], *video_chunk.shape[1:]), dtype=video_chunk.dtype)
133
- video_chunk = np.concatenate([video_chunk, pad], axis=0)
134
-
135
- data = {
136
- "noisy_audio": audio_chunk,
137
- "video_frames": video_chunk[np.newaxis, ...]
138
- }
139
-
140
- with torch.no_grad():
141
- out = avse_model.enhance(data).reshape(-1)
142
- estimated_chunks.append(out)
143
-
144
- estimated = np.concatenate(estimated_chunks)[:len(noisy)]
145
-
146
 
147
  # Save result
148
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
103
 
104
 
105
  # Combine into input dict (match what model.enhance expects)
106
+ data = {
107
+ "noisy_audio": noisy,
108
+ "video_frames": bg_frames[np.newaxis, ...]
109
+ }
110
 
111
  with torch.no_grad():
112
+ estimated = avse_model.enhance(data).reshape(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")