Spaces:
Sleeping
Sleeping
import gradio as gr | |
import librosa | |
import torch | |
from transformers import ( | |
Wav2Vec2ForCTC, Wav2Vec2Processor, | |
MarianMTModel, MarianTokenizer, | |
BertForSequenceClassification, AutoModel, AutoTokenizer | |
) | |
# Detect device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
### 🔹 Load Models & Tokenizers Once ### | |
# Wav2Vec2 for Darija transcription | |
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija" | |
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name) | |
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device) | |
# MarianMT for translation (Arabic → English) | |
translation_model_name = "Helsinki-NLP/opus-mt-ar-en" | |
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) | |
# AraBERT for Darija topic classification | |
arabert_model_name = "aubmindlab/bert-base-arabert" | |
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name) | |
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device) | |
# BERT for English topic classification | |
bert_model_name = "bert-base-uncased" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device) | |
# Libellés en Darija (Arabe et Latin) | |
darija_topic_labels = [ | |
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau | |
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet | |
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement | |
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits | |
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...) | |
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM | |
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique | |
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions | |
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information | |
"شكاية (Chikaya)", # Réclamation | |
"حاجة أخرى (Chi haja okhra)" # Autre | |
] | |
# Libellés en Anglais | |
english_topic_labels = [ | |
"Network Issue", | |
"Internet Issue", | |
"Billing & Payment Issue", | |
"Recharge & Plans", | |
"App Issue", | |
"SIM Card Issue", | |
"Technical Support", | |
"Offers & Promotions", | |
"General Inquiry", | |
"Complaint", | |
"Other" | |
] | |
def transcribe_audio(audio): | |
"""Convert audio to text, translate it, and classify topics in both Darija and English.""" | |
try: | |
# Load and preprocess audio | |
audio_array, sr = librosa.load(audio, sr=16000) | |
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device) | |
# Transcription (Darija) | |
with torch.no_grad(): | |
logits = wav2vec_model(input_values).logits | |
tokens = torch.argmax(logits, axis=-1) | |
transcription = processor.decode(tokens[0]) | |
# Translate to English | |
translation = translate_text(transcription) | |
# Classify topics | |
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels) | |
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels) | |
return transcription, translation, darija_topic, english_topic | |
except Exception as e: | |
return f"Error processing audio: {str(e)}", "", "", "" | |
def translate_text(text): | |
"""Translate Arabic text to English.""" | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
with torch.no_grad(): | |
translated_tokens = translation_model.generate(**inputs) | |
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
def classify_topic(text, tokenizer, model, topic_labels): | |
"""Classify topic using BERT-based models.""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other" | |
# 🔹 Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification") | |
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record") | |
submit_button = gr.Button("Process") | |
transcription_output = gr.Textbox(label="Transcription (Darija)") | |
translation_output = gr.Textbox(label="Translation (English)") | |
darija_topic_output = gr.Textbox(label="Darija Topic Classification") | |
english_topic_output = gr.Textbox(label="English Topic Classification") | |
submit_button.click(transcribe_audio, | |
inputs=[audio_input], | |
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output]) | |
demo.launch() | |