Manyue-DataScientist commited on
Commit
0d5109a
·
verified ·
1 Parent(s): 08d05f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -115
app.py CHANGED
@@ -7,126 +7,163 @@ Date: January 2025
7
  """
8
 
9
  import streamlit as st
10
- from pyannote.audio import Pipeline
11
- import whisper
12
- import tempfile
 
 
13
  import os
14
- import torch
15
- from transformers import pipeline as tf_pipeline, BartTokenizer
16
- from pydub import AudioSegment
17
- import io
18
- import pickle
19
-
20
- class SpeakerDiarizer:
21
- def __init__(self, token):
22
- self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=token)
23
-
24
- def process(self, audio_file):
25
- return self.pipeline(audio_file)
26
-
27
- class Transcriber:
28
- def __init__(self):
29
- self.model = whisper.load_model("base")
30
-
31
- def process(self, audio_file):
32
- return self.model.transcribe(audio_file)["text"]
33
-
34
- class Summarizer:
35
- def __init__(self, model_path='bart_ami_finetuned.pkl'):
36
- self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
37
- with open(model_path, 'rb') as f:
38
- self.model = pickle.load(f)
39
-
40
- def process(self, text):
41
- inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
42
- summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
43
- return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
 
 
45
  @st.cache_resource
46
  def load_models():
47
- try:
48
- diarizer = SpeakerDiarizer(st.secrets["hf_token"])
49
- transcriber = Transcriber()
50
- summarizer = Summarizer()
51
- return diarizer, transcriber, summarizer
52
- except Exception as e:
53
- st.error(f"Error loading models: {str(e)}")
54
- return None, None, None
55
-
56
- def process_audio(audio_file):
57
- try:
58
- audio_bytes = io.BytesIO(audio_file.getvalue())
59
-
60
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
61
- if audio_file.name.lower().endswith('.mp3'):
62
- audio = AudioSegment.from_mp3(audio_bytes)
63
- else:
64
- audio = AudioSegment.from_wav(audio_bytes)
65
-
66
- audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
67
- audio.export(tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"])
68
- tmp_path = tmp.name
69
-
70
- diarizer, transcriber, summarizer = load_models()
71
- if not all([diarizer, transcriber, summarizer]):
72
- return "Model loading failed"
73
-
74
- with st.spinner("Processing..."):
75
- diarization = diarizer.process(tmp_path)
76
- transcription = transcriber.process(tmp_path)
77
- summary = summarizer.process(transcription)
78
-
79
- os.unlink(tmp_path)
80
-
81
- return {
82
- "diarization": diarization,
83
- "transcription": transcription,
84
- "summary": summary
85
- }
86
-
87
- except Exception as e:
88
- st.error(f"Error: {str(e)}")
89
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def main():
92
- st.title("Multi-Speaker Audio Analyzer")
93
- st.write("Upload an audio file (MP3/WAV) up to 5 minutes long")
94
-
95
- uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
96
-
97
- if uploaded_file:
98
- file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
99
- st.write(f"File size: {file_size:.2f} MB")
100
-
101
- st.audio(uploaded_file, format='audio/wav')
102
-
103
- if st.button("Analyze Audio"):
104
- if file_size > 200:
105
- st.error("File size exceeds 200MB limit")
106
- else:
107
- results = process_audio(uploaded_file)
108
-
109
- if results:
110
- tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
111
-
112
- with tab1:
113
- st.write("Speaker Timeline:")
114
- for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
115
- col1, col2, col3 = st.columns([2,3,5])
116
- with col1:
117
- speaker_num = int(speaker.split('_')[1])
118
- colors = ['🔵', '🔴']
119
- st.write(f"{colors[speaker_num % 2]} {speaker}")
120
- with col2:
121
- st.write(f"{format_timestamp(turn.start)} {format_timestamp(turn.end)}")
122
-
123
- with tab2:
124
- st.write("Transcription:")
125
- st.write(results["transcription"])
126
-
127
- with tab3:
128
- st.write("Summary:")
129
- st.write(results["summary"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
- main()
 
7
  """
8
 
9
  import streamlit as st
10
+ from src.models.diarization import SpeakerDiarizer
11
+ from src.models.transcription import Transcriber
12
+ from src.models.summarization import Summarizer
13
+ from src.utils.audio_processor import AudioProcessor
14
+ from src.utils.formatter import TimeFormatter
15
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Cache for model loading
18
  @st.cache_resource
19
  def load_models():
