File size: 3,753 Bytes
e7d0ead
6cc8631
 
f2ecb6e
980dcf2
9078685
980dcf2
 
 
85e680f
 
 
 
 
f2ecb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
980dcf2
f2ecb6e
980dcf2
 
9078685
e7d0ead
 
9078685
980dcf2
d04bf8d
9078685
f2ecb6e
 
 
 
 
 
d04bf8d
 
9078685
 
 
 
d04bf8d
980dcf2
f2ecb6e
 
 
 
 
 
 
 
 
 
9078685
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
f2ecb6e
 
c7f40c9
f2ecb6e
 
 
d27a60f
980dcf2
f2ecb6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer

# Charger le modèle de transcription pour le Darija
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")

# Charger le modèle de traduction Arabe -> Anglais
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_model = MarianMTModel.from_pretrained(translation_model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)



# Load AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name)

# Load BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3)  # Adjust labels as needed

darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]  # Adjust for Darija topics
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"]  # Adjust for English topics


def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English"""
    audio_array, sr = librosa.load(audio, sr=16000)
    input_values = processor(audio_array, return_tensors="pt", padding=True).input_values

    logits = model(input_values).logits
    tokens = torch.argmax(logits, axis=-1)

    transcription = processor.decode(tokens[0])
    translation = translate_text(transcription)

    # Classify topics for both Darija and English
    darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
    english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

    return transcription, translation, darija_topic, english_topic


def translate_text(text):
    """Traduire le texte de l'arabe vers l'anglais"""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    translated_tokens = translation_model.generate(**inputs)
    translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    return translated_text

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"


# Interface utilisateur avec Gradio
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification")
    english_topic_output = gr.Textbox(label="English Topic Classification")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])

demo.launch()