STTDARIJAAPI / app.py
Mohssinibra's picture
../
f117ae1 verified
raw
history blame
4.37 kB
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer
# Charger le modèle de transcription pour le Darija
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Charger le modèle de traduction Arabe -> Anglais
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_model = MarianMTModel.from_pretrained(translation_model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
# Load AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name)
# Load BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3) # Adjust labels as needed
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
import torch
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English"""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
# Ensure correct sampling rate
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
input_values = input_values.to(device)
# Get predictions from Wav2Vec2 model
with torch.no_grad():
logits = model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
# Decode transcription (Darija)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics for Darija and English
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
return transcription, translation, darija_topic, english_topic
except Exception as e:
print(f"Error in transcription: {e}")
return "Error processing audio", "", "", ""
def translate_text(text):
"""Traduire le texte de l'arabe vers l'anglais"""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated_tokens = translation_model.generate(**inputs)
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated_text
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# Interface utilisateur avec Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
english_topic_output = gr.Textbox(label="English Topic Classification")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
demo.launch()