Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -10,186 +10,185 @@ import io | |
| 10 |  | 
| 11 | 
             
            @st.cache_resource
         | 
| 12 | 
             
            def load_models():
         | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 |  | 
| 36 | 
             
            def process_audio(audio_file, max_duration=600):
         | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
|  | |
| 61 |  | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 |  | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 |  | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 |  | 
| 87 | 
             
            def format_speaker_segments(diarization_result, transcription):
         | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
               cleaned_segments = []
         | 
| 120 | 
            -
               for i, segment in enumerate(formatted_segments):
         | 
| 121 | 
            -
                   # Skip if this segment overlaps with previous one
         | 
| 122 | 
            -
                   if i > 0 and segment['start'] < cleaned_segments[-1]['end']:
         | 
| 123 | 
            -
                       continue
         | 
| 124 | 
            -
                   cleaned_segments.append(segment)
         | 
| 125 | 
            -
               
         | 
| 126 | 
            -
               return cleaned_segments
         | 
| 127 |  | 
| 128 | 
             
            def format_timestamp(seconds):
         | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 |  | 
| 133 | 
             
            def main():
         | 
| 134 | 
            -
             | 
| 135 | 
            -
             | 
| 136 |  | 
| 137 | 
            -
             | 
| 138 |  | 
| 139 | 
            -
             | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
             | 
| 156 | 
            -
             | 
| 157 | 
            -
             | 
| 158 | 
            -
             | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 | 
            -
             | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
             | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 193 |  | 
| 194 | 
             
            if __name__ == "__main__":
         | 
| 195 | 
            -
             | 
|  | |
| 10 |  | 
| 11 | 
             
            @st.cache_resource
         | 
| 12 | 
             
            def load_models():
         | 
| 13 | 
            +
                try:
         | 
| 14 | 
            +
                    diarization = Pipeline.from_pretrained(
         | 
| 15 | 
            +
                        "pyannote/speaker-diarization",
         | 
| 16 | 
            +
                        use_auth_token=st.secrets["hf_token"]
         | 
| 17 | 
            +
                    )
         | 
| 18 | 
            +
                    
         | 
