Manyue-DataScientist commited on
Commit
d6e0f11
·
verified ·
1 Parent(s): 83bc687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -124
app.py CHANGED
@@ -10,142 +10,166 @@ import io
10
 
11
  @st.cache_resource
12
  def load_models():
13
- try:
14
- diarization = Pipeline.from_pretrained(
15
- "pyannote/speaker-diarization-3.1",
16
- use_auth_token=st.secrets["hf_token"]
17
- )
18
-
19
- transcriber = whisper.load_model("small")
20
-
21
- summarizer = tf_pipeline(
22
- "summarization",
23
- model="facebook/bart-large-cnn",
24
- device=0 if torch.cuda.is_available() else -1
25
- )
26
- return diarization, transcriber, summarizer
27
- except Exception as e:
28
- st.error(f"Error loading models: {str(e)}")
29
- return None, None, None
 
 
 
 
 
 
 
30
 
31
  def process_audio(audio_file, max_duration=600):
32
- try:
33
- audio_bytes = io.BytesIO(audio_file.getvalue())
34
-
35
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
36
- try:
37
- if audio_file.name.lower().endswith('.mp3'):
38
- audio = AudioSegment.from_mp3(audio_bytes)
39
- else:
40
- audio = AudioSegment.from_wav(audio_bytes)
41
-
42
- # Standardize format
43
- audio = audio.set_frame_rate(16000)
44
- audio = audio.set_channels(1)
45
- audio = audio.set_sample_width(2)
46
-
47
- audio.export(
48
- tmp.name,
49
- format="wav",
50
- parameters=["-ac", "1", "-ar", "16000"]
51
- )
52
- tmp_path = tmp.name
53
-
54
- except Exception as e:
55
- st.error(f"Error converting audio: {str(e)}")
56
- return None
57
 
58
- diarization, transcriber, summarizer = load_models()
59
- if not all([diarization, transcriber, summarizer]):
60
- return "Model loading failed"
61
 
62
- with st.spinner("Identifying speakers..."):
63
- diarization_result = diarization(tmp_path)
64
-
65
- with st.spinner("Transcribing audio..."):
66
- transcription = transcriber.transcribe(tmp_path)
67
-
68
- with st.spinner("Generating summary..."):
69
- summary = summarizer(transcription["text"], max_length=130, min_length=30)
70
 
71
- os.unlink(tmp_path)
72
-
73
- return {
74
- "diarization": diarization_result,
75
- "transcription": transcription,
76
- "summary": summary[0]["summary_text"]
77
- }
78
-
79
- except Exception as e:
80
- st.error(f"Error processing audio: {str(e)}")
81
- return None
82
 
83
  def format_speaker_segments(diarization_result):
84
- formatted_segments = []
85
-
86
- for turn, _, speaker in diarization_result.itertracks(yield_label=True):
87
- if turn.start is not None and turn.end is not None:
88
- formatted_segments.append({
89
- 'speaker': speaker,
90
- 'start': float(turn.start),
91
- 'end': float(turn.end)
92
- })
93
-
94
- return formatted_segments
 
 
 
 
 
95
 
96
  def format_timestamp(seconds):