20
+ """
21
+ Load and cache all required models.
22
+
23
+ Returns:
24
+ tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
25
+ """
26
+ try:
27
+ diarizer = SpeakerDiarizer(st.secrets["hf_token"])
28
+ diarizer_model = diarizer.load_model()
29
+
30
+ transcriber = Transcriber()
31
+ transcriber_model = transcriber.load_model()
32
+
33
+ summarizer = Summarizer()
34
+ summarizer_model = summarizer.load_model()
35
+
36
+ if not all([diarizer_model, transcriber_model, summarizer_model]):
37
+ raise ValueError("One or more models failed to load")
38
+
39
+ return diarizer, transcriber, summarizer
40
+ except Exception as e:
41
+ st.error(f"Error loading models: {str(e)}")
42
+ st.error("Debug info: Check if HF token is valid and has necessary permissions")
43
+ return None, None, None
44
+
45
+ def process_audio(audio_file, max_duration=600):
46
+ """
47
+ Process the uploaded audio file through all models.
48
+
49
+ Args:
50
+ audio_file: Uploaded audio file
51
+ max_duration (int): Maximum duration in seconds
52
+
53
+ Returns:
54
+ dict: Processing results containing diarization, transcription, and summary
55
+ """
56
+ try:
57
+ # Process audio file
58
+ audio_processor = AudioProcessor()
59
+ tmp_path = audio_processor.standardize_audio(audio_file)
60
+
61
+ # Load models
62
+ diarizer, transcriber, summarizer = load_models()
63
+ if not all([diarizer, transcriber, summarizer]):
64
+ return "Model loading failed"
65
+
66
+ # Process with each model
67
+ with st.spinner("Identifying speakers..."):
68
+ diarization_result = diarizer.process(tmp_path)
69
+
70
+ with st.spinner("Transcribing audio..."):
71
+ transcription = transcriber.process(tmp_path)
72
+
73
+ with st.spinner("Generating summary..."):
74
+ summary = summarizer.process(transcription["text"])
75
+
76
+ # Cleanup
77
+ os.unlink(tmp_path)
78
+
79
+ return {
80
+ "diarization": diarization_result,
81
+ "transcription": transcription,
82
+ "summary": summary[0]["summary_text"]
83
+ }
84
+
85
+ except Exception as e:
86
+ st.error(f"Error processing audio: {str(e)}")
87
+ return None
88
 
89
  def main():
90
+ """Main application function."""
91
+ st.title("Multi-Speaker Audio Analyzer")
92
+ st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
93
+
94
+ uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
95
+
96
+ if uploaded_file:
97
+ file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
98
+ st.write(f"File size: {file_size:.2f} MB")
99
+
100
+ st.audio(uploaded_file, format='audio/wav')
101
+
102
+ if st.button("Analyze Audio"):
103
+ if file_size > 200:
104
+ st.error("File size exceeds 200MB limit")
105
+ else:
106
+ results = process_audio(uploaded_file)
107
+
108
+ if results:
109
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
110
+
111
+ # Display speaker timeline
112
+ with tab1:
113
+ display_speaker_timeline(results)
114
+
115
+ # Display transcription
116
+ with tab2:
117
+ display_transcription(results)
118
+
119
+ # Display summary
120
+ with tab3:
121
+ display_summary(results)
122
+
123
+ def display_speaker_timeline(results):
124
+ """Display speaker diarization results in a timeline format."""
125
+ st.write("Speaker Timeline:")
126
+ segments = TimeFormatter.format_speaker_segments(
127
+ results["diarization"],
128
+ results["transcription"]
129
+ )
130
+
131
+ if segments:
132
+ for segment in segments:
133
+ col1, col2, col3 = st.columns([2,3,5])
134
+
135
+ with col1:
136
+ display_speaker_info(segment)
137
+
138
+ with col2:
139
+ display_timestamp(segment)
140
+
141
+ with col3:
142
+ display_text(segment)
143
+
144
+ st.markdown("---")
145
+ else:
146
+ st.warning("No speaker segments detected")
147
+
148
+ def display_speaker_info(segment):
149
+ """Display speaker information with color coding."""
150
+ speaker_num = int(segment['speaker'].split('_')[1])
151
+ colors = ['🔵', '🔴']
152
+ speaker_color = colors[speaker_num % len(colors)]
153
+ st.write(f"{speaker_color} {segment['speaker']}")
154
+
155
+ def display_timestamp(segment):
156
+ """Display formatted timestamps."""
157
+ start_time = TimeFormatter.format_timestamp(segment['start'])
158
+ end_time = TimeFormatter.format_timestamp(segment['end'])
159
+ st.write(f"{start_time} → {end_time}")
160
+
161
+ def display_text(segment):
162
+ """Display speaker's text."""
163
+ if segment['text']:
164
+ st.write(f"\"{segment['text']}\"")
165
+ else:
166
+ st.write("(no speech detected)")
167
 
168
  if __name__ == "__main__":
169
+ main()