MyIVR / app.py
JabriA's picture
Add Darija transcription and topic extraction app8
98899e8
raw
history blame
3.71 kB
import gradio as gr
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline
from transformers import BertTokenizer, BertForSequenceClassification
import librosa
import os
# Set up proxy for internal testing
os.environ["HTTP_PROXY"] = "http://meditelproxy.meditel.int:80"
os.environ["HTTPS_PROXY"] = "http://meditelproxy.meditel.int:80"
# Load models
# Transcription model for Moroccan Darija
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Summarization model (for French summaries)
summarizer = pipeline("summarization", model="facebook/mbart-large-50-many-to-many-mmt")
# Topic Classification Model (BERT for example)
topic_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") # Example model
topic_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Function to resample audio to 16kHz if necessary
def resample_audio(audio_path, target_sr=16000):
audio_input, original_sr = librosa.load(audio_path, sr=None) # Load audio with original sampling rate
if original_sr != target_sr:
audio_input = librosa.resample(audio_input, orig_sr=original_sr, target_sr=target_sr) # Resample to 16kHz
return audio_input, target_sr
# Function to transcribe audio using Wav2Vec2
def transcribe_audio(audio_path):
# Load and preprocess audio
audio_input, sample_rate = resample_audio(audio_path)
inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True)
# Get predictions
with torch.no_grad():
logits = transcription_model(**inputs).logits
# Decode predictions
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Function to classify the transcription into topics
def classify_topic(transcription):
# Tokenize the transcription and pass it through the BERT classifier
inputs = topic_tokenizer(transcription, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = topic_model(**inputs)
# Get the predicted label (0 for Customer Service, 1 for Retention Service, etc.)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Map prediction to a topic
if predicted_class == 0:
return "Customer Service"
elif predicted_class == 1:
return "Retention Service"
else:
return "Other"
# Function to transcribe, summarize in French, and classify topic
def transcribe_and_summarize(audio_file):
# Transcription
transcription = transcribe_audio(audio_file)
# Summarization in French
summary = summarizer(
transcription,
max_length=50,
min_length=10,
do_sample=False,
tgt_lang="fr_XX" # Target language set to French
)[0]["summary_text"]
# Topic classification
topic = classify_topic(transcription)
return transcription, summary, topic
# Gradio Interface
inputs = gr.Audio(type="filepath", label="Upload your audio file")
outputs = [
gr.Textbox(label="Transcription"),
gr.Textbox(label="Résumé (en Français)"),
gr.Textbox(label="Topic")
]
app = gr.Interface(
fn=transcribe_and_summarize,
inputs=inputs,
outputs=outputs,
title="Moroccan Darija Audio Transcription, Résumé, and Topic Classification",
description="Upload an audio file in Moroccan Darija to get its transcription, a summarized version in French, and the detected topic."
)
# Launch the app
if __name__ == "__main__":
app.launch()