""" Multi-Speaker Audio Analyzer A Streamlit application that performs speaker diarization, transcription, and summarization on audio files. Author: [Your Name] Date: January 2025 """ import streamlit as st from pyannote.audio import Pipeline import whisper import tempfile import os import torch from transformers import pipeline as tf_pipeline, BartTokenizer from pydub import AudioSegment import io import pickle class SpeakerDiarizer: def __init__(self, token): self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=token) def process(self, audio_file): return self.pipeline(audio_file) class Transcriber: def __init__(self): self.model = whisper.load_model("base") def process(self, audio_file): return self.model.transcribe(audio_file)["text"] class Summarizer: def __init__(self, model_path='bart_ami_finetuned.pkl'): self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') with open(model_path, 'rb') as f: self.model = pickle.load(f) def process(self, text): inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40) return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) @st.cache_resource def load_models(): try: diarizer = SpeakerDiarizer(st.secrets["hf_token"]) transcriber = Transcriber() summarizer = Summarizer() return diarizer, transcriber, summarizer except Exception as e: st.error(f"Error loading models: {str(e)}") return None, None, None def process_audio(audio_file): try: audio_bytes = io.BytesIO(audio_file.getvalue()) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: if audio_file.name.lower().endswith('.mp3'): audio = AudioSegment.from_mp3(audio_bytes) else: audio = AudioSegment.from_wav(audio_bytes) audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2) audio.export(tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"]) tmp_path = tmp.name diarizer, transcriber, summarizer = load_models() if not all([diarizer, transcriber, summarizer]): return "Model loading failed" with st.spinner("Processing..."): diarization = diarizer.process(tmp_path) transcription = transcriber.process(tmp_path) summary = summarizer.process(transcription) os.unlink(tmp_path) return { "diarization": diarization, "transcription": transcription, "summary": summary } except Exception as e: st.error(f"Error: {str(e)}") return None def main(): st.title("Multi-Speaker Audio Analyzer") st.write("Upload an audio file (MP3/WAV) up to 5 minutes long") uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) if uploaded_file: file_size = len(uploaded_file.getvalue()) / (1024 * 1024) st.write(f"File size: {file_size:.2f} MB") st.audio(uploaded_file, format='audio/wav') if st.button("Analyze Audio"): if file_size > 200: st.error("File size exceeds 200MB limit") else: results = process_audio(uploaded_file) if results: tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) with tab1: st.write("Speaker Timeline:") for turn, _, speaker in results["diarization"].itertracks(yield_label=True): col1, col2, col3 = st.columns([2,3,5]) with col1: speaker_num = int(speaker.split('_')[1]) colors = ['🔵', '🔴'] st.write(f"{colors[speaker_num % 2]} {speaker}") with col2: st.write(f"{format_timestamp(turn.start)} → {format_timestamp(turn.end)}") with tab2: st.write("Transcription:") st.write(results["transcription"]) with tab3: st.write("Summary:") st.write(results["summary"]) if __name__ == "__main__": main()