File size: 7,301 Bytes
2a6784d
 
caa4c85
b3635dd
 
 
 
853df82
caa4c85
2a6784d
b3635dd
 
 
 
 
 
 
caa4c85
b3635dd
caa4c85
b3635dd
 
 
 
 
 
 
2a6784d
f1a85dc
b3635dd
caa4c85
 
853df82
b3635dd
caa4c85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
caa4c85
 
 
 
2a6784d
caa4c85
 
 
b3635dd
caa4c85
 
 
 
 
2a6784d
caa4c85
 
 
 
 
 
 
 
 
b3635dd
 
 
2a6784d
e0d61c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
caa4c85
e0d61c7
caa4c85
 
 
b3635dd
 
 
caa4c85
 
 
 
b3635dd
caa4c85
 
 
 
e0d61c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caa4c85
e0d61c7
caa4c85
 
 
 
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment
import io

@st.cache_resource
def load_models():
    try:
        diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization",
            use_auth_token=st.secrets["hf_token"]
        )
        transcriber = whisper.load_model("base")  # Changed from turbo to base as it's more stable
        summarizer = tf_pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=600):  # limit to 5 minutes initially
    try:
        # First, read the uploaded file into BytesIO
        audio_bytes = io.BytesIO(audio_file.getvalue())
        
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            try:
                # Convert audio to standard format
                if audio_file.name.lower().endswith('.mp3'):
                    audio = AudioSegment.from_mp3(audio_bytes)
                else:
                    audio = AudioSegment.from_wav(audio_bytes)
                
                # Standardize audio format
                audio = audio.set_frame_rate(16000)  # Set sample rate to 16kHz
                audio = audio.set_channels(1)        # Convert to mono
                audio = audio.set_sample_width(2)    # Set to 16-bit
                
                # Export with specific parameters
                audio.export(
                    tmp.name,
                    format="wav",
                    parameters=["-ac", "1", "-ar", "16000"]
                )
                tmp_path = tmp.name
                
            except Exception as e:
                st.error(f"Error converting audio: {str(e)}")
                return None

            # Get cached models
            diarization, transcriber, summarizer = load_models()
            if not all([diarization, transcriber, summarizer]):
                return "Model loading failed"

            # Process with progress bar
            with st.spinner("Identifying speakers..."):
                diarization_result = diarization(tmp_path)
            
            with st.spinner("Transcribing audio..."):
                transcription = transcriber.transcribe(tmp_path)
                
            with st.spinner("Generating summary..."):
                summary = summarizer(transcription["text"], max_length=130, min_length=30)

            # Cleanup
            os.unlink(tmp_path)
            
            return {
                "diarization": diarization_result,
                "transcription": transcription["text"],
                "summary": summary[0]["summary_text"]
            }
            
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def format_speaker_segments(diarization_result):
    """Process and format speaker segments by removing very short segments and merging consecutive ones"""
    formatted_segments = []
    min_duration = 0.3  # Minimum duration threshold in seconds
    
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        duration = turn.end - turn.start
        
        # Skip very short segments
        if duration < min_duration:
            continue
            
        # Add segment if it's the first one or from a different speaker
        if not formatted_segments or formatted_segments[-1]['speaker'] != speaker:
            formatted_segments.append({
                'speaker': speaker,
                'start': turn.start,
                'end': turn.end
            })
        # Extend the end time if it's the same speaker
        else:
            formatted_segments[-1]['end'] = turn.end
    
    return formatted_segments

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        # Display file info
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        # Display audio player
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    with tab1:
                        st.write("Speaker Timeline:")
                        
                        # Process speaker segments
                        segments = format_speaker_segments(results["diarization"])
                        
                        # Display segments in a more organized way
                        for segment in segments:
                            # Create columns for better layout
                            col1, col2, col3 = st.columns([2,1,6])
                            
                            with col1:
                                # Show speaker with consistent color
                                speaker_num = int(segment['speaker'].split('_')[1])
                                colors = ['πŸ”΅', 'πŸ”΄', '🟒', '🟑', '🟣']  # Different colors for different speakers
                                speaker_color = colors[speaker_num % len(colors)]
                                st.write(f"{speaker_color} {segment['speaker']}")
                            
                            with col2:
                                # Format time more cleanly
                                start_time = f"{int(segment['start']):02d}:{(segment['start']%60):04.1f}"
                                end_time = f"{int(segment['end']):02d}:{(segment['end']%60):04.1f}"
                                st.write(f"{start_time} β†’")
                            
                            with col3:
                                st.write(f"{end_time}")
                            
                            # Add a small separator
                            st.markdown("---")
                        
                        # Add legend
                        st.write("\nSpeaker Legend:")
                        for i in range(len(set(s['speaker'] for s in segments))):
                            st.write(f"{colors[i]} SPEAKER_{i:02d}")
                    
                    # Keep original transcription and summary tabs
                    with tab2:
                        st.write("Transcription:")
                        st.write(results["transcription"])
                    
                    with tab3:
                        st.write("Summary:")
                        st.write(results["summary"])

if __name__ == "__main__":
    main()