wli3221134 commited on
Commit
af54972
verified
1 Parent(s): 4151dd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -64,7 +64,10 @@ def detect_on_gpu(audio_path):
64
  with torch.no_grad():
65
  for batch_idx, batch in enumerate(audio_dataset):
66
  print(f"\n澶勭悊鎵规 {batch_idx + 1}")
67
- print('waveforms shape:', batch['waveforms'].shape)
 
 
 
68
  waveforms = batch['waveforms'].numpy() # [B, T]
69
  features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
70
  outputs = model(features)
 
64
  with torch.no_grad():
65
  for batch_idx, batch in enumerate(audio_dataset):
66
  print(f"\n澶勭悊鎵规 {batch_idx + 1}")
67
+ if len(batch['waveforms'].shape) == 1:
68
+ batch['waveforms'] = batch['waveforms'].unsqueeze(0)
69
+
70
+ print('shape:', batch['waveforms'].shape)
71
  waveforms = batch['waveforms'].numpy() # [B, T]
72
  features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
73
  outputs = model(features)