File size: 5,308 Bytes
e7d0ead
6cc8631
 
2785b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85e680f
 
2785b50
85e680f
2785b50
f2ecb6e
 
2855355
 
f2ecb6e
2785b50
f2ecb6e
 
2785b50
f2ecb6e
9e3ffca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f117ae1
980dcf2
2785b50
f117ae1
 
 
2785b50
f117ae1
2785b50
f117ae1
2785b50
f117ae1
 
 
 
 
9078685
2785b50
f117ae1
 
9078685
f117ae1
9078685
f117ae1
2785b50
d04bf8d
 
2785b50
 
 
 
 
980dcf2
f2ecb6e
2785b50
 
f2ecb6e
 
 
 
 
 
2785b50
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
f2ecb6e
 
c7f40c9
f2ecb6e
 
 
d27a60f
980dcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import librosa
import torch
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor, 
    MarianMTModel, MarianTokenizer, 
    BertForSequenceClassification, AutoModel, AutoTokenizer
)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)

# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)


# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)  

# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
    "مشكيل ف الشبكة (Mochkil f réseau)",        # Problème de réseau
    "مشكيل ف الانترنت (Mochkil f internet)",      # Problème d'Internet
    "مشكيل ف الفاتورة (Mochkil f l'factura)",     # Problème de facturation et paiement
    "مشكيل ف التعبئة (Mochkil f l'recharge)",     # Problème de recharge et forfaits
    "مشكيل ف التطبيق (Mochkil f l'application)",  # Problème avec l’application (Orange et Moi...)
    "مشكيل ف بطاقة SIM (Mochkil f carte SIM)",    # Problème avec la carte SIM
    "مساعدة تقنية (Mosa3ada technique)",         # Assistance technique
    "العروض والتخفيضات (Offres w promotions)",   # Offres et promotions
    "طلب معلومات (Talab l'ma3loumat)",           # Demande d'information
    "شكاية (Chikaya)",                            # Réclamation
    "حاجة أخرى (Chi haja okhra)"                 # Autre
]

# Libellés en Anglais
english_topic_labels = [
    "Network Issue",
    "Internet Issue",
    "Billing & Payment Issue",
    "Recharge & Plans",
    "App Issue",
    "SIM Card Issue",
    "Technical Support",
    "Offers & Promotions",
    "General Inquiry",
    "Complaint",
    "Other"
]


def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English."""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)

        # Transcription (Darija)
        with torch.no_grad():
            logits = wav2vec_model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        return transcription, translation, darija_topic, english_topic

    except Exception as e:
        return f"Error processing audio: {str(e)}", "", "", ""

def translate_text(text):
    """Translate Arabic text to English."""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        translated_tokens = translation_model.generate(**inputs)
    return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"

# 🔹 Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification")
    english_topic_output = gr.Textbox(label="English Topic Classification")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])

demo.launch()