File size: 7,052 Bytes
e7d0ead
6cc8631
 
2785b50
 
 
4a98e32
2785b50
 
 
 
 
 
 
 
 
 
 
 
85e680f
 
2785b50
85e680f
2785b50
f2ecb6e
 
4a98e32
 
f2ecb6e
2785b50
f2ecb6e
 
4a98e32
f2ecb6e
9e3ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a98e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980dcf2
2785b50
f117ae1
 
 
2785b50
f117ae1
2785b50
f117ae1
2785b50
f117ae1
 
 
 
 
9078685
4a98e32
f117ae1
 
9078685
4a98e32
 
 
 
 
9078685
f117ae1
4a98e32
d04bf8d
 
2785b50
 
 
 
 
980dcf2
f2ecb6e
2785b50
 
f2ecb6e
 
 
 
 
 
2785b50
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
4a98e32
 
 
 
c7f40c9
f2ecb6e
 
4a98e32
 
 
d27a60f
980dcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import librosa
import torch
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor, 
    MarianMTModel, MarianTokenizer, 
    BertForSequenceClassification, AutoModel, AutoTokenizer
)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)

# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)


# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)  

# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
    "مشكيل ف الشبكة (Mochkil f réseau)",        # Problème de réseau
    "مشكيل ف الانترنت (Mochkil f internet)",      # Problème d'Internet
    "مشكيل ف الفاتورة (Mochkil f l'factura)",     # Problème de facturation et paiement
    "مشكيل ف التعبئة (Mochkil f l'recharge)",     # Problème de recharge et forfaits
    "مشكيل ف التطبيق (Mochkil f l'application)",  # Problème avec l’application (Orange et Moi...)
    "مشكيل ف بطاقة SIM (Mochkil f carte SIM)",    # Problème avec la carte SIM
    "مساعدة تقنية (Mosa3ada technique)",         # Assistance technique
    "العروض والتخفيضات (Offres w promotions)",   # Offres et promotions
    "طلب معلومات (Talab l'ma3loumat)",           # Demande d'information
    "شكاية (Chikaya)",                            # Réclamation
    "حاجة أخرى (Chi haja okhra)"                 # Autre
]

# Libellés en Anglais
english_topic_labels = [
    "Network Issue",
    "Internet Issue",
    "Billing & Payment Issue",
    "Recharge & Plans",
    "App Issue",
    "SIM Card Issue",
    "Technical Support",
    "Offers & Promotions",
    "General Inquiry",
    "Complaint",
    "Other"
]

# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, topic_labels):
    # Dictionnaire de mots-clés pour chaque topic
    keywords = {
        "خدمة العملاء": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
        "خدمة الاحتفاظ": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
        "مشكلة في الفاتورة": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"]
    }
    
    # Convertir le texte en minuscule pour éviter les incohérences
    text = text.lower()
    
    # Vérification de la présence des mots-clés dans le texte
    topic_scores = {label: 0 for label in topic_labels}  # Initialiser le score des topics

    for topic, words in keywords.items():
        for word in words:
            if word in text:
                topic_scores[topic] += 1  # Incrémenter le score pour chaque mot trouvé
    
    # Retourner le topic avec le score le plus élevé
    best_topic = max(topic_scores, key=topic_scores.get)
    return best_topic


def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English."""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)

        # Transcription (Darija)
        with torch.no_grad():
            logits = wav2vec_model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics using BERT models
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        # Classify topics using keywords-based classification
        darija_keyword_topic = classify_topic_by_keywords(transcription, darija_topic_labels)
        english_keyword_topic = classify_topic_by_keywords(translation, english_topic_labels)

        return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic

    except Exception as e:
        return f"Error processing audio: {str(e)}", "", "", "", "", ""

def translate_text(text):
    """Translate Arabic text to English."""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        translated_tokens = translation_model.generate(**inputs)
    return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"

# 🔹 Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
    english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
    darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
    english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, 
                                 darija_topic_output, english_topic_output, 
                                 darija_keyword_topic_output, english_keyword_topic_output])

demo.launch()