Spaces:
Running
Running
File size: 4,367 Bytes
e7d0ead 6cc8631 f2ecb6e 980dcf2 9078685 980dcf2 85e680f f2ecb6e f117ae1 980dcf2 f2ecb6e f117ae1 9078685 f117ae1 9078685 f117ae1 9078685 f117ae1 f2ecb6e d04bf8d 9078685 d04bf8d 980dcf2 f2ecb6e 9078685 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 f2ecb6e c7f40c9 f2ecb6e d27a60f 980dcf2 f2ecb6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import librosa
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, MarianMTModel, MarianTokenizer, BertForSequenceClassification, AutoModel, AutoTokenizer
# Charger le modèle de transcription pour le Darija
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Charger le modèle de traduction Arabe -> Anglais
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_model = MarianMTModel.from_pretrained(translation_model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
# Load AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name)
# Load BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3) # Adjust labels as needed
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"] # Adjust for Darija topics
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"] # Adjust for English topics
import torch
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English"""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
# Ensure correct sampling rate
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
input_values = input_values.to(device)
# Get predictions from Wav2Vec2 model
with torch.no_grad():
logits = model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
# Decode transcription (Darija)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics for Darija and English
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
return transcription, translation, darija_topic, english_topic
except Exception as e:
print(f"Error in transcription: {e}")
return "Error processing audio", "", "", ""
def translate_text(text):
"""Traduire le texte de l'arabe vers l'anglais"""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
translated_tokens = translation_model.generate(**inputs)
translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated_text
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models"""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# Interface utilisateur avec Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🎙️ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
english_topic_output = gr.Textbox(label="English Topic Classification")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
demo.launch()
|