Spaces:
Running
Running
import torch | |
import torchaudio | |
from torchaudio.transforms import Resample | |
from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration | |
from audio_recorder_streamlit import audio_recorder | |
import streamlit as st | |
def preprocess_audio(audio_bytes, sample_rate=16000): | |
# Load audio and convert to mono if necessary | |
waveform, _ = torchaudio.load(audio_bytes, normalize=True) | |
if waveform.size(0) > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# Resample if needed | |
if waveform.shape[1] != sample_rate: | |
resampler = Resample(orig_freq=waveform.shape[1], new_freq=sample_rate) | |
waveform = resampler(waveform) | |
return waveform | |
def transcribe_audio(audio_bytes): | |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-mustc-en-fr-st") | |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-mustc-en-fr-st") | |
# Preprocess audio | |
input_features = preprocess_audio(audio_bytes) | |
# Generate transcription | |
generated_ids = model.generate(input_features) | |
translation = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return translation | |
st.title("Audio to Text Transcription..") | |
audio_bytes = audio_recorder(pause_threshold=3.0, sample_rate=16000) | |
if audio_bytes: | |
st.audio(audio_bytes, format="audio/wav") | |
transcription = transcribe_audio(audio_bytes) | |
if transcription: | |
st.write("Transcription:") | |
st.write(transcription) | |
else: | |
st.write("Error: Failed to transcribe audio.") | |
else: | |
st.write("No audio recorded.") | |