File size: 3,310 Bytes
2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd 2a6784d b3635dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
# Cache the model loading using streamlit
@st.cache_resource
def load_models():
try:
# Load diarization model efficiently
diarization = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token=st.secrets["hf_token"]
)
# Load smaller whisper model for faster processing
transcriber = whisper.load_model("base")
# Load efficient summarizer
summarizer = tf_pipeline(
"summarization",
model="facebook/bart-large-cnn",
device=0 if torch.cuda.is_available() else -1
)
return diarization, transcriber, summarizer
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None, None, None
def process_audio(audio_file, max_duration=300): # limit to 5 minutes initially
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_file.getvalue())
tmp_path = tmp.name
# Get cached models
diarization, transcriber, summarizer = load_models()
if not all([diarization, transcriber, summarizer]):
return "Model loading failed"
# Process with progress bar
with st.spinner("Identifying speakers..."):
diarization_result = diarization(tmp_path)
with st.spinner("Transcribing audio..."):
transcription = transcriber.transcribe(tmp_path)
with st.spinner("Generating summary..."):
summary = summarizer(transcription["text"], max_length=130, min_length=30)
# Cleanup
os.unlink(tmp_path)
return {
"diarization": diarization_result,
"transcription": transcription["text"],
"summary": summary[0]["summary_text"]
}
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
def main():
st.title("Multi-Speaker Audio Analyzer")
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
if uploaded_file:
st.audio(uploaded_file, format='audio/wav')
if st.button("Analyze Audio"):
results = process_audio(uploaded_file)
if results:
# Display results in tabs
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
with tab1:
st.write("Speaker Segments:")
for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
with tab2:
st.write("Transcription:")
st.write(results["transcription"])
with tab3:
st.write("Summary:")
st.write(results["summary"])
if __name__ == "__main__":
main() |