File size: 6,217 Bytes
c569b48
a314490
6db9237
aaac499
 
3a346c4
 
 
 
aaac499
 
 
9148c64
3a346c4
aaac499
6db9237
 
 
 
 
3a346c4
6db9237
 
3a346c4
6db9237
3a346c4
6db9237
3a346c4
 
6db9237
 
 
 
 
 
 
 
 
aaac499
6db9237
 
 
 
 
 
 
aaac499
64f2bf5
6db9237
 
81e4ee2
 
 
 
 
f36e52e
6db9237
aaac499
81e4ee2
6db9237
 
 
 
aaac499
 
 
 
 
 
 
 
 
64f2bf5
aaac499
6db9237
9148c64
 
 
aaac499
 
64f2bf5
aaac499
6db9237
9148c64
 
 
 
 
 
aaac499
64f2bf5
6db9237
 
aaac499
 
 
64f2bf5
6db9237
 
aaac499
 
 
 
 
 
 
 
 
 
 
6db9237
aaac499
 
 
 
 
 
6db9237
aaac499
 
 
 
 
 
 
6db9237
aaac499
 
 
 
 
 
6db9237
aaac499
 
 
 
3a346c4
 
 
 
 
aaac499
 
3a346c4
aaac499
 
f36e52e
81e4ee2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
from audio_processing import process_audio, load_models
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, pipeline
import spaces
import torch
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check if CUDA is available
cuda_available = torch.cuda.is_available()
device = "cuda" if cuda_available else "cpu"
logger.info(f"Using device: {device}")

# Initialize model variables
summarizer_model = None
summarizer_tokenizer = None
qa_model = None
qa_tokenizer = None

# Load Whisper model
print("Loading Whisper model...")
try:
    load_models()  # Load Whisper model
except Exception as e:
    logger.error(f"Error loading Whisper model: {str(e)}")
    raise

print("Whisper model loaded successfully.")

def load_summarization_model():
    global summarizer_model, summarizer_tokenizer
    if summarizer_model is None:
        logger.info("Loading summarization model...")
        summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(device)
        summarizer_tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
        logger.info("Summarization model loaded.")

def load_qa_model():
    global qa_model, qa_tokenizer
    if qa_model is None:
        logger.info("Loading QA model...")
        qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad").to(device)
        qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
        logger.info("QA model loaded.")

@spaces.GPU
def transcribe_audio(audio_file, translate, model_size, use_diarization):
    language_segments, final_segments = process_audio(audio_file, translate=translate, model_size=model_size, use_diarization=use_diarization)
    
    output = "Detected language changes:\n\n"
    for segment in language_segments:
        output += f"Language: {segment['language']}\n"
        output += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n"

    output += f"Transcription with language detection {f'and speaker diarization' if use_diarization else ''} (using {model_size} model):\n\n"
    full_text = ""
    for segment in final_segments:
        output += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']})"
        if use_diarization:
            output += f" {segment['speaker']}:"
        output += f"\nOriginal: {segment['text']}\n"
        if translate:
            output += f"Translated: {segment['translated']}\n"
            full_text += segment['translated'] + " "
        else:
            full_text += segment['text'] + " "
        output += "\n"
    
    return output, full_text

@spaces.GPU
def summarize_text(text):
    load_summarization_model()
    inputs = summarizer_tokenizer(text, max_length=1024, truncation=True, return_tensors="pt").to(device)
    summary_ids = summarizer_model.generate(inputs["input_ids"], max_length=150, min_length=50, do_sample=False)
    summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

@spaces.GPU
def answer_question(context, question):
    load_qa_model()
    inputs = qa_tokenizer(question, context, return_tensors="pt").to(device)
    outputs = qa_model(**inputs)
    answer_start = torch.argmax(outputs.start_logits)
    answer_end = torch.argmax(outputs.end_logits) + 1
    answer = qa_tokenizer.decode(inputs["input_ids"][0][answer_start:answer_end])
    return answer

@spaces.GPU
def process_and_summarize(audio_file, translate, model_size, use_diarization):
    transcription, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
    summary = summarize_text(full_text)
    return transcription, summary

@spaces.GPU
def qa_interface(audio_file, translate, model_size, use_diarization, question):
    _, full_text = transcribe_audio(audio_file, translate, model_size, use_diarization)
    answer = answer_question(full_text, question)
    return answer

# Main interface
with gr.Blocks() as iface:
    gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and QA (with ZeroGPU support)")
    
    with gr.Tab("Transcribe and Summarize"):
        audio_input = gr.Audio(type="filepath")
        translate_checkbox = gr.Checkbox(label="Enable Translation")
        model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
        diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
        transcribe_button = gr.Button("Transcribe and Summarize")
        transcription_output = gr.Textbox(label="Transcription")
        summary_output = gr.Textbox(label="Summary")
        
        transcribe_button.click(
            process_and_summarize,
            inputs=[audio_input, translate_checkbox, model_dropdown, diarization_checkbox],
            outputs=[transcription_output, summary_output]
        )
    
    with gr.Tab("Question Answering"):
        qa_audio_input = gr.Audio(type="filepath")
        qa_translate_checkbox = gr.Checkbox(label="Enable Translation")
        qa_model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small")
        qa_diarization_checkbox = gr.Checkbox(label="Enable Speaker Diarization")
        question_input = gr.Textbox(label="Ask a question about the audio")
        qa_button = gr.Button("Get Answer")
        answer_output = gr.Textbox(label="Answer")
        
        qa_button.click(
            qa_interface,
            inputs=[qa_audio_input, qa_translate_checkbox, qa_model_dropdown, qa_diarization_checkbox, question_input],
            outputs=answer_output
        )

    gr.Markdown(
        f"""
        ## System Information
        - Device: {device}
        - CUDA Available: {"Yes" if cuda_available else "No"}
        
        ## ZeroGPU Support
        This application supports ZeroGPU for Hugging Face Spaces pro users. 
        GPU-intensive tasks are automatically optimized for better performance when available.
        """
    )

iface.launch()