Spaces:
Sleeping
Sleeping
import gradio as gr | |
import librosa | |
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer | |
# Charger le modèle de transcription pour le Darija | |
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") | |
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija") | |
# Charger le modèle de traduction Arabe -> Anglais | |
translation_model_name = "Helsinki-NLP/opus-mt-ar-en" | |
translation_model = MarianMTModel.from_pretrained(translation_model_name) | |
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name) | |
# Load AraBERT for Darija topic classification | |
arabert_model_name = "aubmindlab/bert-base-arabert" | |
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name) | |
arabert_model = AutoModel.from_pretrained(arabert_model_name) | |
# Load BERT for English topic classification | |
bert_model_name = "bert-base-uncased" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3) # Adjust labels as needed | |
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics | |
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics | |
import torch | |
def transcribe_audio(audio): | |
"""Convert audio to text, translate it, and classify topics in both Darija and English""" | |
try: | |
# Load and preprocess audio | |
audio_array, sr = librosa.load(audio, sr=16000) | |
# Ensure correct sampling rate | |
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values | |
# Move to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
input_values = input_values.to(device) | |
# Get predictions from Wav2Vec2 model | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
tokens = torch.argmax(logits, axis=-1) | |
# Decode transcription (Darija) | |
transcription = processor.decode(tokens[0]) | |
# Translate to English | |
translation = translate_text(transcription) | |
# Classify topics for Darija and English | |
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels) | |
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels) | |
return transcription, translation, darija_topic, english_topic | |
except Exception as e: | |
print(f"Error in transcription: {e}") | |
return "Error processing audio", "", "", "" | |
def translate_text(text): | |
"""Traduire le texte de l'arabe vers l'anglais""" | |
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
translated_tokens = translation_model.generate(**inputs) | |
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
return translated_text | |
def classify_topic(text, tokenizer, model, topic_labels): | |
"""Classify topic using BERT-based models""" | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other" | |
# Interface utilisateur avec Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification") | |
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record") | |
submit_button = gr.Button("Process") | |
transcription_output = gr.Textbox(label="Transcription (Darija)") | |
translation_output = gr.Textbox(label="Translation (English)") | |
darija_topic_output = gr.Textbox(label="Darija Topic Classification") | |
english_topic_output = gr.Textbox(label="English Topic Classification") | |
submit_button.click(transcribe_audio, | |
inputs=[audio_input], | |
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output]) | |
demo.launch() | |