File size: 4,676 Bytes
2a6784d
 
caa4c85
b3635dd
 
 
 
853df82
caa4c85
2a6784d
b3635dd
 
 
 
 
 
 
caa4c85
b3635dd
caa4c85
b3635dd
 
 
 
 
 
 
2a6784d
f1a85dc
b3635dd
caa4c85
 
853df82
b3635dd
caa4c85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
caa4c85
 
 
 
2a6784d
caa4c85
 
 
b3635dd
caa4c85
 
 
 
 
2a6784d
caa4c85
 
 
 
 
 
 
 
 
b3635dd
 
 
2a6784d
b3635dd
 
 
2a6784d
b3635dd
2a6784d
b3635dd
caa4c85
 
 
 
 
b3635dd
 
 
caa4c85
 
 
 
b3635dd
caa4c85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a6784d
b3635dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from pyannote.audio import Pipeline
import whisper
import tempfile
import os
import torch
from transformers import pipeline as tf_pipeline
from pydub import AudioSegment
import io

@st.cache_resource
def load_models():
    try:
        diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization",
            use_auth_token=st.secrets["hf_token"]
        )
        transcriber = whisper.load_model("base")  # Changed from turbo to base as it's more stable
        summarizer = tf_pipeline(
            "summarization",
            model="facebook/bart-large-cnn",
            device=0 if torch.cuda.is_available() else -1
        )
        return diarization, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        return None, None, None

def process_audio(audio_file, max_duration=600):  # limit to 5 minutes initially
    try:
        # First, read the uploaded file into BytesIO
        audio_bytes = io.BytesIO(audio_file.getvalue())
        
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            try:
                # Convert audio to standard format
                if audio_file.name.lower().endswith('.mp3'):
                    audio = AudioSegment.from_mp3(audio_bytes)
                else:
                    audio = AudioSegment.from_wav(audio_bytes)
                
                # Standardize audio format
                audio = audio.set_frame_rate(16000)  # Set sample rate to 16kHz
                audio = audio.set_channels(1)        # Convert to mono
                audio = audio.set_sample_width(2)    # Set to 16-bit
                
                # Export with specific parameters
                audio.export(
                    tmp.name,
                    format="wav",
                    parameters=["-ac", "1", "-ar", "16000"]
                )
                tmp_path = tmp.name
                
            except Exception as e:
                st.error(f"Error converting audio: {str(e)}")
                return None

            # Get cached models
            diarization, transcriber, summarizer = load_models()
            if not all([diarization, transcriber, summarizer]):
                return "Model loading failed"

            # Process with progress bar
            with st.spinner("Identifying speakers..."):
                diarization_result = diarization(tmp_path)
            
            with st.spinner("Transcribing audio..."):
                transcription = transcriber.transcribe(tmp_path)
                
            with st.spinner("Generating summary..."):
                summary = summarizer(transcription["text"], max_length=130, min_length=30)

            # Cleanup
            os.unlink(tmp_path)
            
            return {
                "diarization": diarization_result,
                "transcription": transcription["text"],
                "summary": summary[0]["summary_text"]
            }
            
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        # Display file info
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)  # Convert to MB
        st.write(f"File size: {file_size:.2f} MB")
        
        # Display audio player
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    # Display results in tabs
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    with tab1:
                        st.write("Speaker Segments:")
                        for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
                            st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
                    
                    with tab2:
                        st.write("Transcription:")
                        st.write(results["transcription"])
                    
                    with tab3:
                        st.write("Summary:")
                        st.write(results["summary"])

if __name__ == "__main__":
    main()