TWASR / app.py
JacobLinCool's picture
fix: reorder model choices and update example files to use PHI model
c3b43a8
raw
history blame
4.27 kB
import spaces
import gradio as gr
import logging
from pathlib import Path
import base64
from model import (
MODEL_ID as WHISPER_MODEL_ID,
PHI_MODEL_ID,
transcribe_audio_local,
transcribe_audio_phi,
preload_models,
)
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Constants
EXAMPLES_DIR = Path("./examples")
MODEL_CHOICES = {
PHI_MODEL_ID: "Phi-4 Model",
WHISPER_MODEL_ID: "Whisper Model",
}
EXAMPLE_FILES = [
[str(EXAMPLES_DIR / "audio1.mp3"), PHI_MODEL_ID],
[str(EXAMPLES_DIR / "audio2.mp3"), PHI_MODEL_ID],
]
def read_file_as_base64(file_path: str) -> str:
"""
Read a file and encode it as base64.
Args:
file_path: Path to the file to read
Returns:
Base64 encoded string of file contents
"""
try:
with open(file_path, "rb") as f:
return base64.b64encode(f.read()).decode()
except Exception as e:
logger.error(f"Failed to read file {file_path}: {str(e)}")
raise
def combined_transcription(audio: str, model_choice: str) -> str:
"""
Transcribe audio using the selected model.
Args:
audio: Path to audio file
model_choice: Full model ID to use for transcription
Returns:
Transcription text
"""
if not audio:
return "Please provide an audio file to transcribe."
try:
if model_choice == PHI_MODEL_ID:
return transcribe_audio_phi(audio)
elif model_choice == WHISPER_MODEL_ID:
return transcribe_audio_local(audio)
else:
logger.error(f"Unknown model choice: {model_choice}")
return f"Error: Unknown model {model_choice}"
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
return f"Error during transcription: {str(e)}"
def create_demo() -> gr.Blocks:
"""Create and configure the Gradio demo interface"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# TWASR: Chinese (Taiwan) Automatic Speech Recognition")
gr.Markdown(
"Upload an audio file or record your voice to transcribe it to text."
)
gr.Markdown(
"⚠️ First load may take a while to initialize the model, following requests will be faster."
)
with gr.Row():
audio_input = gr.Audio(
label="Audio Input", type="filepath", show_download_button=True
)
with gr.Column():
model_choice = gr.Dropdown(
label="Select Model",
choices=list(MODEL_CHOICES.keys()),
value=PHI_MODEL_ID,
info="Select the model for transcription",
)
text_output = gr.Textbox(label="Transcription Output", lines=5)
with gr.Row():
transcribe_button = gr.Button("🎯 Transcribe", variant="primary")
clear_button = gr.Button("🧹 Clear")
transcribe_button.click(
fn=combined_transcription,
inputs=[audio_input, model_choice],
outputs=[text_output],
show_progress=True,
)
clear_button.click(
fn=lambda: (None, ""),
inputs=[],
outputs=[audio_input, text_output],
)
gr.Examples(
examples=EXAMPLE_FILES,
inputs=[audio_input, model_choice],
outputs=[text_output],
fn=combined_transcription,
cache_examples=True,
cache_mode="lazy",
run_on_click=True,
)
gr.Markdown("### Model Information")
with gr.Accordion("Model Details", open=False):
for model_id, model_name in MODEL_CHOICES.items():
gr.Markdown(
f"**{model_name}:** [{model_id}](https://huggingface.co/{model_id})"
)
return demo
if __name__ == "__main__":
# Preload models before starting the app to reduce cold start time
logger.info("Preloading models to reduce cold start time")
preload_models()
demo = create_demo()
demo.launch(share=False)