97
- minutes = int(seconds // 60)
98
- seconds = seconds % 60
99
- return f"{minutes:02d}:{seconds:05.2f}"
100
 
101
  def main():
102
- st.title("Multi-Speaker Audio Analyzer")
103
- st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
104
 
105
- uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
106
 
107
- if uploaded_file:
108
- file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
109
- st.write(f"File size: {file_size:.2f} MB")
110
-
111
- st.audio(uploaded_file, format='audio/wav')
112
-
113
- if st.button("Analyze Audio"):
114
- if file_size > 200:
115
- st.error("File size exceeds 200MB limit")
116
- else:
117
- results = process_audio(uploaded_file)
118
-
119
- if results:
120
- tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
121
-
122
- with tab1:
123
- st.write("Speaker Timeline:")
124
- segments = format_speaker_segments(results["diarization"])
125
-
126
- for segment in segments:
127
- col1, col2 = st.columns([2,8])
128
-
129
- with col1:
130
- speaker_num = int(segment['speaker'].split('_')[1])
131
- colors = ['🔵', '🔴'] # Two colors for alternating speakers
132
- speaker_color = colors[speaker_num % len(colors)]
133
- st.write(f"{speaker_color} {segment['speaker']}")
134
-
135
- with col2:
136
- start_time = format_timestamp(segment['start'])
137
- end_time = format_timestamp(segment['end'])
138
- st.write(f"{start_time} → {end_time}")
139
-
140
- st.markdown("---")
141
-
142
- with tab2:
143
- st.write("Transcription:")
144
- st.write(results["transcription"]["text"])
145
-
146
- with tab3:
147
- st.write("Summary:")
148
- st.write(results["summary"])
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
- main()
 
10
 
11
  @st.cache_resource
12
  def load_models():
13
+ try:
14
+ # Back to original model name
15
+ diarization = Pipeline.from_pretrained(
16
+ "pyannote/speaker-diarization", # Original model name
17
+ use_auth_token=st.secrets["hf_token"]
18
+ )
19
+
20
+ transcriber = whisper.load_model("base")
21
+
22
+ summarizer = tf_pipeline(
23
+ "summarization",
24
+ model="facebook/bart-large-cnn",
25
+ device=0 if torch.cuda.is_available() else -1
26
+ )
27
+
28
+ # Validate models loaded correctly
29
+ if not diarization or not transcriber or not summarizer:
30
+ raise ValueError("One or more models failed to load")
31
+
32
+ return diarization, transcriber, summarizer
33
+ except Exception as e:
34
+ st.error(f"Error loading models: {str(e)}")
35
+ st.error("Debug info: Check if HF token is valid and has necessary permissions")
36
+ return None, None, None
37
 
38
  def process_audio(audio_file, max_duration=600):
39
+ try:
40
+ audio_bytes = io.BytesIO(audio_file.getvalue())
41
+
42
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
43
+ try:
44
+ if audio_file.name.lower().endswith('.mp3'):
45
+ audio = AudioSegment.from_mp3(audio_bytes)
46
+ else:
47
+ audio = AudioSegment.from_wav(audio_bytes)
48
+
49
+ # Standardize format
50
+ audio = audio.set_frame_rate(16000)
51
+ audio = audio.set_channels(1)
52
+ audio = audio.set_sample_width(2)
53
+
54
+ audio.export(
55
+ tmp.name,
56
+ format="wav",
57
+ parameters=["-ac", "1", "-ar", "16000"]
58
+ )
59
+ tmp_path = tmp.name
60
+
61
+ except Exception as e:
62
+ st.error(f"Error converting audio: {str(e)}")
63
+ return None
64
 
65
+ diarization, transcriber, summarizer = load_models()
66
+ if not all([diarization, transcriber, summarizer]):
67
+ return "Model loading failed"
68
 
69
+ with st.spinner("Identifying speakers..."):
70
+ diarization_result = diarization(tmp_path)
71
+
72
+ with st.spinner("Transcribing audio..."):
73
+ transcription = transcriber.transcribe(tmp_path)
74
+
75
+ with st.spinner("Generating summary..."):
76
+ summary = summarizer(transcription["text"], max_length=130, min_length=30)
77
 
78
+ os.unlink(tmp_path)
79
+
80
+ return {
81
+ "diarization": diarization_result,
82
+ "transcription": transcription,
83
+ "summary": summary[0]["summary_text"]
84
+ }
85
+
86
+ except Exception as e:
87
+ st.error(f"Error processing audio: {str(e)}")
88
+ return None
89
 
90
  def format_speaker_segments(diarization_result):
91
+ if diarization_result is None:
92
+ return []
93
+
94
+ formatted_segments = []
95
+ try:
96
+ for turn, _, speaker in diarization_result.itertracks(yield_label=True):
97
+ formatted_segments.append({
98
+ 'speaker': str(speaker), # Ensure string
99
+ 'start': float(turn.start) if turn.start is not None else 0.0,
100
+ 'end': float(turn.end) if turn.end is not None else 0.0
101
+ })
102
+ except Exception as e:
103
+ st.error(f"Error formatting segments: {str(e)}")
104
+ return []
105
+
106
+ return formatted_segments
107
 
108
  def format_timestamp(seconds):
109
+ minutes = int(seconds // 60)
110
+ seconds = seconds % 60
111
+ return f"{minutes:02d}:{seconds:05.2f}"
112
 
113
  def main():
114
+ st.title("Multi-Speaker Audio Analyzer")
115
+ st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
116
 
117
+ uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
118
 
119
+ if uploaded_file:
120
+ file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
121
+ st.write(f"File size: {file_size:.2f} MB")
122
+
123
+ st.audio(uploaded_file, format='audio/wav')
124
+
125
+ if st.button("Analyze Audio"):
126
+ if file_size > 200:
127
+ st.error("File size exceeds 200MB limit")
128
+ else:
129
+ results = process_audio(uploaded_file)
130
+
131
+ if results:
132
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
133
+
134
+ with tab1:
135
+ st.write("Speaker Timeline:")
136
+ segments = format_speaker_segments(results["diarization"])
137
+
138
+ if segments: # Only proceed if we have segments
139
+ for segment in segments:
140
+ col1, col2 = st.columns([2,8])
141
+
142
+ with col1:
143
+ try:
144
+ speaker_num = int(segment['speaker'].split('_')[1])
145
+ colors = ['🔵', '🔴'] # Two colors for alternating speakers
146
+ speaker_color = colors[speaker_num % len(colors)]
147
+ st.write(f"{speaker_color} {segment['speaker']}")
148
+ except (IndexError, ValueError) as e:
149
+ st.write(f"⚪ {segment['speaker']}")
150
+
151
+ with col2:
152
+ start_time = format_timestamp(segment['start'])
153
+ end_time = format_timestamp(segment['end'])
154
+ st.write(f"{start_time} → {end_time}")
155
+
156
+ st.markdown("---")
157
+ else:
158
+ st.warning("No speaker segments detected")
159
+
160
+ with tab2:
161
+ st.write("Transcription:")
162
+ if "text" in results["transcription"]:
163
+ st.write(results["transcription"]["text"])
164
+ else:
165
+ st.warning("No transcription available")
166
+
167
+ with tab3:
168
+ st.write("Summary:")
169
+ if results["summary"]:
170
+ st.write(results["summary"])
171
+ else:
172
+ st.warning("No summary available")
173
 
174
  if __name__ == "__main__":
175
+ main()