Spaces:
Running
Running
File size: 3,753 Bytes
e7d0ead 6cc8631 f2ecb6e 980dcf2 9078685 980dcf2 85e680f f2ecb6e 980dcf2 f2ecb6e 980dcf2 9078685 e7d0ead 9078685 980dcf2 d04bf8d 9078685 f2ecb6e d04bf8d 9078685 d04bf8d 980dcf2 f2ecb6e 9078685 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 f2ecb6e c7f40c9 f2ecb6e d27a60f 980dcf2 f2ecb6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer
# Charger le modèle de transcription pour le Darija
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Charger le modèle de traduction Arabe -> Anglais
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_model = MarianMTModel.from_pretrained(translation_model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
# Load AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name)
# Load BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3) # Adjust labels as needed
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English"""
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, return_tensors="pt", padding=True).input_values
logits = model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
translation = translate_text(transcription)
# Classify topics for both Darija and English
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
return transcription, translation, darija_topic, english_topic
def translate_text(text):
"""Traduire le texte de l'arabe vers l'anglais"""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated_tokens = translation_model.generate(**inputs)
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated_text
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# Interface utilisateur avec Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
english_topic_output = gr.Textbox(label="English Topic Classification")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
demo.launch()
|