Manyue-DataScientist commited on
Commit
935113b
Β·
verified Β·
1 Parent(s): e0d61c7

Update app.py

Browse files

Tried to fix and optimize the first part of the project, speaker diarization.

Files changed (1) hide show
  1. app.py +40 -52
app.py CHANGED
@@ -11,11 +11,19 @@ import io
11
  @st.cache_resource
12
  def load_models():
13
  try:
 
14
  diarization = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization",
16
  use_auth_token=st.secrets["hf_token"]
17
- )
18
- transcriber = whisper.load_model("base") # Changed from turbo to base as it's more stable
 
 
 
 
 
 
 
19
  summarizer = tf_pipeline(
20
  "summarization",
21
  model="facebook/bart-large-cnn",
@@ -26,25 +34,22 @@ def load_models():
26
  st.error(f"Error loading models: {str(e)}")
27
  return None, None, None
28
 
29
- def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
30
  try:
31
- # First, read the uploaded file into BytesIO
32
  audio_bytes = io.BytesIO(audio_file.getvalue())
33
 
34
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
35
  try:
36
- # Convert audio to standard format
37
  if audio_file.name.lower().endswith('.mp3'):
38
  audio = AudioSegment.from_mp3(audio_bytes)
39
  else:
40
  audio = AudioSegment.from_wav(audio_bytes)
41
 
42
- # Standardize audio format
43
- audio = audio.set_frame_rate(16000) # Set sample rate to 16kHz
44
- audio = audio.set_channels(1) # Convert to mono
45
- audio = audio.set_sample_width(2) # Set to 16-bit
46
 
47
- # Export with specific parameters
48
  audio.export(
49
  tmp.name,
50
  format="wav",
@@ -56,12 +61,10 @@ def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
56
  st.error(f"Error converting audio: {str(e)}")
57
  return None
58
 
59
- # Get cached models
60
  diarization, transcriber, summarizer = load_models()
61
  if not all([diarization, transcriber, summarizer]):
62
  return "Model loading failed"
63
 
64
- # Process with progress bar
65
  with st.spinner("Identifying speakers..."):
66
  diarization_result = diarization(tmp_path)
67
 
@@ -71,12 +74,11 @@ def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
71
  with st.spinner("Generating summary..."):
72
  summary = summarizer(transcription["text"], max_length=130, min_length=30)
73
 
74
- # Cleanup
75
  os.unlink(tmp_path)
76
 
77
  return {
78
  "diarization": diarization_result,
79
- "transcription": transcription["text"],
80
  "summary": summary[0]["summary_text"]
81
  }
82
 
@@ -84,28 +86,23 @@ def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
84
  st.error(f"Error processing audio: {str(e)}")
85
  return None
86
 
87
- def format_speaker_segments(diarization_result):
88
- """Process and format speaker segments by removing very short segments and merging consecutive ones"""
89
  formatted_segments = []
90
- min_duration = 0.3 # Minimum duration threshold in seconds
91
 
92
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
93
- duration = turn.end - turn.start
94
-
95
- # Skip very short segments
96
- if duration < min_duration:
97
  continue
98
 
99
- # Add segment if it's the first one or from a different speaker
100
- if not formatted_segments or formatted_segments[-1]['speaker'] != speaker:
101
  formatted_segments.append({
102
  'speaker': speaker,
103
  'start': turn.start,
104
- 'end': turn.end
 
105
  })
106
- # Extend the end time if it's the same speaker
107
- else:
108
- formatted_segments[-1]['end'] = turn.end
109
 
110
  return formatted_segments
111
 
@@ -116,11 +113,9 @@ def main():
116
  uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
117
 
118
  if uploaded_file:
119
- # Display file info
120
  file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
121
  st.write(f"File size: {file_size:.2f} MB")
122
 
123
- # Display audio player
124
  st.audio(uploaded_file, format='audio/wav')
125
 
126
  if st.button("Analyze Audio"):
@@ -135,42 +130,35 @@ def main():
135
  with tab1:
136
  st.write("Speaker Timeline:")
137
 
138
- # Process speaker segments
139
- segments = format_speaker_segments(results["diarization"])
 
 
140
 
141
- # Display segments in a more organized way
142
  for segment in segments:
143
- # Create columns for better layout
144
- col1, col2, col3 = st.columns([2,1,6])
145
 
146
  with col1:
147
- # Show speaker with consistent color
148
  speaker_num = int(segment['speaker'].split('_')[1])
149
- colors = ['πŸ”΅', 'πŸ”΄', '🟒', '🟑', '🟣'] # Different colors for different speakers
150
  speaker_color = colors[speaker_num % len(colors)]
151
  st.write(f"{speaker_color} {segment['speaker']}")
152
 
153
  with col2:
154
- # Format time more cleanly
155
- start_time = f"{int(segment['start']):02d}:{(segment['start']%60):04.1f}"
156
- end_time = f"{int(segment['end']):02d}:{(segment['end']%60):04.1f}"
157
- st.write(f"{start_time} β†’")
158
-
159
- with col3:
160
- st.write(f"{end_time}")
161
 
162
- # Add a small separator
163
  st.markdown("---")
164
-
165
- # Add legend
166
- st.write("\nSpeaker Legend:")
167
- for i in range(len(set(s['speaker'] for s in segments))):
168
- st.write(f"{colors[i]} SPEAKER_{i:02d}")
169
 
170
- # Keep original transcription and summary tabs
171
  with tab2:
172
  st.write("Transcription:")
173
- st.write(results["transcription"])
174
 
175
  with tab3:
176
  st.write("Summary:")
 
11
  @st.cache_resource
12
  def load_models():
13
  try:
14
+ # Updated to 3.1 with parameters
15
  diarization = Pipeline.from_pretrained(
16
+ "pyannote/speaker-diarization@3.1",
17
  use_auth_token=st.secrets["hf_token"]
18
+ ).instantiate({
19
+ "onset": 0.3,
20
+ "offset": 0.3,
21
+ "min_duration_on": 0.1,
22
+ "min_duration_off": 0.1
23
+ })
24
+
25
+ transcriber = whisper.load_model("base")
26
+
27
  summarizer = tf_pipeline(
28
  "summarization",
29
  model="facebook/bart-large-cnn",
 
34
  st.error(f"Error loading models: {str(e)}")
35
  return None, None, None
36
 
37
+ def process_audio(audio_file, max_duration=600):
38
  try:
 
39
  audio_bytes = io.BytesIO(audio_file.getvalue())
40
 
41
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
42
  try:
 
43
  if audio_file.name.lower().endswith('.mp3'):
44
  audio = AudioSegment.from_mp3(audio_bytes)
45
  else:
46
  audio = AudioSegment.from_wav(audio_bytes)
47
 
48
+ # Standardize format
49
+ audio = audio.set_frame_rate(16000)
50
+ audio = audio.set_channels(1)
51
+ audio = audio.set_sample_width(2)
52
 
 
53
  audio.export(
54
  tmp.name,
55
  format="wav",
 
61
  st.error(f"Error converting audio: {str(e)}")
62
  return None
63
 
 
64
  diarization, transcriber, summarizer = load_models()
65
  if not all([diarization, transcriber, summarizer]):
66
  return "Model loading failed"
67
 
 
68
  with st.spinner("Identifying speakers..."):
69
  diarization_result = diarization(tmp_path)
70
 
 
74
  with st.spinner("Generating summary..."):
75
  summary = summarizer(transcription["text"], max_length=130, min_length=30)
76
 
 
77
  os.unlink(tmp_path)
78
 
79
  return {
80
  "diarization": diarization_result,
81
+ "transcription": transcription, # Return full transcription object
82
  "summary": summary[0]["summary_text"]
83
  }
84
 
 
86
  st.error(f"Error processing audio: {str(e)}")
87
  return None
88
 
89
+ def format_speaker_segments(diarization_result, transcription):
 
90
  formatted_segments = []
91
+ audio_duration = transcription.get('duration', 0)
92
 
93
  for turn, _, speaker in diarization_result.itertracks(yield_label=True):
94
+ # Skip invalid timestamps
95
+ if turn.start > audio_duration or turn.end > audio_duration:
 
 
96
  continue
97
 
98
+ # Only add segments with meaningful duration
99
+ if (turn.end - turn.start) >= 0.1: # 100ms minimum
100
  formatted_segments.append({
101
  'speaker': speaker,
102
  'start': turn.start,
103
+ 'end': turn.end,
104
+ 'duration': turn.end - turn.start
105
  })
 
 
 
106
 
107
  return formatted_segments
108
 
 
113
  uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
114
 
115
  if uploaded_file:
 
116
  file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
117
  st.write(f"File size: {file_size:.2f} MB")
118
 
 
119
  st.audio(uploaded_file, format='audio/wav')
120
 
121
  if st.button("Analyze Audio"):
 
130
  with tab1:
131
  st.write("Speaker Timeline:")
132
 
133
+ segments = format_speaker_segments(
134
+ results["diarization"],
135
+ results["transcription"]
136
+ )
137
 
138
+ # Display segments with proper time formatting
139
  for segment in segments:
140
+ col1, col2 = st.columns([2,8])
 
141
 
142
  with col1:
 
143
  speaker_num = int(segment['speaker'].split('_')[1])
144
+ colors = ['πŸ”΅', 'πŸ”΄'] # Simplified to two colors
145
  speaker_color = colors[speaker_num % len(colors)]
146
  st.write(f"{speaker_color} {segment['speaker']}")
147
 
148
  with col2:
149
+ mm_start = int(segment['start'] // 60)
150
+ ss_start = segment['start'] % 60
151
+ mm_end = int(segment['end'] // 60)
152
+ ss_end = segment['end'] % 60
153
+
154
+ time_str = f"{mm_start:02d}:{ss_start:05.2f} β†’ {mm_end:02d}:{ss_end:05.2f}"
155
+ st.write(time_str)
156
 
 
157
  st.markdown("---")
 
 
 
 
 
158
 
 
159
  with tab2:
160
  st.write("Transcription:")
161
+ st.write(results["transcription"]["text"])
162
 
163
  with tab3:
164
  st.write("Summary:")