Spaces:
Running
Running
File size: 7,052 Bytes
e7d0ead 6cc8631 2785b50 4a98e32 2785b50 85e680f 2785b50 85e680f 2785b50 f2ecb6e 4a98e32 f2ecb6e 2785b50 f2ecb6e 4a98e32 f2ecb6e 9e3ffca 4a98e32 980dcf2 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 9078685 4a98e32 f117ae1 9078685 4a98e32 9078685 f117ae1 4a98e32 d04bf8d 2785b50 980dcf2 f2ecb6e 2785b50 f2ecb6e 2785b50 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 4a98e32 c7f40c9 f2ecb6e 4a98e32 d27a60f 980dcf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
import librosa
import torch
from transformers import (
Wav2Vec2ForCTC, Wav2Vec2Processor,
MarianMTModel, MarianTokenizer,
BertForSequenceClassification, AutoModel, AutoTokenizer
)
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...)
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information
"شكاية (Chikaya)", # Réclamation
"حاجة أخرى (Chi haja okhra)" # Autre
]
# Libellés en Anglais
english_topic_labels = [
"Network Issue",
"Internet Issue",
"Billing & Payment Issue",
"Recharge & Plans",
"App Issue",
"SIM Card Issue",
"Technical Support",
"Offers & Promotions",
"General Inquiry",
"Complaint",
"Other"
]
# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, topic_labels):
# Dictionnaire de mots-clés pour chaque topic
keywords = {
"خدمة العملاء": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
"خدمة الاحتفاظ": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
"مشكلة في الفاتورة": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"]
}
# Convertir le texte en minuscule pour éviter les incohérences
text = text.lower()
# Vérification de la présence des mots-clés dans le texte
topic_scores = {label: 0 for label in topic_labels} # Initialiser le score des topics
for topic, words in keywords.items():
for word in words:
if word in text:
topic_scores[topic] += 1 # Incrémenter le score pour chaque mot trouvé
# Retourner le topic avec le score le plus élevé
best_topic = max(topic_scores, key=topic_scores.get)
return best_topic
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
# Transcription (Darija)
with torch.no_grad():
logits = wav2vec_model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics using BERT models
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
# Classify topics using keywords-based classification
darija_keyword_topic = classify_topic_by_keywords(transcription, darija_topic_labels)
english_keyword_topic = classify_topic_by_keywords(translation, english_topic_labels)
return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic
except Exception as e:
return f"Error processing audio: {str(e)}", "", "", "", "", ""
def translate_text(text):
"""Translate Arabic text to English."""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated_tokens = translation_model.generate(**inputs)
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# 🔹 Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output,
darija_topic_output, english_topic_output,
darija_keyword_topic_output, english_keyword_topic_output])
demo.launch()
|