File size: 4,367 Bytes
e7d0ead
6cc8631
 
f2ecb6e
980dcf2
9078685
980dcf2
 
 
85e680f
 
 
 
 
f2ecb6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f117ae1
 
980dcf2
f2ecb6e
f117ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9078685
f117ae1
 
 
9078685
f117ae1
9078685
f117ae1
 
 
f2ecb6e
 
d04bf8d
 
9078685
 
 
 
d04bf8d
980dcf2
f2ecb6e
 
 
 
 
 
 
 
 
 
9078685
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
f2ecb6e
 
c7f40c9
f2ecb6e
 
 
d27a60f
980dcf2
f2ecb6e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer

# Charger le modèle de transcription pour le Darija
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")

# Charger le modèle de traduction Arabe -> Anglais
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_model = MarianMTModel.from_pretrained(translation_model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)



# Load AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name)

# Load BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3)  # Adjust labels as needed

darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]  # Adjust for Darija topics
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"]  # Adjust for English topics


import torch

def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English"""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        
        # Ensure correct sampling rate
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values

        # Move to GPU if available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        input_values = input_values.to(device)

        # Get predictions from Wav2Vec2 model
        with torch.no_grad():
            logits = model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)

        # Decode transcription (Darija)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics for Darija and English
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        return transcription, translation, darija_topic, english_topic

    except Exception as e:
        print(f"Error in transcription: {e}")
        return "Error processing audio", "", "", ""



def translate_text(text):
    """Traduire le texte de l'arabe vers l'anglais"""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    translated_tokens = translation_model.generate(**inputs)
    translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    return translated_text

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"


# Interface utilisateur avec Gradio
with gr.Blocks() as demo:
    gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification")
    english_topic_output = gr.Textbox(label="English Topic Classification")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])

demo.launch()