Spaces:
Sleeping
Sleeping
File size: 5,308 Bytes
e7d0ead 6cc8631 2785b50 85e680f 2785b50 85e680f 2785b50 f2ecb6e 2855355 f2ecb6e 2785b50 f2ecb6e 2785b50 f2ecb6e 9e3ffca f117ae1 980dcf2 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 9078685 2785b50 f117ae1 9078685 f117ae1 9078685 f117ae1 2785b50 d04bf8d 2785b50 980dcf2 f2ecb6e 2785b50 f2ecb6e 2785b50 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 f2ecb6e c7f40c9 f2ecb6e d27a60f 980dcf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import librosa
import torch
from transformers import (
Wav2Vec2ForCTC, Wav2Vec2Processor,
MarianMTModel, MarianTokenizer,
BertForSequenceClassification, AutoModel, AutoTokenizer
)
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
### 🔹 Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
# MarianMT for translation (Arabic → English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device)
# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
# Libellés en Darija (Arabe et Latin)
darija_topic_labels = [
"مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau
"مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet
"مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement
"مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits
"مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...)
"مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM
"مساعدة تقنية (Mosa3ada technique)", # Assistance technique
"العروض والتخفيضات (Offres w promotions)", # Offres et promotions
"طلب معلومات (Talab l'ma3loumat)", # Demande d'information
"شكاية (Chikaya)", # Réclamation
"حاجة أخرى (Chi haja okhra)" # Autre
]
# Libellés en Anglais
english_topic_labels = [
"Network Issue",
"Internet Issue",
"Billing & Payment Issue",
"Recharge & Plans",
"App Issue",
"SIM Card Issue",
"Technical Support",
"Offers & Promotions",
"General Inquiry",
"Complaint",
"Other"
]
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
# Transcription (Darija)
with torch.no_grad():
logits = wav2vec_model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
return transcription, translation, darija_topic, english_topic
except Exception as e:
return f"Error processing audio: {str(e)}", "", "", ""
def translate_text(text):
"""Translate Arabic text to English."""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated_tokens = translation_model.generate(**inputs)
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# 🔹 Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
english_topic_output = gr.Textbox(label="English Topic Classification")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
demo.launch()
|