|
|
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import logging |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
CHUNK_SECONDS = 30 |
|
CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS |
|
MODEL_NAME = "openpecha/general_stt_base_model" |
|
|
|
title = "# Tibetan Speech-to-Text" |
|
|
|
description = """ |
|
This application transcribes Tibetan audio files using: |
|
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings |
|
- 30-second fixed chunking for long audio processing |
|
""" |
|
|
|
|
|
def init_model(): |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME) |
|
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
|
|
return model, processor |
|
|
|
|
|
model, processor = init_model() |
|
|
|
def process_audio(audio_path: str): |
|
if audio_path is None or audio_path == "": |
|
return "Please upload an audio file first" |
|
|
|
logging.info(f"Processing audio file: {audio_path}") |
|
|
|
try: |
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
if sr != SAMPLE_RATE: |
|
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav) |
|
wav = wav.mean(dim=0) |
|
|
|
|
|
audio_length = wav.shape[0] |
|
transcriptions = [] |
|
|
|
for start_sample in range(0, audio_length, CHUNK_SAMPLES): |
|
end_sample = min(start_sample + CHUNK_SAMPLES, audio_length) |
|
|
|
|
|
chunk = wav[start_sample:end_sample] |
|
|
|
|
|
if chunk.shape[0] < 0.5 * SAMPLE_RATE: |
|
continue |
|
|
|
|
|
inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.decode(predicted_ids[0]) |
|
|
|
|
|
if transcription.strip(): |
|
transcriptions.append(transcription) |
|
|
|
if not transcriptions: |
|
return "No speech detected or recognized" |
|
|
|
|
|
all_text = " ".join(transcriptions) |
|
return all_text |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing audio: {str(e)}") |
|
return f"Error processing audio: {str(e)}" |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.Markdown(title) |
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio( |
|
sources=["upload"], |
|
type="filepath", |
|
label="Upload audio file", |
|
) |
|
|
|
process_button = gr.Button("Transcribe Audio") |
|
|
|
with gr.Row(): |
|
text_output = gr.Textbox( |
|
label="Transcription", |
|
placeholder="Transcribed text will appear here...", |
|
lines=8 |
|
) |
|
|
|
process_button.click( |
|
process_audio, |
|
inputs=[audio_input], |
|
outputs=[text_output], |
|
) |
|
|
|
gr.Markdown(description) |
|
|
|
if __name__ == "__main__": |
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
logging.basicConfig(format=formatter, level=logging.INFO) |
|
demo.launch(share=True) |