Spaces:
Running
Running
File size: 4,134 Bytes
e7d0ead 6cc8631 2785b50 85e680f 2785b50 85e680f 2785b50 f2ecb6e 2785b50 f2ecb6e 2785b50 f2ecb6e 2785b50 f2ecb6e 2785b50 f117ae1 980dcf2 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 2785b50 f117ae1 9078685 2785b50 f117ae1 9078685 f117ae1 9078685 f117ae1 2785b50 d04bf8d 2785b50 980dcf2 f2ecb6e 2785b50 f2ecb6e 2785b50 1a38424 f2ecb6e 9078685 f2ecb6e 9078685 f2ecb6e c7f40c9 f2ecb6e d27a60f 980dcf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import librosa
import torch
from transformers import (
Wav2Vec2ForCTC, Wav2Vec2Processor,
MarianMTModel, MarianTokenizer,
BertForSequenceClassification, AutoModel, AutoTokenizer
)
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
### πΉ Load Models & Tokenizers Once ###
# Wav2Vec2 for Darija transcription
wav2vec_model_name = "boumehdi/wav2vec2-large-xlsr-moroccan-darija"
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_name)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(wav2vec_model_name).to(device)
# MarianMT for translation (Arabic β English)
translation_model_name = "Helsinki-NLP/opus-mt-ar-en"
translation_tokenizer = MarianTokenizer.from_pretrained(translation_model_name)
translation_model = MarianMTModel.from_pretrained(translation_model_name).to(device)
# AraBERT for Darija topic classification
arabert_model_name = "aubmindlab/bert-base-arabert"
arabert_tokenizer = AutoTokenizer.from_pretrained(arabert_model_name)
arabert_model = AutoModel.from_pretrained(arabert_model_name).to(device)
# BERT for English topic classification
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=3).to(device)
# Define Topic Labels
darija_topic_labels = ["Customer Service", "Retention Service", "Billing Issue"]
english_topic_labels = ["Support Request", "Subscription Issue", "Payment Dispute"]
def transcribe_audio(audio):
"""Convert audio to text, translate it, and classify topics in both Darija and English."""
try:
# Load and preprocess audio
audio_array, sr = librosa.load(audio, sr=16000)
input_values = processor(audio_array, sampling_rate=16000, return_tensors="pt", padding=True).input_values.to(device)
# Transcription (Darija)
with torch.no_grad():
logits = wav2vec_model(input_values).logits
tokens = torch.argmax(logits, axis=-1)
transcription = processor.decode(tokens[0])
# Translate to English
translation = translate_text(transcription)
# Classify topics
darija_topic = classify_topic(transcription, arabert_tokenizer, arabert_model, darija_topic_labels)
english_topic = classify_topic(translation, bert_tokenizer, bert_model, english_topic_labels)
return transcription, translation, darija_topic, english_topic
except Exception as e:
return f"Error processing audio: {str(e)}", "", "", ""
def translate_text(text):
"""Translate Arabic text to English."""
inputs = translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
translated_tokens = translation_model.generate(**inputs)
return translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
def classify_topic(text, tokenizer, model, topic_labels):
"""Classify topic using BERT-based models."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return topic_labels[predicted_class] if predicted_class < len(topic_labels) else "Other"
# πΉ Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# ποΈ Speech-to-Text, Translation & Topic Classification")
audio_input = gr.Audio(type="filepath", label="Upload Audio or Record")
submit_button = gr.Button("Process")
transcription_output = gr.Textbox(label="Transcription (Darija)")
translation_output = gr.Textbox(label="Translation (English)")
darija_topic_output = gr.Textbox(label="Darija Topic Classification")
english_topic_output = gr.Textbox(label="English Topic Classification")
submit_button.click(transcribe_audio,
inputs=[audio_input],
outputs=[transcription_output, translation_output, darija_topic_output, english_topic_output])
demo.launch()
|