Manyue-DataScientist's picture
Update app.py
f1a85dc verified
raw
history blame
3.15 kB
import streamlit as st
from pyannote.audio import Pipeline
import whisper # Changed import
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
@st.cache_resource
def load_models():
try:
diarization = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token=st.secrets["hf_token"]
)
transcriber = whisper.load_model("turbo")
summarizer = tf_pipeline(
"summarization",
model="facebook/bart-large-cnn",
device=0 if torch.cuda.is_available() else -1
)
return diarization, transcriber, summarizer
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None, None, None
def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_file.getvalue())
tmp_path = tmp.name
# Get cached models
diarization, transcriber, summarizer = load_models()
if not all([diarization, transcriber, summarizer]):
return "Model loading failed"
# Process with progress bar
with st.spinner("Identifying speakers..."):
diarization_result = diarization(tmp_path)
with st.spinner("Transcribing audio..."):
transcription = transcriber.transcribe(tmp_path)
with st.spinner("Generating summary..."):
summary = summarizer(transcription["text"], max_length=130, min_length=30)
# Cleanup
os.unlink(tmp_path)
return {
"diarization": diarization_result,
"transcription": transcription["text"],
"summary": summary[0]["summary_text"]
}
except Exception as e:
st.error(f"Error processing audio: {str(e)}")
return None
def main():
st.title("Multi-Speaker Audio Analyzer")
st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
if uploaded_file:
st.audio(uploaded_file, format='audio/wav')
if st.button("Analyze Audio"):
results = process_audio(uploaded_file)
if results:
# Display results in tabs
tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
with tab1:
st.write("Speaker Segments:")
for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
st.write(f"{speaker}: {turn.start:.1f}s β†’ {turn.end:.1f}s")
with tab2:
st.write("Transcription:")
st.write(results["transcription"])
with tab3:
st.write("Summary:")
st.write(results["summary"])
if __name__ == "__main__":
main()