Manyue-DataScientist commited on
Commit
caa4c85
·
verified ·
1 Parent(s): 61e4a9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -55
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import streamlit as st
2
  from pyannote.audio import Pipeline
3
- import whisper
4
  import tempfile
5
  import os
6
  import torch
7
  from transformers import pipeline as tf_pipeline
8
  from pydub import AudioSegment
 
9
 
10
  @st.cache_resource
11
  def load_models():
@@ -14,15 +15,12 @@ def load_models():
14
  "pyannote/speaker-diarization",
15
  use_auth_token=st.secrets["hf_token"]
16
  )
17
-
18
- transcriber = whisper.load_model("turbo")
19
-
20
  summarizer = tf_pipeline(
21
- "summarization",
22
  model="facebook/bart-large-cnn",
23
  device=0 if torch.cuda.is_available() else -1
24
  )
25
-
26
  return diarization, transcriber, summarizer
27
  except Exception as e:
28
  st.error(f"Error loading models: {str(e)}")
@@ -30,44 +28,58 @@ def load_models():
30
 
31
  def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
32
  try:
 
 
33
 
34
-
35
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
36
- # Convert MP3 to WAV if needed
37
- if audio_file.name.endswith('.mp3'):
38
- audio = AudioSegment.from_mp3(audio_file)
39
- else:
40
- audio = AudioSegment.from_wav(audio_file)
41
-
42
- # Export as WAV
43
- audio.export(tmp.name, format="wav")
44
- tmp_path = tmp.name
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Get cached models
48
- diarization, transcriber, summarizer = load_models()
49
- if not all([diarization, transcriber, summarizer]):
50
- return "Model loading failed"
51
 
52
- # Process with progress bar
53
- with st.spinner("Identifying speakers..."):
54
- diarization_result = diarization(tmp_path)
55
-
56
- with st.spinner("Transcribing audio..."):
57
- transcription = transcriber.transcribe(tmp_path)
58
 
59
- with st.spinner("Generating summary..."):
60
- summary = summarizer(transcription["text"], max_length=130, min_length=30)
 
 
 
61
 
62
- # Cleanup
63
- os.unlink(tmp_path)
64
-
65
- return {
66
- "diarization": diarization_result,
67
- "transcription": transcription["text"],
68
- "summary": summary[0]["summary_text"]
69
- }
70
-
71
  except Exception as e:
72
  st.error(f"Error processing audio: {str(e)}")
73
  return None
@@ -79,27 +91,35 @@ def main():
79
  uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
80
 
81
  if uploaded_file:
 
 
 
 
 
82
  st.audio(uploaded_file, format='audio/wav')
83
 
84
  if st.button("Analyze Audio"):
85
- results = process_audio(uploaded_file)
86
-
87
- if results:
88
- # Display results in tabs
89
- tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
90
-
91
- with tab1:
92
- st.write("Speaker Segments:")
93
- for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
94
- st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
95
-
96
- with tab2:
97
- st.write("Transcription:")
98
- st.write(results["transcription"])
99
 
100
- with tab3:
101
- st.write("Summary:")
102
- st.write(results["summary"])
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
  main()
 
1
  import streamlit as st
2
  from pyannote.audio import Pipeline
3
+ import whisper
4
  import tempfile
5
  import os
6
  import torch
7
  from transformers import pipeline as tf_pipeline
8
  from pydub import AudioSegment
9
+ import io
10
 
11
  @st.cache_resource
12
  def load_models():
 
15
  "pyannote/speaker-diarization",
16
  use_auth_token=st.secrets["hf_token"]
17
  )
18
+ transcriber = whisper.load_model("base") # Changed from turbo to base as it's more stable
 
 
19
  summarizer = tf_pipeline(
20
+ "summarization",
21
  model="facebook/bart-large-cnn",
22
  device=0 if torch.cuda.is_available() else -1
23
  )
 
24
  return diarization, transcriber, summarizer
25
  except Exception as e:
26
  st.error(f"Error loading models: {str(e)}")
 
28
 
29
  def process_audio(audio_file, max_duration=600): # limit to 5 minutes initially
30
  try:
31
+ # First, read the uploaded file into BytesIO
32
+ audio_bytes = io.BytesIO(audio_file.getvalue())
33
 
 
34
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
35
+ try:
36
+ # Convert audio to standard format
37
+ if audio_file.name.lower().endswith('.mp3'):
38
+ audio = AudioSegment.from_mp3(audio_bytes)
39
+ else:
40
+ audio = AudioSegment.from_wav(audio_bytes)
41
+
42
+ # Standardize audio format
43
+ audio = audio.set_frame_rate(16000) # Set sample rate to 16kHz
44
+ audio = audio.set_channels(1) # Convert to mono
45
+ audio = audio.set_sample_width(2) # Set to 16-bit
46
+
47
+ # Export with specific parameters
48
+ audio.export(
49
+ tmp.name,
50
+ format="wav",
51
+ parameters=["-ac", "1", "-ar", "16000"]
52
+ )
53
+ tmp_path = tmp.name
54
+
55
+ except Exception as e:
56
+ st.error(f"Error converting audio: {str(e)}")
57
+ return None
58
 
59
+ # Get cached models
60
+ diarization, transcriber, summarizer = load_models()
61
+ if not all([diarization, transcriber, summarizer]):
62
+ return "Model loading failed"
63
 
64
+ # Process with progress bar
65
+ with st.spinner("Identifying speakers..."):
66
+ diarization_result = diarization(tmp_path)
 
 
 
67
 
68
+ with st.spinner("Transcribing audio..."):
69
+ transcription = transcriber.transcribe(tmp_path)
70
+
71
+ with st.spinner("Generating summary..."):
72
+ summary = summarizer(transcription["text"], max_length=130, min_length=30)
73
 
74
+ # Cleanup
75
+ os.unlink(tmp_path)
76
+
77
+ return {
78
+ "diarization": diarization_result,
79
+ "transcription": transcription["text"],
80
+ "summary": summary[0]["summary_text"]
81
+ }
82
+
83
  except Exception as e:
84
  st.error(f"Error processing audio: {str(e)}")
85
  return None
 
91
  uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])
92
 
93
  if uploaded_file:
94
+ # Display file info
95
+ file_size = len(uploaded_file.getvalue()) / (1024 * 1024) # Convert to MB
96
+ st.write(f"File size: {file_size:.2f} MB")
97
+
98
+ # Display audio player
99
  st.audio(uploaded_file, format='audio/wav')
100
 
101
  if st.button("Analyze Audio"):
102
+ if file_size > 200:
103
+ st.error("File size exceeds 200MB limit")
104
+ else:
105
+ results = process_audio(uploaded_file)
 
 
 
 
 
 
 
 
 
 
106
 
107
+ if results:
108
+ # Display results in tabs
109
+ tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
110
+
111
+ with tab1:
112
+ st.write("Speaker Segments:")
113
+ for turn, _, speaker in results["diarization"].itertracks(yield_label=True):
114
+ st.write(f"{speaker}: {turn.start:.1f}s → {turn.end:.1f}s")
115
+
116
+ with tab2:
117
+ st.write("Transcription:")
118
+ st.write(results["transcription"])
119
+
120
+ with tab3:
121
+ st.write("Summary:")
122
+ st.write(results["summary"])
123
 
124
  if __name__ == "__main__":
125
  main()