File size: 4,272 Bytes
487ed33
108107c
487ed33
 
f4c725a
108107c
487ed33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3b43a8
 
 
 
487ed33
c3b43a8
 
487ed33
f4c725a
 
 
487ed33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108107c
487ed33
 
 
 
 
 
 
 
c3b43a8
487ed33
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c725a
487ed33
 
 
 
176e214
 
487ed33
c3b43a8
487ed33
 
 
 
 
 
108107c
487ed33
 
 
 
 
 
 
 
 
 
108107c
 
487ed33
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import spaces
import gradio as gr
import logging
from pathlib import Path
import base64

from model import (
    MODEL_ID as WHISPER_MODEL_ID,
    PHI_MODEL_ID,
    transcribe_audio_local,
    transcribe_audio_phi,
    preload_models,
)

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Constants
EXAMPLES_DIR = Path("./examples")
MODEL_CHOICES = {
    PHI_MODEL_ID: "Phi-4 Model",
    WHISPER_MODEL_ID: "Whisper Model",
}
EXAMPLE_FILES = [
    [str(EXAMPLES_DIR / "audio1.mp3"), PHI_MODEL_ID],
    [str(EXAMPLES_DIR / "audio2.mp3"), PHI_MODEL_ID],
]


def read_file_as_base64(file_path: str) -> str:
    """
    Read a file and encode it as base64.

    Args:
        file_path: Path to the file to read

    Returns:
        Base64 encoded string of file contents
    """
    try:
        with open(file_path, "rb") as f:
            return base64.b64encode(f.read()).decode()
    except Exception as e:
        logger.error(f"Failed to read file {file_path}: {str(e)}")
        raise


def combined_transcription(audio: str, model_choice: str) -> str:
    """
    Transcribe audio using the selected model.

    Args:
        audio: Path to audio file
        model_choice: Full model ID to use for transcription

    Returns:
        Transcription text
    """
    if not audio:
        return "Please provide an audio file to transcribe."

    try:
        if model_choice == PHI_MODEL_ID:
            return transcribe_audio_phi(audio)
        elif model_choice == WHISPER_MODEL_ID:
            return transcribe_audio_local(audio)
        else:
            logger.error(f"Unknown model choice: {model_choice}")
            return f"Error: Unknown model {model_choice}"
    except Exception as e:
        logger.error(f"Transcription failed: {str(e)}")
        return f"Error during transcription: {str(e)}"


def create_demo() -> gr.Blocks:
    """Create and configure the Gradio demo interface"""
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# TWASR: Chinese (Taiwan) Automatic Speech Recognition")
        gr.Markdown(
            "Upload an audio file or record your voice to transcribe it to text."
        )
        gr.Markdown(
            "⚠️ First load may take a while to initialize the model, following requests will be faster."
        )

        with gr.Row():
            audio_input = gr.Audio(
                label="Audio Input", type="filepath", show_download_button=True
            )
            with gr.Column():
                model_choice = gr.Dropdown(
                    label="Select Model",
                    choices=list(MODEL_CHOICES.keys()),
                    value=PHI_MODEL_ID,
                    info="Select the model for transcription",
                )
                text_output = gr.Textbox(label="Transcription Output", lines=5)

        with gr.Row():
            transcribe_button = gr.Button("🎯 Transcribe", variant="primary")
            clear_button = gr.Button("🧹 Clear")

        transcribe_button.click(
            fn=combined_transcription,
            inputs=[audio_input, model_choice],
            outputs=[text_output],
            show_progress=True,
        )

        clear_button.click(
            fn=lambda: (None, ""),
            inputs=[],
            outputs=[audio_input, text_output],
        )

        gr.Examples(
            examples=EXAMPLE_FILES,
            inputs=[audio_input, model_choice],
            outputs=[text_output],
            fn=combined_transcription,
            cache_examples=True,
            cache_mode="lazy",
            run_on_click=True,
        )

        gr.Markdown("### Model Information")
        with gr.Accordion("Model Details", open=False):
            for model_id, model_name in MODEL_CHOICES.items():
                gr.Markdown(
                    f"**{model_name}:** [{model_id}](https://huggingface.co/{model_id})"
                )

    return demo


if __name__ == "__main__":
    # Preload models before starting the app to reduce cold start time
    logger.info("Preloading models to reduce cold start time")
    preload_models()

    demo = create_demo()
    demo.launch(share=False)