Manyue-DataScientist commited on
Commit
7f8e922
·
verified ·
1 Parent(s): 1a56382

Update app.py

Browse files

changed as per the trained model.

Files changed (1) hide show
  1. app.py +115 -152
app.py CHANGED
@@ -7,163 +7,126 @@ Date: January 2025
7
  """
8
 
9
  import streamlit as st
10
- from src.models.diarization import SpeakerDiarizer
11
- from src.models.transcription import Transcriber
12
- from src.models.summarization import Summarizer
13
- from src.utils.audio_processor import AudioProcessor
14
- from src.utils.formatter import TimeFormatter
15
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Cache for model loading
18
  @st.cache_resource
19
  def load_models():
20
- """
21
- Load and cache all required models.
22
-
23
- Returns:
24
- tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
25
- """
26
- try:
27
- diarizer = SpeakerDiarizer(st.secrets["hf_token"])
28
- diarizer_model = diarizer.load_model()
29
-
30
- transcriber = Transcriber()
31
- transcriber_model = transcriber.load_model()
32
-
33
- summarizer = Summarizer()
34
- summarizer_model = summarizer.load_model()
35
-
36
- if not all([diarizer_model, transcriber_model, summarizer_model]):
37
- raise ValueError("One or more models failed to load")
38
-
39
- return diarizer, transcriber, summarizer
40
- except Exception as e:
41
- st.error(f"Error loading models: {str(e)}")
42
- st.error("Debug info: Check if HF token is valid and has necessary permissions")
43
- return None, None, None
44
-
45
- def process_audio(audio_file, max_duration=600):
46
- """
47
- Process the uploaded audio file through all models.
48
-
49
- Args:
50
- audio_file: Uploaded audio file
51
- max_duration (int): Maximum duration in seconds
52
-
53
- Returns:
54
- dict: Processing results containing diarization, transcription, and summary
55
- """
56
- try:
57
- # Process audio file
58
- audio_processor = AudioProcessor()
59
- tmp_path = audio_processor.standardize_audio(audio_file)
60
-
61
- # Load models
62
- diarizer, transcriber, summarizer = load_models()
63
- if not all([diarizer, transcriber, summarizer]):
64
- return "Model loading failed"
65
-
66
- # Process with each model
67
- with st.spinner("Identifying speakers..."):
68
- diarization_result = diarizer.process(tmp_path)
69
-
70
- with st.spinner("Transcribing audio..."):
71
- transcription = transcriber.process(tmp_path)
72
-
73
- with st.spinner("Generating summary..."):
74
- summary = summarizer.process(transcription["text"])
75
-
76
- # Cleanup
77
- os.unlink(tmp_path)
78
-
79
- return {
80
- "diarization": diarization_result,
81
- "transcription": transcription,
82
- "summary": summary[0]["summary_text"]
83
- }
84
-
85
- except Exception as e:
86
- st.error(f"Error processing audio: {str(e)}")
87
- return None
88
 
89
  def main():
90
- """Main application function."""
91
- st.title("Multi-Speaker Audio Analyzer")
92
- st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")
93
-
94
- uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
95
-
96
- if uploaded_file:
97
- file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
98
- st.write(f"File size: {file_size:.2f} MB")
99
-
100
- st.audio(uploaded_file, format='audio/wav')
101
-
102
- if st.button("Analyze Audio"):
103
- if file_size > 200:
104
- st.error("File size exceeds 200MB limit")
105
- else:
106
- results = process_audio(uploaded_file)
107
-
108
- if results:
109
- tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
110
-
111
- # Display speaker timeline
112
- with tab1:
113
- display_speaker_timeline(results)
114
-
115
- # Display transcription
116
- with tab2:
117
- display_transcription(results)
118
-
119
- # Display summary
120
- with tab3:
121
- display_summary(results)
122
-
123
- def display_speaker_timeline(results):
124
- """Display speaker diarization results in a timeline format."""
125
- st.write("Speaker Timeline:")
126
- segments = TimeFormatter.format_speaker_segments(
127
- results["diarization"],
128
- results["transcription"]
129
- )
130
-
131
- if segments:
132
- for segment in segments:
133
- col1, col2, col3 = st.columns([2,3,5])
134
-
135
- with col1:
136
- display_speaker_info(segment)
137
-
138
- with col2:
139
- display_timestamp(segment)
140
-
141
- with col3:
142
- display_text(segment)
143
-
144
- st.markdown("---")
145
- else:
146
- st.warning("No speaker segments detected")
147
-
148
- def display_speaker_info(segment):
149
- """Display speaker information with color coding."""
150
- speaker_num = int(segment['speaker'].split('_')[1])
151
- colors = ['🔵', '🔴']
152
- speaker_color = colors[speaker_num % len(colors)]
153
- st.write(f"{speaker_color} {segment['speaker']}")
154
-
155
- def display_timestamp(segment):
156
- """Display formatted timestamps."""
157
- start_time = TimeFormatter.format_timestamp(segment['start'])
158
- end_time = TimeFormatter.format_timestamp(segment['end'])
159
- st.write(f"{start_time} → {end_time}")
160
-
161
- def display_text(segment):
162
- """Display speaker's text."""
163
- if segment['text']:
164
- st.write(f"\"{segment['text']}\"")
165
- else:
166
- st.write("(no speech detected)")
167
 
