File size: 6,354 Bytes
2a6784d
 
caa4c85
b3635dd
 
 
 
853df82
caa4c85
2a6784d
b3635dd
 
 
935113b
b3635dd
935113b
b3635dd
935113b
 
 
 
 
 
 
 
 
b3635dd
caa4c85
b3635dd
 
 
 
 
 
 
2a6784d
935113b
b3635dd
caa4c85
853df82
b3635dd
caa4c85
 
 
 
 
 
935113b
 
 
 
caa4c85
 
 
 
 
 
 
 
 
 
 
2a6784d
caa4c85
 
 
2a6784d
caa4c85
 
b3635dd
caa4c85
 
 
 
 
2a6784d
caa4c85
 
 
 
935113b
caa4c85
 
 
b3635dd
 
 
2a6784d
935113b
e0d61c7
935113b
e0d61c7
 
935113b
 
e0d61c7
 
935113b
 
e0d61c7
 
 
935113b
 
e0d61c7
 
 
 
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
e0d61c7
caa4c85
 
b3635dd
 
 
caa4c85
 
 
 
b3635dd
caa4c85
 
 
 
e0d61c7
 
935113b
 
 
 
e0d61c7
935113b
e0d61c7
935113b
e0d61c7
 
 
935113b
e0d61c7
 
 
 
935113b
 
 
 
 
 
 
e0d61c7
 
caa4c85
 
 
935113b
caa4c85
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment
import io

@st.cache_resource
def load_models():
    try:
        # Updated to 3.1 with parameters
        diarization = Pipeline.from_pretrained(
            "pyannote/[email protected]",
            use_auth_token=st.secrets["hf_token"]
        ).instantiate({
            "onset": 0.3,
            "offset": 0.3,
            "min_duration_on": 0.1,
            "min_duration_off": 0.1
        })
        
        transcriber = whisper.load_model("base")
        
        summarizer = tf_pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=600):
    try:
        audio_bytes = io.BytesIO(audio_file.getvalue())
        
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            try:
                if audio_file.name.lower().endswith('.mp3'):
                    audio = AudioSegment.from_mp3(audio_bytes)
                else:
                    audio = AudioSegment.from_wav(audio_bytes)
                
                # Standardize format
                audio = audio.set_frame_rate(16000)
                audio = audio.set_channels(1)
                audio = audio.set_sample_width(2)
                
                audio.export(
                    tmp.name,
                    format="wav",
                    parameters=["-ac", "1", "-ar", "16000"]
                )
                tmp_path = tmp.name
                
            except Exception as e:
                st.error(f"Error converting audio: {str(e)}")
                return None

            diarization, transcriber, summarizer = load_models()
            if not all([diarization, transcriber, summarizer]):
                return "Model loading failed"

            with st.spinner("Identifying speakers..."):
                diarization_result = diarization(tmp_path)
            
            with st.spinner("Transcribing audio..."):
                transcription = transcriber.transcribe(tmp_path)
                
            with st.spinner("Generating summary..."):
                summary = summarizer(transcription["text"], max_length=130, min_length=30)

            os.unlink(tmp_path)
            
            return {
                "diarization": diarization_result,
                "transcription": transcription,  # Return full transcription object
                "summary": summary[0]["summary_text"]
            }
            
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def format_speaker_segments(diarization_result, transcription):
    formatted_segments = []
    audio_duration = transcription.get('duration', 0)
    
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        # Skip invalid timestamps
        if turn.start > audio_duration or turn.end > audio_duration:
            continue
            
        # Only add segments with meaningful duration
        if (turn.end - turn.start) >= 0.1:  # 100ms minimum
            formatted_segments.append({
                'speaker': speaker,
                'start': turn.start,
                'end': turn.end,
                'duration': turn.end - turn.start
            })
    
    return formatted_segments

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    with tab1:
                        st.write("Speaker Timeline:")
                        
                        segments = format_speaker_segments(
                            results["diarization"], 
                            results["transcription"]
                        )
                        
                        # Display segments with proper time formatting
                        for segment in segments:
                            col1, col2 = st.columns([2,8])
                            
                            with col1:
                                speaker_num = int(segment['speaker'].split('_')[1])
                                colors = ['🔵', '🔴']  # Simplified to two colors
                                speaker_color = colors[speaker_num % len(colors)]
                                st.write(f"{speaker_color} {segment['speaker']}")
                            
                            with col2:
                                mm_start = int(segment['start'] // 60)
                                ss_start = segment['start'] % 60
                                mm_end = int(segment['end'] // 60)
                                ss_end = segment['end'] % 60
                                
                                time_str = f"{mm_start:02d}:{ss_start:05.2f}{mm_end:02d}:{ss_end:05.2f}"
                                st.write(time_str)
                            
                            st.markdown("---")
                    
                    with tab2:
                        st.write("Transcription:")
                        st.write(results["transcription"]["text"])
                    
                    with tab3:
                        st.write("Summary:")
                        st.write(results["summary"])

if __name__ == "__main__":
    main()