Manyue-DataScientist's picture
Update app.py
7f8e922 verified
raw
history blame
4.69 kB
"""
Multi-Speaker Audio Analyzer
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.
Author: [Your Name]
Date: January 2025
"""
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline, BartTokenizer
from pydub import AudioSegment
import io
import pickle
class SpeakerDiarizer:
def __init__(self, token):
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=token)
def process(self, audio_file):
return self.pipeline(audio_file)
class Transcriber:
def __init__(self):
self.model = whisper.load_model("base")
def process(self, audio_file):
return self.model.transcribe(audio_file)["text"]
class Summarizer:
def __init__(self, model_path='bart_ami_finetuned.pkl'):
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
with open(model_path, 'rb') as f:
self.model = pickle.load(f)
def process(self, text):
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
@st.cache_resource
def load_models():
try:
diarizer = SpeakerDiarizer(st.secrets["hf_token"])
transcriber = Transcriber()
summarizer = Summarizer()
return diarizer, transcriber, summarizer
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None, None, None
def process_audio(audio_file):
try:
audio_bytes = io.BytesIO(audio_file.getvalue())
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
if audio_file.name.lower().endswith('.mp3'):
audio = AudioSegment.from_mp3(audio_bytes)
else:
audio = AudioSegment.from_wav(audio_bytes)
audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
audio.export(tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"])
tmp_path = tmp.name
diarizer, transcriber, summarizer = load_models()
if not all([diarizer, transcriber, summarizer]):
return "Model loading failed"
with st.spinner("Processing..."):
diarization = diarizer.process(tmp_path)
transcription = transcriber.process(tmp_path)
summary = summarizer.process(transcription)
os.unlink(tmp_path)
return {
"diarization": diarization,
"transcription": transcription,
"summary": summary
}
except Exception as e:
st.error(f"Error: {str(e)}")
return None
def main():
st.title("Multi-Speaker Audio Analyzer")
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long")
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
if uploaded_file:
file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
st.write(f"File size: {file_size:.2f} MB")
st.audio(uploaded_file, format='audio/wav')
if st.button("Analyze Audio"):
if file_size > 200:
st.error("File size exceeds 200MB limit")
else:
results = process_audio(uploaded_file)
if results:
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
with tab1:
st.write("Speaker Timeline:")
for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
col1, col2, col3 = st.columns([2,3,5])
with col1:
speaker_num = int(speaker.split('_')[1])
colors = ['🔵', '🔴']
st.write(f"{colors[speaker_num % 2]} {speaker}")
with col2:
st.write(f"{format_timestamp(turn.start)}{format_timestamp(turn.end)}")
with tab2:
st.write("Transcription:")
st.write(results["transcription"])
with tab3:
st.write("Summary:")
st.write(results["summary"])
if __name__ == "__main__":
main()