| 19 | 
            +
                    transcriber = whisper.load_model("base")
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                    summarizer = tf_pipeline(
         | 
| 22 | 
            +
                        "summarization",
         | 
| 23 | 
            +
                        model="facebook/bart-large-cnn",
         | 
| 24 | 
            +
                        device=0 if torch.cuda.is_available() else -1
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
                    if not diarization or not transcriber or not summarizer:
         | 
| 28 | 
            +
                        raise ValueError("One or more models failed to load")
         | 
| 29 | 
            +
                        
         | 
| 30 | 
            +
                    return diarization, transcriber, summarizer
         | 
| 31 | 
            +
                except Exception as e:
         | 
| 32 | 
            +
                    st.error(f"Error loading models: {str(e)}")
         | 
| 33 | 
            +
                    st.error("Debug info: Check if HF token is valid and has necessary permissions")
         | 
| 34 | 
            +
                    return None, None, None
         | 
| 35 |  | 
| 36 | 
             
            def process_audio(audio_file, max_duration=600):
         | 
| 37 | 
            +
                try:
         | 
| 38 | 
            +
                    audio_bytes = io.BytesIO(audio_file.getvalue())
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
         | 
| 41 | 
            +
                        try:
         | 
| 42 | 
            +
                            if audio_file.name.lower().endswith('.mp3'):
         | 
| 43 | 
            +
                                audio = AudioSegment.from_mp3(audio_bytes)
         | 
| 44 | 
            +
                            else:
         | 
| 45 | 
            +
                                audio = AudioSegment.from_wav(audio_bytes)
         | 
| 46 | 
            +
                            
         | 
| 47 | 
            +
                            # Standardize format
         | 
| 48 | 
            +
                            audio = audio.set_frame_rate(16000)
         | 
| 49 | 
            +
                            audio = audio.set_channels(1)
         | 
| 50 | 
            +
                            audio = audio.set_sample_width(2)
         | 
| 51 | 
            +
                            
         | 
| 52 | 
            +
                            audio.export(
         | 
| 53 | 
            +
                                tmp.name,
         | 
| 54 | 
            +
                                format="wav",
         | 
| 55 | 
            +
                                parameters=["-ac", "1", "-ar", "16000"]
         | 
| 56 | 
            +
                            )
         | 
| 57 | 
            +
                            tmp_path = tmp.name
         | 
| 58 | 
            +
                            
         | 
| 59 | 
            +
                        except Exception as e:
         | 
| 60 | 
            +
                            st.error(f"Error converting audio: {str(e)}")
         | 
| 61 | 
            +
                            return None
         | 
| 62 |  | 
| 63 | 
            +
                        diarization, transcriber, summarizer = load_models()
         | 
| 64 | 
            +
                        if not all([diarization, transcriber, summarizer]):
         | 
| 65 | 
            +
                            return "Model loading failed"
         | 
| 66 |  | 
| 67 | 
            +
                        with st.spinner("Identifying speakers..."):
         | 
| 68 | 
            +
                            diarization_result = diarization(tmp_path)
         | 
| 69 | 
            +
                        
         | 
| 70 | 
            +
                        with st.spinner("Transcribing audio..."):
         | 
| 71 | 
            +
                            transcription = transcriber.transcribe(tmp_path)
         | 
| 72 | 
            +
                            
         | 
| 73 | 
            +
                        with st.spinner("Generating summary..."):
         | 
| 74 | 
            +
                            summary = summarizer(transcription["text"], max_length=130, min_length=30)
         | 
| 75 |  | 
| 76 | 
            +
                        os.unlink(tmp_path)
         | 
| 77 | 
            +
                        
         | 
| 78 | 
            +
                        return {
         | 
| 79 | 
            +
                            "diarization": diarization_result,
         | 
| 80 | 
            +
                            "transcription": transcription,
         | 
| 81 | 
            +
                            "summary": summary[0]["summary_text"]
         | 
| 82 | 
            +
                        }
         | 
| 83 | 
            +
                        
         | 
| 84 | 
            +
                except Exception as e:
         | 
| 85 | 
            +
                    st.error(f"Error processing audio: {str(e)}")
         | 
| 86 | 
            +
                    return None
         | 
| 87 |  | 
| 88 | 
             
            def format_speaker_segments(diarization_result, transcription):
         | 
| 89 | 
            +
                if diarization_result is None:
         | 
| 90 | 
            +
                    return []
         | 
| 91 | 
            +
                    
         | 
| 92 | 
            +
                formatted_segments = []
         | 
| 93 | 
            +
                whisper_segments = transcription.get('segments', [])
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                try:
         | 
| 96 | 
            +
                    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
         | 
| 97 | 
            +
                        current_text = ""
         | 
| 98 | 
            +
                        # Find matching whisper segments for this speaker's time window
         | 
| 99 | 
            +
                        for w_segment in whisper_segments:
         | 
| 100 | 
            +
                            w_start = float(w_segment['start'])
         | 
| 101 | 
            +
                            w_end = float(w_segment['end'])
         | 
| 102 | 
            +
                            
         | 
| 103 | 
            +
                            # If whisper segment overlaps with speaker segment
         | 
| 104 | 
            +
                            if (w_start >= turn.start and w_start < turn.end) or \
         | 
| 105 | 
            +
                               (w_end > turn.start and w_end <= turn.end):
         | 
| 106 | 
            +
                                current_text += w_segment['text'].strip() + " "
         | 
| 107 | 
            +
                        
         | 
| 108 | 
            +
                        formatted_segments.append({
         | 
| 109 | 
            +
                            'speaker': str(speaker),
         | 
| 110 | 
            +
                            'start': float(turn.start),
         | 
| 111 | 
            +
                            'end': float(turn.end),
         | 
| 112 | 
            +
                            'text': current_text.strip()
         | 
| 113 | 
            +
                        })
         | 
| 114 | 
            +
                        
         | 
| 115 | 
            +
                except Exception as e:
         | 
| 116 | 
            +
                    st.error(f"Error formatting segments: {str(e)}")
         | 
| 117 | 
            +
                    return []
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
                return formatted_segments
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 120 |  | 
| 121 | 
             
            def format_timestamp(seconds):
         | 
| 122 | 
            +
                minutes = int(seconds // 60)
         | 
| 123 | 
            +
                seconds = seconds % 60
         | 
| 124 | 
            +
                return f"{minutes:02d}:{seconds:05.2f}"
         | 
| 125 |  | 
| 126 | 
             
            def main():
         | 
| 127 | 
            +
                st.title("Multi-Speaker Audio Analyzer")
         | 
| 128 | 
            +
                st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
         | 
| 129 |  | 
| 130 | 
            +
                uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
         | 
| 131 |  | 
| 132 | 
            +
                if uploaded_file:
         | 
| 133 | 
            +
                    file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
         | 
| 134 | 
            +
                    st.write(f"File size: {file_size:.2f} MB")
         | 
| 135 | 
            +
                    
         | 
| 136 | 
            +
                    st.audio(uploaded_file, format='audio/wav')
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    if st.button("Analyze Audio"):
         | 
| 139 | 
            +
                        if file_size > 200:
         | 
| 140 | 
            +
                            st.error("File size exceeds 200MB limit")
         | 
| 141 | 
            +
                        else:
         | 
| 142 | 
            +
                            results = process_audio(uploaded_file)
         | 
| 143 | 
            +
                            
         | 
| 144 | 
            +
                            if results:
         | 
| 145 | 
            +
                                tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
         | 
| 146 | 
            +
                                
         | 
| 147 | 
            +
                                with tab1:
         | 
| 148 | 
            +
                                    st.write("Speaker Timeline:")
         | 
| 149 | 
            +
                                    segments = format_speaker_segments(
         | 
| 150 | 
            +
                                        results["diarization"], 
         | 
| 151 | 
            +
                                        results["transcription"]
         | 
| 152 | 
            +
                                    )
         | 
| 153 | 
            +
                                    
         | 
| 154 | 
            +
                                    if segments:
         | 
| 155 | 
            +
                                        for segment in segments:
         | 
| 156 | 
            +
                                            col1, col2, col3 = st.columns([2,3,5])
         | 
| 157 | 
            +
                                            
         | 
| 158 | 
            +
                                            with col1:
         | 
| 159 | 
            +
                                                speaker_num = int(segment['speaker'].split('_')[1])
         | 
| 160 | 
            +
                                                colors = ['🔵', '🔴']
         | 
| 161 | 
            +
                                                speaker_color = colors[speaker_num % len(colors)]
         | 
| 162 | 
            +
                                                st.write(f"{speaker_color} {segment['speaker']}")
         | 
| 163 | 
            +
                                            
         | 
| 164 | 
            +
                                            with col2:
         | 
| 165 | 
            +
                                                start_time = format_timestamp(segment['start'])
         | 
| 166 | 
            +
                                                end_time = format_timestamp(segment['end'])
         | 
| 167 | 
            +
                                                st.write(f"{start_time} → {end_time}")
         | 
| 168 | 
            +
                                            
         | 
| 169 | 
            +
                                            with col3:
         | 
| 170 | 
            +
                                                if segment['text']:
         | 
| 171 | 
            +
                                                    st.write(f"\"{segment['text']}\"")
         | 
| 172 | 
            +
                                                else:
         | 
| 173 | 
            +
                                                    st.write("(no speech detected)")
         | 
| 174 | 
            +
                                            
         | 
| 175 | 
            +
                                            st.markdown("---")
         | 
| 176 | 
            +
                                    else:
         | 
| 177 | 
            +
                                        st.warning("No speaker segments detected")
         | 
| 178 | 
            +
                                
         | 
| 179 | 
            +
                                with tab2:
         | 
| 180 | 
            +
                                    st.write("Transcription:")
         | 
| 181 | 
            +
                                    if "text" in results["transcription"]:
         | 
| 182 | 
            +
                                        st.write(results["transcription"]["text"])
         | 
| 183 | 
            +
                                    else:
         | 
| 184 | 
            +
                                        st.warning("No transcription available")
         | 
| 185 | 
            +
                                
         | 
| 186 | 
            +
                                with tab3:
         | 
| 187 | 
            +
                                    st.write("Summary:")
         | 
| 188 | 
            +
                                    if results["summary"]:
         | 
| 189 | 
            +
                                        st.write(results["summary"])
         | 
| 190 | 
            +
                                    else:
         | 
| 191 | 
            +
                                        st.warning("No summary available")
         | 
| 192 |  | 
| 193 | 
             
            if __name__ == "__main__":
         | 
| 194 | 
            +
                main()
         | 
