File size: 4,134 Bytes
e7d0ead
6cc8631
 
2785b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85e680f
 
2785b50
85e680f
2785b50
f2ecb6e
 
2785b50
f2ecb6e
2785b50
f2ecb6e
 
2785b50
f2ecb6e
2785b50
 
 
f117ae1
980dcf2
2785b50
f117ae1
 
 
2785b50
f117ae1
2785b50
f117ae1
2785b50
f117ae1
 
 
 
 
9078685
2785b50
f117ae1
 
9078685
f117ae1
9078685
f117ae1
2785b50
d04bf8d
 
2785b50
 
 
 
 
980dcf2
f2ecb6e
2785b50
 
f2ecb6e
 
 
 
 
 
2785b50
1a38424
f2ecb6e
9078685
 
f2ecb6e
 
9078685
 
f2ecb6e
 
c7f40c9
f2ecb6e
 
 
d27a60f
980dcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import librosa
import torch
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor, 
    MarianMTModel, MarianTokenizer, 
    BertForSequenceClassification, AutoModel, AutoTokenizer
)

# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"

### πŸ”Ή Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)

# MarianMT for translation (Arabic β†’ English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)

# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name).to(device)

# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)  

# Define Topic Labels
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"]

def transcribe_audio(audio):
    """Convert audio to text, translate it, and classify topics in both Darija and English."""
    try:
        # Load and preprocess audio
        audio_array, sr = librosa.load(audio, sr=16000)
        input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)

        # Transcription (Darija)
        with torch.no_grad():
            logits = wav2vec_model(input_values).logits
        tokens = torch.argmax(logits, axis=-1)
        transcription = processor.decode(tokens[0])

        # Translate to English
        translation = translate_text(transcription)

        # Classify topics
        darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
        english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)

        return transcription, translation, darija_topic, english_topic

    except Exception as e:
        return f"Error processing audio: {str(e)}", "", "", ""

def translate_text(text):
    """Translate Arabic text to English."""
    inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        translated_tokens = translation_model.generate(**inputs)
    return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

def classify_topic(text, tokenizer, model, topic_labels):
    """Classify topic using BERT-based models."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"

# πŸ”Ή Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# πŸŽ™οΈ Speech-to-Text, Translation & Topic Classification")
    
    audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
    submit_button = gr.Button("Process")

    transcription_output = gr.Textbox(label="Transcription (Darija)")
    translation_output = gr.Textbox(label="Translation (English)")
    darija_topic_output = gr.Textbox(label="Darija Topic Classification")
    english_topic_output = gr.Textbox(label="English Topic Classification")

    submit_button.click(transcribe_audio, 
                        inputs=[audio_input], 
                        outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])

demo.launch()