|
import streamlit as st |
|
from pyannote.audio import Pipeline |
|
import whisper |
|
import tempfile |
|
import os |
|
import torch |
|
from transformers import pipeline as tf_pipeline |
|
from pydub import AudioSegment |
|
import io |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
try: |
|
diarization = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization", |
|
use_auth_token=st.secrets["hf_token"] |
|
) |
|
transcriber = whisper.load_model("base") |
|
summarizer = tf_pipeline( |
|
"summarization", |
|
model="facebook/bart-large-cnn", |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
return diarization, transcriber, summarizer |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
return None, None, None |
|
|
|
def process_audio(audio_file, max_duration=600): |
|
try: |
|
|
|
audio_bytes = io.BytesIO(audio_file.getvalue()) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
try: |
|
|
|
if audio_file.name.lower().endswith('.mp3'): |
|
audio = AudioSegment.from_mp3(audio_bytes) |
|
else: |
|
audio = AudioSegment.from_wav(audio_bytes) |
|
|
|
|
|
audio = audio.set_frame_rate(16000) |
|
audio = audio.set_channels(1) |
|
audio = audio.set_sample_width(2) |
|
|
|
|
|
audio.export( |
|
tmp.name, |
|
format="wav", |
|
parameters=["-ac", "1", "-ar", "16000"] |
|
) |
|
tmp_path = tmp.name |
|
|
|
except Exception as e: |
|
st.error(f"Error converting audio: {str(e)}") |
|
return None |
|
|
|
|
|
diarization, transcriber, summarizer = load_models() |
|
if not all([diarization, transcriber, summarizer]): |
|
return "Model loading failed" |
|
|
|
|
|
with st.spinner("Identifying speakers..."): |
|
diarization_result = diarization(tmp_path) |
|
|
|
with st.spinner("Transcribing audio..."): |
|
transcription = transcriber.transcribe(tmp_path) |
|
|
|
with st.spinner("Generating summary..."): |
|
summary = summarizer(transcription["text"], max_length=130, min_length=30) |
|
|
|
|
|
os.unlink(tmp_path) |
|
|
|
return { |
|
"diarization": diarization_result, |
|
"transcription": transcription["text"], |
|
"summary": summary[0]["summary_text"] |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"Error processing audio: {str(e)}") |
|
return None |
|
|
|
def format_speaker_segments(diarization_result): |
|
"""Process and format speaker segments by removing very short segments and merging consecutive ones""" |
|
formatted_segments = [] |
|
min_duration = 0.3 |
|
|
|
for turn, _, speaker in diarization_result.itertracks(yield_label=True): |
|
duration = turn.end - turn.start |
|
|
|
|
|
if duration < min_duration: |
|
continue |
|
|
|
|
|
if not formatted_segments or formatted_segments[-1]['speaker'] != speaker: |
|
formatted_segments.append({ |
|
'speaker': speaker, |
|
'start': turn.start, |
|
'end': turn.end |
|
}) |
|
|
|
else: |
|
formatted_segments[-1]['end'] = turn.end |
|
|
|
return formatted_segments |
|
|
|
def main(): |
|
st.title("Multi-Speaker Audio Analyzer") |
|
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance") |
|
|
|
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"]) |
|
|
|
if uploaded_file: |
|
|
|
file_size = len(uploaded_file.getvalue()) / (1024 * 1024) |
|
st.write(f"File size: {file_size:.2f} MB") |
|
|
|
|
|
st.audio(uploaded_file, format='audio/wav') |
|
|
|
if st.button("Analyze Audio"): |
|
if file_size > 200: |
|
st.error("File size exceeds 200MB limit") |
|
else: |
|
results = process_audio(uploaded_file) |
|
|
|
if results: |
|
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"]) |
|
|
|
with tab1: |
|
st.write("Speaker Timeline:") |
|
|
|
|
|
segments = format_speaker_segments(results["diarization"]) |
|
|
|
|
|
for segment in segments: |
|
|
|
col1, col2, col3 = st.columns([2,1,6]) |
|
|
|
with col1: |
|
|
|
speaker_num = int(segment['speaker'].split('_')[1]) |
|
colors = ['🔵', '🔴', '🟢', '🟡', '🟣'] |
|
speaker_color = colors[speaker_num % len(colors)] |
|
st.write(f"{speaker_color} {segment['speaker']}") |
|
|
|
with col2: |
|
|
|
start_time = f"{int(segment['start']):02d}:{(segment['start']%60):04.1f}" |
|
end_time = f"{int(segment['end']):02d}:{(segment['end']%60):04.1f}" |
|
st.write(f"{start_time} →") |
|
|
|
with col3: |
|
st.write(f"{end_time}") |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
st.write("\nSpeaker Legend:") |
|
for i in range(len(set(s['speaker'] for s in segments))): |
|
st.write(f"{colors[i]} SPEAKER_{i:02d}") |
|
|
|
|
|
with tab2: |
|
st.write("Transcription:") |
|
st.write(results["transcription"]) |
|
|
|
with tab3: |
|
st.write("Summary:") |
|
st.write(results["summary"]) |
|
|
|
if __name__ == "__main__": |
|
main() |