stt_demo / app.py
ganga4364's picture
Update app.py
0191635 verified
#!/usr/bin/env python3
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import logging
# Constants and Configuration
SAMPLE_RATE = 16000
CHUNK_SECONDS = 30 # Split audio into 30-second chunks
CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS
MODEL_NAME = "openpecha/general_stt_base_model"
title = "# Tibetan Speech-to-Text"
description = """
This application transcribes Tibetan audio files using:
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
- 30-second fixed chunking for long audio processing
"""
# Initialize model
def init_model():
# Load Wav2Vec2 model
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model.eval()
return model, processor
# Initialize model globally
model, processor = init_model()
def process_audio(audio_path: str):
if audio_path is None or audio_path == "":
return "Please upload an audio file first"
logging.info(f"Processing audio file: {audio_path}")
try:
# Load and resample audio to 16kHz mono
wav, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
wav = wav.mean(dim=0) # convert to mono
# Split audio into 30-second chunks
audio_length = wav.shape[0]
transcriptions = []
for start_sample in range(0, audio_length, CHUNK_SAMPLES):
end_sample = min(start_sample + CHUNK_SAMPLES, audio_length)
# Extract chunk
chunk = wav[start_sample:end_sample]
# Skip processing if chunk is too short (less than 0.5 seconds)
if chunk.shape[0] < 0.5 * SAMPLE_RATE:
continue
# Process chunk through model
inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
# Skip empty transcriptions
if transcription.strip():
transcriptions.append(transcription)
if not transcriptions:
return "No speech detected or recognized"
# Join all transcriptions
all_text = " ".join(transcriptions)
return all_text
except Exception as e:
logging.error(f"Error processing audio: {str(e)}")
return f"Error processing audio: {str(e)}"
demo = gr.Blocks()
with demo:
gr.Markdown(title)
with gr.Row():
audio_input = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload audio file",
)
process_button = gr.Button("Transcribe Audio")
with gr.Row():
text_output = gr.Textbox(
label="Transcription",
placeholder="Transcribed text will appear here...",
lines=8
)
process_button.click(
process_audio,
inputs=[audio_input],
outputs=[text_output],
)
gr.Markdown(description)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
demo.launch(share=True)