File size: 3,454 Bytes
2a6784d
 
853df82
b3635dd
 
 
 
853df82
2a6784d
b3635dd
 
 
 
 
 
 
 
f1a85dc
b3635dd
 
 
 
 
 
 
 
 
 
 
2a6784d
f1a85dc
b3635dd
853df82
 
b3635dd
853df82
 
 
 
 
 
 
 
b3635dd
853df82
2a6784d
b3635dd
 
 
 
2a6784d
b3635dd
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from pyannote.audio import Pipeline
import whisper  
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment

@st.cache_resource
def load_models():
    try:
        diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization",
            use_auth_token=st.secrets["hf_token"]
        )
        
        transcriber = whisper.load_model("turbo")  
        
        summarizer = tf_pipeline(
            "summarization", 
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=600):  # limit to 5 minutes initially
    try:
        

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            # Convert MP3 to WAV if needed
            if audio_file.name.endswith('.mp3'):
                audio = AudioSegment.from_mp3(audio_file)
            else:
                audio = AudioSegment.from_wav(audio_file)
            
            # Export as WAV
            audio.export(tmp.name, format="wav")
            tmp_path = tmp.name
        

        # Get cached models
        diarization, transcriber, summarizer = load_models()
        if not all([diarization, transcriber, summarizer]):
            return "Model loading failed"

        # Process with progress bar
        with st.spinner("Identifying speakers..."):
            diarization_result = diarization(tmp_path)
        
        with st.spinner("Transcribing audio..."):
            transcription = transcriber.transcribe(tmp_path)
            
        with st.spinner("Generating summary..."):
            summary = summarizer(transcription["text"], max_length=130, min_length=30)

        # Cleanup
        os.unlink(tmp_path)
        
        return {
            "diarization": diarization_result,
            "transcription": transcription["text"],
            "summary": summary[0]["summary_text"]
        }
        
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            results = process_audio(uploaded_file)
            
            if results:
                # Display results in tabs
                tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                
                with tab1:
                    st.write("Speaker Segments:")
                    for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
                        st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
                
                with tab2:
                    st.write("Transcription:")
                    st.write(results["transcription"])
                
                with tab3:
                    st.write("Summary:")
                    st.write(results["summary"])

if __name__ == "__main__":
    main()