MyIVR / app.py
JabriA's picture
Add Darija transcription and topic extraction app7
deca047
raw
history blame
3.37 kB
import gradio as gr
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, pipeline
from transformers import BertTokenizer, BertForSequenceClassification
import librosa
# Load models
# Transcription model for Moroccan Darija
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
transcription_model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
# Summarization model
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
# Topic Classification Model (BERT for example)
topic_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") # Example model
topic_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Function to resample audio to 16kHz if necessary
def resample_audio(audio_path, target_sr=16000):
audio_input, original_sr = librosa.load(audio_path, sr=None) # Load audio with original sampling rate
if original_sr != target_sr:
audio_input = librosa.resample(audio_input, orig_sr=original_sr, target_sr=target_sr) # Resample to 16kHz
return audio_input, target_sr
# Function to transcribe audio using Wav2Vec2
def transcribe_audio(audio_path):
# Load and preprocess audio
audio_input, sample_rate = resample_audio(audio_path)
inputs = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt", padding=True)
# Get predictions
with torch.no_grad():
logits = transcription_model(**inputs).logits
# Decode predictions
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# Function to classify the transcription into topics
def classify_topic(transcription):
# Tokenize the transcription and pass it through the BERT classifier
inputs = topic_tokenizer(transcription, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = topic_model(**inputs)
# Get the predicted label (0 for Customer Service, 1 for Retention Service, etc.)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Map prediction to a topic
if predicted_class == 0:
return "Customer Service"
elif predicted_class == 1:
return "Retention Service"
else:
return "Other"
# Function to transcribe, summarize, and classify topic
def transcribe_and_summarize(audio_file):
# Transcription
transcription = transcribe_audio(audio_file)
# Summarization
summary = summarizer(transcription, max_length=50, min_length=10, do_sample=False)[0]["summary_text"]
# Topic classification
topic = classify_topic(transcription)
return transcription, summary, topic
# Gradio Interface
inputs = gr.Audio(type="filepath", label="Upload your audio file")
outputs = [
gr.Textbox(label="Transcription"),
gr.Textbox(label="Summary"),
gr.Textbox(label="Topic")
]
app = gr.Interface(
fn=transcribe_and_summarize,
inputs=inputs,
outputs=outputs,
title="Moroccan Darija Audio Transcription, Summarization, and Topic Classification",
description="Upload an audio file in Moroccan Darija to get its transcription, a summarized version of the content, and the detected topic."
)
# Launch the app
if __name__ == "__main__":
app.launch()