Spaces:
Sleeping
Sleeping
File size: 7,936 Bytes
e7d0ead 6cc8631 2785b50 4a98e32 2785b50 85e680f 2785b50 85e680f 2785b50 f2ecb6e 4a98e32 f2ecb6e 2785b50 f2ecb6e 4a98e32 f2ecb6e 9e3ffca 4a98e32 b6cc6ac 4a98e32 b6cc6ac 4a98e32 b6cc6ac 4a98e32 b6cc6ac 980dcf2 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 9078685 4a98e32 f117ae1 9078685 4a98e32 b6cc6ac 4a98e32 9078685 f117ae1 4a98e32 d04bf8d 2785b50 980dcf2 f2ecb6e 2785b50 f2ecb6e 2785b50 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 4a98e32 c7f40c9 f2ecb6e 4a98e32 d27a60f 980dcf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
import librosa
import torch
from transformers import (
Wav2Vec2ForCTC, Wav2Vec2Processor,
MarianMTModel, MarianTokenizer,
BertForSequenceClassification, AutoModel, AutoTokenizer
)
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...)
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information
"شكاية (Chikaya)", # Réclamation
"حاجة أخرى (Chi haja okhra)" # Autre
]
# Libellés en Anglais
english_topic_labels = [
"Network Issue",
"Internet Issue",
"Billing & Payment Issue",
"Recharge & Plans",
"App Issue",
"SIM Card Issue",
"Technical Support",
"Offers & Promotions",
"General Inquiry",
"Complaint",
"Other"
]
# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, language='ar'):
# Arabic keywords for each topic
arabic_keywords = {
"Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
"Retention Service": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
"Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"],
"Other": ["شيء آخر", "غير ذلك", "أخرى"]
}
# English keywords for each topic
english_keywords = {
"Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"],
"Retention Service": ["retain", "cut", "discount", "offer", "promotion","stop"],
"Billing Issue": ["bill", "payment", "problem", "error", "amount"],
"Other": ["other", "none of the above", "something else"]
}
# Select the appropriate keywords based on the language
if language == 'ar':
keywords = arabic_keywords
elif language == 'en':
keywords = english_keywords
else:
raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.")
# Convert text to lowercase to avoid inconsistencies
text = text.lower()
# Check for keywords in the text and calculate the topic scores
topic_scores = {topic: 0 for topic in keywords} # Initialize topic scores
for topic, words in keywords.items():
for word in words:
if word in text:
topic_scores[topic] += 1 # Increment score for each keyword found
# Check if no keywords are found, and in that case, return "Other"
if all(score == 0 for score in topic_scores.values()):
return "Other"
# Return the topic with the highest score
best_topic = max(topic_scores, key=topic_scores.get)
return best_topic
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
# Transcription (Darija)
with torch.no_grad():
logits = wav2vec_model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics using BERT models
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
# Classify topics using keywords-based classification
darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' )
english_keyword_topic = classify_topic_by_keywords(transcription,language='en' )
#english_keyword_topic = classify_topic_by_keywords(translation )
return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic
except Exception as e:
return f"Error processing audio: {str(e)}", "", "", "", "", ""
def translate_text(text):
"""Translate Arabic text to English."""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated_tokens = translation_model.generate(**inputs)
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# 🔹 Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output,
darija_topic_output, english_topic_output,
darija_keyword_topic_output, english_keyword_topic_output])
demo.launch()
|