Spaces:
Build error
Build error
import gradio as gr | |
from audio_processing import process_audio, load_models | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, pipeline | |
import spaces | |
import torch | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Check if CUDA is available | |
cuda_available = torch.cuda.is_available() | |
device = "cuda" if cuda_available else "cpu" | |
logger.info(f"Using device: {device}") | |
# Initialize model variables | |
summarizer_model = None | |
summarizer_tokenizer = None | |
qa_model = None | |
qa_tokenizer = None | |
# Load Whisper model | |
print("Loading Whisper model...") | |
try: | |
load_models() # Load Whisper model | |
except Exception as e: | |
logger.error(f"Error loading Whisper model: {str(e)}") | |
raise | |
print("Whisper model loaded successfully.") | |
def load_summarization_model(): | |
global summarizer_model, summarizer_tokenizer | |
if summarizer_model is None: | |
logger.info("Loading summarization model...") | |
summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(device) | |
summarizer_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
logger.info("Summarization model loaded.") | |
def load_qa_model(): | |
global qa_model, qa_tokenizer | |
if qa_model is None: | |
logger.info("Loading QA model...") | |
qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device) | |
qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad") | |
logger.info("QA model loaded.") | |
def transcribe_audio(audio_file, translate, model_size, use_diarization): | |
language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size, use_diarization=use_diarization) | |
output = "Detected language changes:\n\n" | |
for segment in language_segments: | |
output += f"Language: {segment['language']}\n" | |
output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n" | |
output += f"Transcription with language detection {f'and speaker diarization' if use_diarization else ''} (using {model_size} model):\n\n" | |
full_text = "" | |
for segment in final_segments: | |
output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']})" | |
if use_diarization: | |
output += f" {segment['speaker']}:" | |
output += f"\nOriginal: {segment['text']}\n" | |
if translate: | |
output += f"Translated: {segment['translated']}\n" | |
full_text += segment['translated'] + " " | |
else: | |
full_text += segment['text'] + " " | |
output += "\n" | |
return output, full_text | |
def summarize_text(text): | |
load_summarization_model() | |
inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device) | |
summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False) | |
summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def answer_question(context, question): | |
load_qa_model() | |
inputs = qa_tokenizer(question, context, return_tensors="pt").to(device) | |
outputs = qa_model(**inputs) | |
answer_start = torch.argmax(outputs.start_logits) | |
answer_end = torch.argmax(outputs.end_logits) + 1 | |
answer = qa_tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end]) | |
return answer | |
def process_and_summarize(audio_file, translate, model_size, use_diarization): | |
transcription, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization) | |
summary = summarize_text(full_text) | |
return transcription, summary | |
def qa_interface(audio_file, translate, model_size, use_diarization, question): | |
_, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization) | |
answer = answer_question(full_text, question) | |
return answer | |
# Main interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and QA (with ZeroGPU support)") | |
with gr.Tab("Transcribe and Summarize"): | |
audio_input = gr.Audio(type="filepath") | |
translate_checkbox = gr.Checkbox(label="Enable Translation") | |
model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small") | |
diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization") | |
transcribe_button = gr.Button("Transcribe and Summarize") | |
transcription_output = gr.Textbox(label="Transcription") | |
summary_output = gr.Textbox(label="Summary") | |
transcribe_button.click( | |
process_and_summarize, | |
inputs=[audio_input, translate_checkbox, model_dropdown, diarization_checkbox], | |
outputs=[transcription_output, summary_output] | |
) | |
with gr.Tab("Question Answering"): | |
qa_audio_input = gr.Audio(type="filepath") | |
qa_translate_checkbox = gr.Checkbox(label="Enable Translation") | |
qa_model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small") | |
qa_diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization") | |
question_input = gr.Textbox(label="Ask a question about the audio") | |
qa_button = gr.Button("Get Answer") | |
answer_output = gr.Textbox(label="Answer") | |
qa_button.click( | |
qa_interface, | |
inputs=[qa_audio_input, qa_translate_checkbox, qa_model_dropdown, qa_diarization_checkbox, question_input], | |
outputs=answer_output | |
) | |
gr.Markdown( | |
f""" | |
## System Information | |
- Device: {device} | |
- CUDA Available: {"Yes" if cuda_available else "No"} | |
## ZeroGPU Support | |
This application supports ZeroGPU for Hugging Face Spaces pro users. | |
GPU-intensive tasks are automatically optimized for better performance when available. | |
""" | |
) | |
iface.launch() |