|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
|
import logging |
|
import traceback |
|
import sys |
|
from audio_processing import AudioProcessor |
|
import spaces |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
def load_qa_model(): |
|
"""Load question-answering model""" |
|
try: |
|
qa_pipeline = pipeline( |
|
"text-generation", |
|
model="meta-llama/Meta-Llama-3-8B-Instruct", |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
use_auth_token=os.getenv("HF_TOKEN") |
|
) |
|
return qa_pipeline |
|
except Exception as e: |
|
logger.error(f"Failed to load Q&A model: {str(e)}") |
|
return None |
|
|
|
def load_summarization_model(): |
|
"""Load summarization model""" |
|
try: |
|
summarizer = pipeline( |
|
"summarization", |
|
model="sshleifer/distilbart-cnn-12-6", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
return summarizer |
|
except Exception as e: |
|
logger.error(f"Failed to load summarization model: {str(e)}") |
|
return None |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def process_audio(audio_file, translate=False): |
|
"""Process audio file""" |
|
try: |
|
processor = AudioProcessor() |
|
language_segments, final_segments = processor.process_audio(audio_file, translate) |
|
|
|
|
|
transcription = "" |
|
full_text = "" |
|
|
|
|
|
for segment in language_segments: |
|
transcription += f"Language: {segment['language']}\n" |
|
transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n" |
|
|
|
|
|
transcription += "Transcription with language detection:\n\n" |
|
for segment in final_segments: |
|
transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}):\n" |
|
transcription += f"Original: {segment['text']}\n" |
|
if translate and 'translated' in segment: |
|
transcription += f"Translated: {segment['translated']}\n" |
|
full_text += segment['translated'] + " " |
|
else: |
|
full_text += segment['text'] + " " |
|
transcription += "\n" |
|
|
|
return transcription, full_text |
|
|
|
except Exception as e: |
|
logger.error(f"Audio processing failed: {str(e)}") |
|
raise gr.Error(f"Processing failed: {str(e)}") |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def summarize_text(text): |
|
"""Summarize text""" |
|
try: |
|
summarizer = load_summarization_model() |
|
if summarizer is None: |
|
return "Summarization model could not be loaded." |
|
|
|
summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text'] |
|
return summary |
|
except Exception as e: |
|
logger.error(f"Summarization failed: {str(e)}") |
|
return "Error occurred during summarization." |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def answer_question(context, question): |
|
"""Answer questions about the text""" |
|
try: |
|
qa_pipeline = load_qa_model() |
|
if qa_pipeline is None: |
|
return "Q&A model could not be loaded." |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."}, |
|
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"} |
|
] |
|
|
|
response = qa_pipeline(messages, max_new_tokens=256)[0]['generated_text'] |
|
return response |
|
except Exception as e: |
|
logger.error(f"Q&A failed: {str(e)}") |
|
return f"Error occurred during Q&A process: {str(e)}" |
|
|
|
|
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# Automatic Speech Recognition for Indic Languages") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio(type="filepath") |
|
translate_checkbox = gr.Checkbox(label="Enable Translation") |
|
process_button = gr.Button("Process Audio") |
|
|
|
with gr.Column(): |
|
transcription_output = gr.Textbox(label="Transcription/Translation", lines=10) |
|
full_text_output = gr.Textbox(label="Full Text", lines=5) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
summarize_button = gr.Button("Summarize") |
|
summary_output = gr.Textbox(label="Summary", lines=3) |
|
|
|
with gr.Column(): |
|
question_input = gr.Textbox(label="Ask a question about the transcription") |
|
answer_button = gr.Button("Get Answer") |
|
answer_output = gr.Textbox(label="Answer", lines=3) |
|
|
|
|
|
process_button.click( |
|
process_audio, |
|
inputs=[audio_input, translate_checkbox], |
|
outputs=[transcription_output, full_text_output] |
|
) |
|
|
|
summarize_button.click( |
|
summarize_text, |
|
inputs=[full_text_output], |
|
outputs=[summary_output] |
|
) |
|
|
|
answer_button.click( |
|
answer_question, |
|
inputs=[full_text_output, question_input], |
|
outputs=[answer_output] |
|
) |
|
|
|
|
|
gr.Markdown(f""" |
|
## System Information |
|
- Device: {"CUDA" if torch.cuda.is_available() else "CPU"} |
|
- CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} |
|
|
|
## Features |
|
- Automatic language detection |
|
- High-quality transcription using MMS |
|
- Optional translation to English |
|
- Text summarization |
|
- Question answering |
|
""") |
|
|
|
if __name__ == "__main__": |
|
iface.launch(server_port=None) |