File size: 5,545 Bytes
2a6784d
 
caa4c85
b3635dd
 
 
 
853df82
caa4c85
2a6784d
b3635dd
 
 
 
da59af0
b3635dd
83bc687
935113b
83bc687
935113b
b3635dd
caa4c85
b3635dd
 
 
 
 
 
 
2a6784d
935113b
b3635dd
caa4c85
853df82
b3635dd
caa4c85
 
 
 
 
 
935113b
 
 
 
caa4c85
 
 
 
 
 
 
 
 
 
 
2a6784d
caa4c85
 
 
2a6784d
caa4c85
 
b3635dd
caa4c85
 
 
 
 
2a6784d
caa4c85
 
 
 
83bc687
caa4c85
 
 
b3635dd
 
 
2a6784d
83bc687
e0d61c7
 
 
83bc687
e0d61c7
 
83bc687
 
e0d61c7
 
 
 
83bc687
 
 
 
 
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
e0d61c7
caa4c85
 
b3635dd
 
 
caa4c85
 
 
 
b3635dd
caa4c85
 
 
 
e0d61c7
83bc687
e0d61c7
 
935113b
e0d61c7
 
 
83bc687
e0d61c7
 
 
 
83bc687
 
 
e0d61c7
 
caa4c85
 
 
935113b
caa4c85
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment
import io

@st.cache_resource
def load_models():
    try:
        diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=st.secrets["hf_token"]
        )
        
        transcriber = whisper.load_model("small")
        
        summarizer = tf_pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=600):
    try:
        audio_bytes = io.BytesIO(audio_file.getvalue())
        
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            try:
                if audio_file.name.lower().endswith('.mp3'):
                    audio = AudioSegment.from_mp3(audio_bytes)
                else:
                    audio = AudioSegment.from_wav(audio_bytes)
                
                # Standardize format
                audio = audio.set_frame_rate(16000)
                audio = audio.set_channels(1)
                audio = audio.set_sample_width(2)
                
                audio.export(
                    tmp.name,
                    format="wav",
                    parameters=["-ac", "1", "-ar", "16000"]
                )
                tmp_path = tmp.name
                
            except Exception as e:
                st.error(f"Error converting audio: {str(e)}")
                return None

            diarization, transcriber, summarizer = load_models()
            if not all([diarization, transcriber, summarizer]):
                return "Model loading failed"

            with st.spinner("Identifying speakers..."):
                diarization_result = diarization(tmp_path)
            
            with st.spinner("Transcribing audio..."):
                transcription = transcriber.transcribe(tmp_path)
                
            with st.spinner("Generating summary..."):
                summary = summarizer(transcription["text"], max_length=130, min_length=30)

            os.unlink(tmp_path)
            
            return {
                "diarization": diarization_result,
                "transcription": transcription,
                "summary": summary[0]["summary_text"]
            }
            
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def format_speaker_segments(diarization_result):
    formatted_segments = []
    
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        if turn.start is not None and turn.end is not None:
            formatted_segments.append({
                'speaker': speaker,
                'start': float(turn.start),
                'end': float(turn.end)
            })
    
    return formatted_segments

def format_timestamp(seconds):
    minutes = int(seconds // 60)
    seconds = seconds % 60
    return f"{minutes:02d}:{seconds:05.2f}"

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    with tab1:
                        st.write("Speaker Timeline:")
                        segments = format_speaker_segments(results["diarization"])
                        
                        for segment in segments:
                            col1, col2 = st.columns([2,8])
                            
                            with col1:
                                speaker_num = int(segment['speaker'].split('_')[1])
                                colors = ['πŸ”΅', 'πŸ”΄']  # Two colors for alternating speakers
                                speaker_color = colors[speaker_num % len(colors)]
                                st.write(f"{speaker_color} {segment['speaker']}")
                            
                            with col2:
                                start_time = format_timestamp(segment['start'])
                                end_time = format_timestamp(segment['end'])
                                st.write(f"{start_time} β†’ {end_time}")
                            
                            st.markdown("---")
                    
                    with tab2:
                        st.write("Transcription:")
                        st.write(results["transcription"]["text"])
                    
                    with tab3:
                        st.write("Summary:")
                        st.write(results["summary"])

if __name__ == "__main__":
    main()