168
  if __name__ == "__main__":
169
- main()
 
7
  """
8
 
9
  import streamlit as st
10
+ from pyannote.audio import Pipeline
11
+ import whisper
12
+ import tempfile
 
 
13
  import os
14
+ import torch
15
+ from transformers import pipeline as tf_pipeline, BartTokenizer
16
+ from pydub import AudioSegment
17
+ import io
18
+ import pickle
19
+
20
+ class SpeakerDiarizer:
21
+ def __init__(self, token):
22
+ self.pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=token)
23
+
24
+ def process(self, audio_file):
25
+ return self.pipeline(audio_file)
26
+
27
+ class Transcriber:
28
+ def __init__(self):
29
+ self.model = whisper.load_model("base")
30
+
31
+ def process(self, audio_file):
32
+ return self.model.transcribe(audio_file)["text"]
33
+
34
+ class Summarizer:
35
+ def __init__(self, model_path='bart_ami_finetuned.pkl'):
36
+ self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
37
+ with open(model_path, 'rb') as f:
38
+ self.model = pickle.load(f)
39
+
40
+ def process(self, text):
41
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
42
+ summary_ids = self.model.generate(inputs["input_ids"], max_length=150, min_length=40)
43
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
 
 
45
  @st.cache_resource
46
  def load_models():
47
+ try:
48
+ diarizer = SpeakerDiarizer(st.secrets["hf_token"])
49
+ transcriber = Transcriber()
50
+ summarizer = Summarizer()
51
+ return diarizer, transcriber, summarizer
52
+ except Exception as e:
53
+ st.error(f"Error loading models: {str(e)}")
54
+ return None, None, None
55
+
56
+ def process_audio(audio_file):
57
+ try:
58
+ audio_bytes = io.BytesIO(audio_file.getvalue())
59
+
60
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
61
+ if audio_file.name.lower().endswith('.mp3'):
62
+ audio = AudioSegment.from_mp3(audio_bytes)
63
+ else:
64
+ audio = AudioSegment.from_wav(audio_bytes)
65
+
66
+ audio = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2)
67
+ audio.export(tmp.name, format="wav", parameters=["-ac", "1", "-ar", "16000"])
68
+ tmp_path = tmp.name
69
+
70
+ diarizer, transcriber, summarizer = load_models()
71
+ if not all([diarizer, transcriber, summarizer]):
72
+ return "Model loading failed"
73
+
74
+ with st.spinner("Processing..."):
75
+ diarization = diarizer.process(tmp_path)
76
+ transcription = transcriber.process(tmp_path)
77
+ summary = summarizer.process(transcription)
78
+
79
+ os.unlink(tmp_path)
80
+
81
+ return {
82
+ "diarization": diarization,
83
+ "transcription": transcription,
84
+ "summary": summary
85
+ }
86
+
87
+ except Exception as e:
88
+ st.error(f"Error: {str(e)}")
89
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def main():
92
+ st.title("Multi-Speaker Audio Analyzer")
93
+ st.write("Upload an audio file (MP3/WAV) up to 5 minutes long")
94
+
95
+ uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
96
+
97
+ if uploaded_file:
98
+ file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
99
+ st.write(f"File size: {file_size:.2f} MB")
100
+
101
+ st.audio(uploaded_file, format='audio/wav')
102
+
103
+ if st.button("Analyze Audio"):
104
+ if file_size > 200:
105
+ st.error("File size exceeds 200MB limit")
106
+ else:
107
+ results = process_audio(uploaded_file)
108
+
109
+ if results:
110
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
111
+
112
+ with tab1:
113
+ st.write("Speaker Timeline:")
114
+ for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
115
+ col1, col2, col3 = st.columns([2,3,5])
116
+ with col1:
117
+ speaker_num = int(speaker.split('_')[1])
118
+ colors = ['🔵', '🔴']
119
+ st.write(f"{colors[speaker_num % 2]} {speaker}")
120
+ with col2:
121
+ st.write(f"{format_timestamp(turn.start)} {format_timestamp(turn.end)}")
122
+
123
+ with tab2:
124
+ st.write("Transcription:")
125
+ st.write(results["transcription"])
126
+
127
+ with tab3:
128
+ st.write("Summary:")
129
+ st.write(results["summary"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
+ main()