import gradio as gr import librosa import torch from transformers import ( Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer ) # Detect device device = "cuda" if torch.cuda.is_available() else "cpu" ### 🔹 Load Models & Tokenizers Once ### # Wav2Vec2 for Darija transcription wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija" processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name) wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device) # MarianMT for translation (Arabic → English) translation_model_name = "Helsinki-NLP/opus-mt-ar-en" translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device) # AraBERT for Darija topic classification arabert_model_name = "aubmindlab/bert-base-arabert" arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name) arabert_model = BertForSequenceClassification.from_pretrained(arabert_model_name, num_labels=2).to(device) # BERT for English topic classification bert_model_name = "bert-base-uncased" bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device) # Libellés en Darija (Arabe et Latin) darija_topic_labels = [ "مشكيل ف الشبكة (Mochkil f réseau)", # Problème de réseau "مشكيل ف الانترنت (Mochkil f internet)", # Problème d'Internet "مشكيل ف الفاتورة (Mochkil f l'factura)", # Problème de facturation et paiement "مشكيل ف التعبئة (Mochkil f l'recharge)", # Problème de recharge et forfaits "مشكيل ف التطبيق (Mochkil f l'application)", # Problème avec l’application (Orange et Moi...) "مشكيل ف بطاقة SIM (Mochkil f carte SIM)", # Problème avec la carte SIM "مساعدة تقنية (Mosa3ada technique)", # Assistance technique "العروض والتخفيضات (Offres w promotions)", # Offres et promotions "طلب معلومات (Talab l'ma3loumat)", # Demande d'information "شكاية (Chikaya)", # Réclamation "حاجة أخرى (Chi haja okhra)" # Autre ] # Libellés en Anglais english_topic_labels = [ "Network Issue", "Internet Issue", "Billing & Payment Issue", "Recharge & Plans", "App Issue", "SIM Card Issue", "Technical Support", "Offers & Promotions", "General Inquiry", "Complaint", "Other" ] def transcribe_audio(audio): """Convert audio to text, translate it, and classify topics in both Darija and English.""" try: # Load and preprocess audio audio_array, sr = librosa.load(audio, sr=16000) input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device) # Transcription (Darija) with torch.no_grad(): logits = wav2vec_model(input_values).logits tokens = torch.argmax(logits, axis=-1) transcription = processor.decode(tokens[0]) # Translate to English translation = translate_text(transcription) # Classify topics darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels) english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels) return transcription, translation, darija_topic, english_topic except Exception as e: return f"Error processing audio: {str(e)}", "", "", "" def translate_text(text): """Translate Arabic text to English.""" inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): translated_tokens = translation_model.generate(**inputs) return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) def classify_topic(text, tokenizer, model, topic_labels): """Classify topic using BERT-based models.""" inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) with torch.no_grad(): outputs = model(**inputs) predicted_class = torch.argmax(outputs.logits, dim=1).item() return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other" # 🔹 Gradio Interface with gr.Blocks() as demo: gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification") audio_input = gr.Audio(type="filepath", label="Upload Audio or Record") submit_button = gr.Button("Process") transcription_output = gr.Textbox(label="Transcription (Darija)") translation_output = gr.Textbox(label="Translation (English)") darija_topic_output = gr.Textbox(label="Darija Topic Classification") english_topic_output = gr.Textbox(label="English Topic Classification") submit_button.click(transcribe_audio, inputs=[audio_input], outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output]) demo.launch()