File size: 3,493 Bytes
3543a1c
 
f78a75f
 
 
 
 
3543a1c
f78a75f
3543a1c
 
2a3adfa
 
482d6e9
f78a75f
0191635
f78a75f
3543a1c
0191635
3543a1c
2a3adfa
3543a1c
 
2a3adfa
 
3543a1c
 
 
 
 
2a3adfa
3543a1c
2a3adfa
 
f78a75f
3543a1c
 
0191635
3543a1c
 
 
4ce0e75
 
 
 
 
 
 
2a3adfa
 
4ce0e75
 
2a3adfa
 
 
 
 
 
 
 
 
 
 
 
4ce0e75
 
 
 
2a3adfa
 
 
 
f78a75f
0191635
 
 
 
 
 
 
4ce0e75
 
0191635
f78a75f
0191635
3543a1c
 
 
 
0191635
 
 
 
 
3543a1c
0191635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3543a1c
 
f78a75f
 
3543a1c
 
2a3adfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3

import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import logging

# Constants and Configuration
SAMPLE_RATE = 16000
CHUNK_SECONDS = 30  # Split audio into 30-second chunks
CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS
MODEL_NAME = "openpecha/general_stt_base_model"

title = "# Tibetan Speech-to-Text"

description = """
This application transcribes Tibetan audio files using:
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
- 30-second fixed chunking for long audio processing
"""

# Initialize model
def init_model():
    # Load Wav2Vec2 model
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
    processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
    model.eval()
    
    return model, processor

# Initialize model globally
model, processor = init_model()

def process_audio(audio_path: str):
    if audio_path is None or audio_path == "":
        return "Please upload an audio file first"

    logging.info(f"Processing audio file: {audio_path}")

    try:
        # Load and resample audio to 16kHz mono
        wav, sr = torchaudio.load(audio_path)
        if sr != SAMPLE_RATE:
            wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
        wav = wav.mean(dim=0)  # convert to mono

        # Split audio into 30-second chunks
        audio_length = wav.shape[0]
        transcriptions = []
        
        for start_sample in range(0, audio_length, CHUNK_SAMPLES):
            end_sample = min(start_sample + CHUNK_SAMPLES, audio_length)
            
            # Extract chunk
            chunk = wav[start_sample:end_sample]
            
            # Skip processing if chunk is too short (less than 0.5 seconds)
            if chunk.shape[0] < 0.5 * SAMPLE_RATE:
                continue
                
            # Process chunk through model
            inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
            with torch.no_grad():
                logits = model(**inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)
            transcription = processor.decode(predicted_ids[0])
            
            # Skip empty transcriptions
            if transcription.strip():
                transcriptions.append(transcription)

        if not transcriptions:
            return "No speech detected or recognized"

        # Join all transcriptions
        all_text = " ".join(transcriptions)
        return all_text
        
    except Exception as e:
        logging.error(f"Error processing audio: {str(e)}")
        return f"Error processing audio: {str(e)}"

demo = gr.Blocks()

with demo:
    gr.Markdown(title)

    with gr.Row():
        audio_input = gr.Audio(
            sources=["upload"],
            type="filepath",
            label="Upload audio file",
        )
    
    process_button = gr.Button("Transcribe Audio")
    
    with gr.Row():
        text_output = gr.Textbox(
            label="Transcription",
            placeholder="Transcribed text will appear here...",
            lines=8
        )

    process_button.click(
        process_audio,
        inputs=[audio_input],
        outputs=[text_output],
    )

    gr.Markdown(description)

if __name__ == "__main__":
    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
    logging.basicConfig(format=formatter, level=logging.INFO)
    demo.launch(share=True)