File size: 3,310 Bytes
2a6784d
 
 
b3635dd
 
 
 
2a6784d
b3635dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
 
 
 
2a6784d
b3635dd
 
 
 
2a6784d
b3635dd
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline

# Cache the model loading using streamlit
@st.cache_resource
def load_models():
    try:
        # Load diarization model efficiently
        diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization",
            use_auth_token=st.secrets["hf_token"]
        )
        
        # Load smaller whisper model for faster processing
        transcriber = whisper.load_model("base")
        
        # Load efficient summarizer
        summarizer = tf_pipeline(
            "summarization", 
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=300):  # limit to 5 minutes initially
    try:
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp.write(audio_file.getvalue())
            tmp_path = tmp.name

        # Get cached models
        diarization, transcriber, summarizer = load_models()
        if not all([diarization, transcriber, summarizer]):
            return "Model loading failed"

        # Process with progress bar
        with st.spinner("Identifying speakers..."):
            diarization_result = diarization(tmp_path)
        
        with st.spinner("Transcribing audio..."):
            transcription = transcriber.transcribe(tmp_path)
            
        with st.spinner("Generating summary..."):
            summary = summarizer(transcription["text"], max_length=130, min_length=30)

        # Cleanup
        os.unlink(tmp_path)
        
        return {
            "diarization": diarization_result,
            "transcription": transcription["text"],
            "summary": summary[0]["summary_text"]
        }
        
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            results = process_audio(uploaded_file)
            
            if results:
                # Display results in tabs
                tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                
                with tab1:
                    st.write("Speaker Segments:")
                    for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
                        st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
                
                with tab2:
                    st.write("Transcription:")
                    st.write(results["transcription"])
                
                with tab3:
                    st.write("Summary:")
                    st.write(results["summary"])

if __name__ == "__main__":
    main()