|
""" |
|
Multi-Speaker Audio Analyzer |
|
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files. |
|
|
|
Author: [Your Name] |
|
Date: January 2025 |
|
""" |
|
|
|
import streamlit as st |
|
from pyannote.audio import Pipeline |
|
import whisper |
|
import tempfile |
|
import os |
|
import torch |
|
from transformers import pipeline as tf_pipeline, BartTokenizer |
|
from pydub import AudioSegment |
|
import io |
|
import pickle |
|
|
|
class SpeakerDiarizer: |
|
def __init__(self, token): |
|
self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=token) |
|
|
|
def process(self, audio_file): |
|
return self.pipeline(audio_file) |
|
|
|
class Transcriber: |
|
def __init__(self): |
|
self.model = whisper.load_model("base") |
|
|
|
def process(self, audio_file): |
|
return self.model.transcribe(audio_file)["text"] |
|
|
|
class Summarizer: |
|
def __init__(self, model_path='bart_ami_finetuned.pkl'): |
|
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') |
|
with open(model_path, 'rb') as f: |
|
self.model = pickle.load(f) |
|
|
|
def process(self, text): |
|
inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) |
|
summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40) |
|
return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
try: |
|
diarizer = SpeakerDiarizer(st.secrets["hf_token"]) |
|
transcriber = Transcriber() |
|
summarizer = Summarizer() |
|
return diarizer, transcriber, summarizer |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
return None, None, None |
|
|
|
def process_audio(audio_file): |
|
try: |
|
audio_bytes = io.BytesIO(audio_file.getvalue()) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
if audio_file.name.lower().endswith('.mp3'): |
|
audio = AudioSegment.from_mp3(audio_bytes) |
|
else: |
|
audio = AudioSegment.from_wav(audio_bytes) |
|
|
|
audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2) |
|
audio.export(tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"]) |
|
tmp_path = tmp.name |
|
|
|
diarizer, transcriber, summarizer = load_models() |
|
if not all([diarizer, transcriber, summarizer]): |
|
return "Model loading failed" |
|
|
|
with st.spinner("Processing..."): |
|
diarization = diarizer.process(tmp_path) |
|
transcription = transcriber.process(tmp_path) |
|
summary = summarizer.process(transcription) |
|
|
|
os.unlink(tmp_path) |
|
|
|
return { |
|
"diarization": diarization, |
|
"transcription": transcription, |
|
"summary": summary |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
return None |
|
|
|
def main(): |
|
st.title("Multi-Speaker Audio Analyzer") |
|
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long") |
|
|
|
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) |
|
|
|
if uploaded_file: |
|
file_size = len(uploaded_file.getvalue()) / (1024 * 1024) |
|
st.write(f"File size: {file_size:.2f} MB") |
|
|
|
st.audio(uploaded_file, format='audio/wav') |
|
|
|
if st.button("Analyze Audio"): |
|
if file_size > 200: |
|
st.error("File size exceeds 200MB limit") |
|
else: |
|
results = process_audio(uploaded_file) |
|
|
|
if results: |
|
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) |
|
|
|
with tab1: |
|
st.write("Speaker Timeline:") |
|
for turn, _, speaker in results["diarization"].itertracks(yield_label=True): |
|
col1, col2, col3 = st.columns([2,3,5]) |
|
with col1: |
|
speaker_num = int(speaker.split('_')[1]) |
|
colors = ['🔵', '🔴'] |
|
st.write(f"{colors[speaker_num % 2]} {speaker}") |
|
with col2: |
|
st.write(f"{format_timestamp(turn.start)} → {format_timestamp(turn.end)}") |
|
|
|
with tab2: |
|
st.write("Transcription:") |
|
st.write(results["transcription"]) |
|
|
|
with tab3: |
|
st.write("Summary:") |
|
st.write(results["summary"]) |
|
|
|
if __name__ == "__main__": |
|
main() |