pratikshahp commited on
Commit
f013d22
·
verified ·
1 Parent(s): 16d11ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -1,15 +1,29 @@
1
  import torch
 
 
2
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
3
  from audio_recorder_streamlit import audio_recorder
4
- import numpy as np
5
  import streamlit as st
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def transcribe_audio(audio_bytes):
8
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
9
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
10
 
11
- # Convert audio bytes to tensors
12
- input_features = torch.tensor(audio_bytes).unsqueeze(0) # Assuming audio_bytes is numpy array
13
 
14
  # Generate transcription
15
  generated_ids = model.generate(input_features)
@@ -18,7 +32,7 @@ def transcribe_audio(audio_bytes):
18
  return translation
19
 
20
  st.title("Audio to Text Transcription..")
21
- audio_bytes = audio_recorder(pause_threshold=3.0, sample_rate=16_000)
22
  if audio_bytes:
23
  st.audio(audio_bytes, format="audio/wav")
24
 
 
1
  import torch
2
+ import torchaudio
3
+ from torchaudio.transforms import Resample
4
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
5
  from audio_recorder_streamlit import audio_recorder
 
6
  import streamlit as st
7
 
8
+ def preprocess_audio(audio_bytes, sample_rate=16000):
9
+ # Load audio and convert to mono if necessary
10
+ waveform, _ = torchaudio.load(audio_bytes, normalize=True)
11
+ if waveform.size(0) > 1:
12
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
13
+
14
+ # Resample if needed
15
+ if waveform.shape[1] != sample_rate:
16
+ resampler = Resample(orig_freq=waveform.shape[1], new_freq=sample_rate)
17
+ waveform = resampler(waveform)
18
+
19
+ return waveform
20
+
21
  def transcribe_audio(audio_bytes):
22
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
23
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-mustc-en-fr-st")
24
 
25
+ # Preprocess audio
26
+ input_features = preprocess_audio(audio_bytes)
27
 
28
  # Generate transcription
29
  generated_ids = model.generate(input_features)
 
32
  return translation
33
 
34
  st.title("Audio to Text Transcription..")
35
+ audio_bytes = audio_recorder(pause_threshold=3.0, sample_rate=16000)
36
  if audio_bytes:
37
  st.audio(audio_bytes, format="audio/wav")
38