|
import gradio as gr |
|
import torch |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline |
|
import soundfile as sf |
|
|
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") |
|
transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") |
|
|
|
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
|
|
|
def transcribe_audio(audio_path): |
|
audio_input, sample_rate = sf.read(audio_path) |
|
if sample_rate != 16000: |
|
raise ValueError("Audio must be sampled at 16kHz.") |
|
inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = transcription_model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.batch_decode(predicted_ids)[0] |
|
return transcription |
|
|
|
|
|
def filter_text_by_keywords(text, keywords): |
|
keyword_list = keywords.split(",") |
|
filtered_sentences = [ |
|
sentence for sentence in text.split(". ") if any(keyword.strip().lower() in sentence.lower() for keyword in keyword_list) |
|
] |
|
return ". ".join(filtered_sentences) if filtered_sentences else text |
|
|
|
|
|
def transcribe_and_summarize(audio_file, keywords): |
|
transcription = transcribe_audio(audio_file) |
|
filtered_text = filter_text_by_keywords(transcription, keywords) |
|
summary = summarizer(filtered_text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"] |
|
return transcription, summary |
|
|
|
|
|
inputs = [ |
|
gr.Audio(type="filepath", label="Upload your audio file"), |
|
gr.Textbox(label="Enter Keywords (comma-separated)", placeholder="e.g., customer, service, retention") |
|
] |
|
outputs = [ |
|
gr.Textbox(label="Transcription"), |
|
gr.Textbox(label="Summary") |
|
] |
|
|
|
app = gr.Interface( |
|
fn=transcribe_and_summarize, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Moroccan Darija Audio Transcription and Summarization", |
|
description=( |
|
"Upload an audio file in Moroccan Darija to get its transcription and a summarized version. " |
|
"Specify relevant keywords (comma-separated) to filter the transcription before summarization." |
|
) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|