wrice commited on
Commit
4f912e8
·
1 Parent(s): 316bc64

handle multi-channel audio

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -22,12 +22,13 @@ def denoise(model_name, inputs):
22
  sr, audio = inputs
23
  audio = torch.from_numpy(audio)[None]
24
  audio = audio / 32768.0
 
25
 
26
  print(f"Audio shape: {audio.shape}")
27
  print(f"Sample rate: {sr}")
28
 
29
- if audio.shape[-1] > 1:
30
- audio = audio.mean(-1, keepdim=True)
31
 
32
  print(f"Audio shape: {audio.shape}")
33
 
@@ -41,9 +42,9 @@ def denoise(model_name, inputs):
41
 
42
  clean = []
43
  for i in tqdm(range(0, padded.shape[-1], chunk_size)):
44
- audio_chunk = padded[:, i : i + chunk_size]
45
  with torch.no_grad():
46
- clean_chunk = model(audio_chunk[None]).logits
47
  clean.append(clean_chunk.squeeze(0))
48
 
49
  denoised = torch.concat(clean).flatten()[: audio.shape[-1]].clamp(-1.0, 1.0)
 
22
  sr, audio = inputs
23
  audio = torch.from_numpy(audio)[None]
24
  audio = audio / 32768.0
25
+ audio = audio.permute(0, 2, 1)
26
 
27
  print(f"Audio shape: {audio.shape}")
28
  print(f"Sample rate: {sr}")
29
 
30
+ if audio.shape[1] > 1:
31
+ audio = audio.mean(1, keepdim=True)
32
 
33
  print(f"Audio shape: {audio.shape}")
34
 
 
42
 
43
  clean = []
44
  for i in tqdm(range(0, padded.shape[-1], chunk_size)):
45
+ audio_chunk = padded[:, :, i : i + chunk_size]
46
  with torch.no_grad():
47
+ clean_chunk = model(audio_chunk).logits
48
  clean.append(clean_chunk.squeeze(0))
49
 
50
  denoised = torch.concat(clean).flatten()[: audio.shape[-1]].clamp(-1.0, 1.0)