File size: 7,936 Bytes
e7d0ead
6cc8631
 
2785b50
 
 
4a98e32
2785b50
 
 
 
 
 
 
 
 
 
 
 
85e680f
 
2785b50
85e680f
2785b50
f2ecb6e
 
4a98e32
 
f2ecb6e
2785b50
f2ecb6e
 
4a98e32
f2ecb6e
9e3ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a98e32
b6cc6ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a98e32
b6cc6ac
 
 
4a98e32
 
 
 
b6cc6ac
 
 
 
 
 
 
4a98e32
 
 
 
b6cc6ac
 
980dcf2
2785b50
f117ae1
 
 
2785b50
f117ae1
2785b50
f117ae1
2785b50
f117ae1
 
 
 
 
9078685
4a98e32
f117ae1
 
9078685
4a98e32
b6cc6ac
 
 
4a98e32
 
9078685
f117ae1
4a98e32
d04bf8d
 
2785b50
 
 
 
 
980dcf2
f2ecb6e
2785b50
 
f2ecb6e
 
 
 
 
 
2785b50
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
4a98e32
 
 
 
c7f40c9
f2ecb6e
 
4a98e32
 
 
d27a60f
980dcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import gradio as gr
import librosa
import torch
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor, 
    MarianMTModel, MarianTokenizer, 
    BertForSequenceClassification, AutoModel, AutoTokenizer
)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)

# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)


# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)  

# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
    "مشكيل ف الشبكة (Mochkil f réseau)",        # Problème de réseau
    "مشكيل ف الانترنت (Mochkil f internet)",      # Problème d'Internet
    "مشكيل ف الفاتورة (Mochkil f l'factura)",     # Problème de facturation et paiement
    "مشكيل ف التعبئة (Mochkil f l'recharge)",     # Problème de recharge et forfaits
    "مشكيل ف التطبيق (Mochkil f l'application)",  # Problème avec l’application (Orange et Moi...)
    "مشكيل ف بطاقة SIM (Mochkil f carte SIM)",    # Problème avec la carte SIM
    "مساعدة تقنية (Mosa3ada technique)",         # Assistance technique
    "العروض والتخفيضات (Offres w promotions)",   # Offres et promotions
    "طلب معلومات (Talab l'ma3loumat)",           # Demande d'information
    "شكاية (Chikaya)",                            # Réclamation
    "حاجة أخرى (Chi haja okhra)"                 # Autre
]

# Libellés en Anglais
english_topic_labels = [
    "Network Issue",
    "Internet Issue",
    "Billing & Payment Issue",
    "Recharge & Plans",
    "App Issue",
    "SIM Card Issue",
    "Technical Support",
    "Offers & Promotions",
    "General Inquiry",
    "Complaint",
    "Other"
]

# New Function to Classify Topics by Keywords
def classify_topic_by_keywords(text, language='ar'):
    # Arabic keywords for each topic
    arabic_keywords = {
        "Customer Service": ["خدمة", "استفسار", "مساعدة", "دعم", "سؤال", "استفسار"],
        "Retention Service": ["احتفاظ", "تجديد", "خصم", "عرض", "العرض"],
        "Billing Issue": ["فاتورة", "دفع", "مشكلة", "خطأ", "مبلغ"],
        "Other": ["شيء آخر", "غير ذلك", "أخرى"]
    }

    # English keywords for each topic
    english_keywords = {
        "Customer Service": ["service", "inquiry", "help", "support", "question", "assistance"],
        "Retention Service": ["retain", "cut", "discount", "offer", "promotion","stop"],
        "Billing Issue": ["bill", "payment", "problem", "error", "amount"],
        "Other": ["other", "none of the above", "something else"]
    }

    # Select the appropriate keywords based on the language
    if language == 'ar':
        keywords = arabic_keywords
    elif language == 'en':
        keywords = english_keywords
    else:
        raise ValueError("Invalid language specified. Use 'ar' for Arabic or 'en' for English.")

    # Convert text to lowercase to avoid inconsistencies
    text = text.lower()

    # Check for keywords in the text and calculate the topic scores
    topic_scores = {topic: 0 for topic in keywords}  # Initialize topic scores

    for topic, words in keywords.items():
        for word in words:
            if word in text:
                topic_scores[topic] += 1  # Increment score for each keyword found

    # Check if no keywords are found, and in that case, return "Other"
    if all(score == 0 for score in topic_scores.values()):
        return "Other"

    # Return the topic with the highest score
    best_topic = max(topic_scores, key=topic_scores.get)
    return best_topic




def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English."""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)

        # Transcription (Darija)
        with torch.no_grad():
            logits = wav2vec_model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics using BERT models
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        # Classify topics using keywords-based classification
        darija_keyword_topic = classify_topic_by_keywords(transcription,language='ar' )
        english_keyword_topic = classify_topic_by_keywords(transcription,language='en' )
        #english_keyword_topic = classify_topic_by_keywords(translation )

        return transcription, translation, darija_topic, english_topic, darija_keyword_topic, english_keyword_topic

    except Exception as e:
        return f"Error processing audio: {str(e)}", "", "", "", "", ""

def translate_text(text):
    """Translate Arabic text to English."""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        translated_tokens = translation_model.generate(**inputs)
    return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"

# 🔹 Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification (BERT)")
    english_topic_output = gr.Textbox(label="English Topic Classification (BERT)")
    darija_keyword_topic_output = gr.Textbox(label="Darija Topic Classification (Keywords)")
    english_keyword_topic_output = gr.Textbox(label="English Topic Classification (Keywords)")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, 
                                 darija_topic_output, english_topic_output, 
                                 darija_keyword_topic_output, english_keyword_topic_output])

demo.launch()