dwarkesh commited on
Commit
0fd2cd9
·
1 Parent(s): 73f2e8d
Files changed (1) hide show
  1. scripts/transcript.py +92 -149
scripts/transcript.py CHANGED
@@ -4,18 +4,19 @@ from pathlib import Path
4
  import json
5
  import hashlib
6
  import os
7
- from typing import List, Optional
8
  import assemblyai as aai
9
  from google import generativeai
10
  from pydub import AudioSegment
11
  import asyncio
12
  import io
 
 
13
 
14
 
15
  @dataclass
16
  class Utterance:
17
  """A single utterance from a speaker"""
18
-
19
  speaker: str
20
  text: str
21
  start: int # timestamp in ms from AssemblyAI
@@ -24,11 +25,10 @@ class Utterance:
24
  @property
25
  def timestamp(self) -> str:
26
  """Format start time as HH:MM:SS"""
27
- seconds = int(self.start // 1000) # Convert ms to seconds
28
  hours = seconds // 3600
29
  minutes = (seconds % 3600) // 60
30
  seconds = seconds % 60
31
-
32
  return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
33
 
34
 
@@ -42,49 +42,33 @@ class Transcriber:
42
 
43
  def get_transcript(self, audio_path: Path) -> List[Utterance]:
44
  """Get transcript, using cache if available"""
45
- cached = self._get_cached(audio_path)
46
- if cached:
47
- print("Using cached AssemblyAI transcript...")
48
- return cached
49
-
50
- print("Getting new transcript from AssemblyAI...")
51
- return self._get_fresh(audio_path)
52
-
53
- def _get_cached(self, audio_path: Path) -> Optional[List[Utterance]]:
54
- """Try to get transcript from cache"""
55
  cache_file = self.cache_dir / f"{audio_path.stem}.json"
56
- if not cache_file.exists():
57
- return None
58
-
59
- with open(cache_file) as f:
60
- data = json.load(f)
61
- if data["hash"] != self._get_file_hash(audio_path):
62
- return None
63
-
64
- return [Utterance(**u) for u in data["utterances"]]
65
 
66
- def _get_fresh(self, audio_path: Path) -> List[Utterance]:
67
- """Get new transcript from AssemblyAI"""
68
  config = aai.TranscriptionConfig(speaker_labels=True, language_code="en")
69
  transcript = aai.Transcriber().transcribe(str(audio_path), config=config)
70
-
71
  utterances = [
72
  Utterance(speaker=u.speaker, text=u.text, start=u.start, end=u.end)
73
  for u in transcript.utterances
74
  ]
75
-
76
- self._save_cache(audio_path, utterances)
77
- return utterances
78
-
79
- def _save_cache(self, audio_path: Path, utterances: List[Utterance]) -> None:
80
- """Save transcript to cache"""
81
- cache_file = self.cache_dir / f"{audio_path.stem}.json"
82
- data = {
83
  "hash": self._get_file_hash(audio_path),
84
  "utterances": [vars(u) for u in utterances],
85
  }
86
  with open(cache_file, "w") as f:
87
- json.dump(data, f, indent=2)
 
 
88
 
89
  def _get_file_hash(self, file_path: Path) -> str:
90
  """Calculate MD5 hash of a file"""
@@ -101,70 +85,67 @@ class Enhancer:
101
  def __init__(self, api_key: str):
102
  generativeai.configure(api_key=api_key)
103
  self.model = generativeai.GenerativeModel("gemini-exp-1206")
 
104
 
105
- # Update prompt path
106
- prompt_path = Path("prompts/enhance.txt")
107
- self.prompt = prompt_path.read_text()
108
-
109
- async def enhance_chunks(self, chunks: List[tuple[str, io.BytesIO]]) -> List[str]:
110
- """Enhance multiple transcript chunks in parallel"""
111
  print(f"Enhancing {len(chunks)} chunks...")
112
 
113
- async def process_chunk(chunk, index):
 
 
 
114
  text, audio = chunk
115
- try:
116
- result = await self._enhance_chunk_with_retry(text, audio)
117
- print(f"Completed chunk {index + 1}/{len(chunks)}")
118
- if result == text: # Check if output matches input exactly
119
- print("WARNING: Enhanced text matches input exactly!")
120
- return result
121
- except Exception as e:
122
- print(f"Error in chunk {index + 1}: {e}")
123
- return None
124
-
125
- # Create all tasks at once and wait for them all to complete
126
- tasks = [process_chunk(chunk, i) for i, chunk in enumerate(chunks)]
127
- results = await asyncio.gather(*tasks)
128
 
129
- # Filter out failed chunks
130
- return [r for r in results if r is not None]
 
131
 
132
- async def _enhance_chunk_with_retry(self, text: str, audio: io.BytesIO, max_retries: int = 3) -> Optional[str]:
133
- """Enhance a single chunk with retries"""
134
- for attempt in range(max_retries):
135
- try:
136
- return await self._enhance_chunk(text, audio)
137
- except Exception as e:
138
- if attempt == max_retries - 1:
139
- print(f"Failed after {max_retries} attempts: {e}")
140
- return None
141
- print(f"Attempt {attempt + 1} failed: {e}. Retrying...")
142
- await asyncio.sleep(2 ** attempt) # Exponential backoff
143
 
144
- async def _enhance_chunk(self, text: str, audio: io.BytesIO) -> str:
145
- """Enhance a single chunk"""
146
- audio.seek(0)
147
-
148
- response = await self.model.generate_content_async(
149
- [self.prompt, text, {"mime_type": "audio/mp3", "data": audio.read()}]
150
- )
151
-
152
- return response.text
 
 
 
 
 
 
 
 
 
153
 
154
 
155
- def prepare_audio_chunks(audio_path: Path, utterances: List[Utterance]) -> List[tuple[str, io.BytesIO]]:
156
  """Prepare audio chunks and their corresponding text"""
157
- def chunk_utterances(utterances: List[Utterance]) -> List[List[Utterance]]:
158
  chunks = []
159
  current = []
160
  text_length = 0
161
 
162
  for u in utterances:
163
- # Check if adding this utterance would exceed token limit
164
  new_length = text_length + len(u.text)
165
- if not current or new_length > 8000: # ~2000 tokens
166
- if current:
167
- chunks.append(current)
168
  current = [u]
169
  text_length = len(u.text)
170
  else:
@@ -173,99 +154,61 @@ def prepare_audio_chunks(audio_path: Path, utterances: List[Utterance]) -> List[
173
 
174
  if current:
175
  chunks.append(current)
176
-
177
  return chunks
178
 
179
  # Split utterances into chunks
180
  chunks = chunk_utterances(utterances)
 
181
 
182
- # Load audio file once
183
  audio = AudioSegment.from_file(audio_path)
184
 
185
- # Prepare segments
186
- print(f"Preparing {len(chunks)} audio segments...")
187
  prepared = []
188
  for chunk in chunks:
189
- # Extract audio segment
190
- start_ms = chunk[0].start
191
- end_ms = chunk[-1].end
192
- segment = audio[start_ms:end_ms]
193
-
194
- # Export to buffer
195
  buffer = io.BytesIO()
196
- segment.export(buffer, format="mp3")
197
-
198
- # Format text
199
- text = format_transcript(chunk)
200
 
201
- prepared.append((text, buffer))
202
-
203
  return prepared
204
 
205
 
206
- def format_transcript(utterances: List[Utterance]) -> str:
207
- """Format utterances into readable text"""
208
- sections = []
209
- current_speaker = None
210
- current_texts = []
211
-
212
- for u in utterances:
213
- # When speaker changes, output the accumulated text
214
- if current_speaker != u.speaker:
215
- if current_texts: # Don't output empty sections
216
- sections.append(f"Speaker {current_speaker} {utterances[len(sections)].timestamp}\n\n{''.join(current_texts)}")
217
- current_speaker = u.speaker
218
- current_texts = []
219
- current_texts.append(u.text)
220
-
221
- # Don't forget the last section
222
- if current_texts:
223
- sections.append(f"Speaker {current_speaker} {utterances[len(sections)].timestamp}\n\n{''.join(current_texts)}")
224
-
225
- return "\n\n".join(sections)
226
-
227
-
228
  def main():
229
- def setup_args() -> Path:
230
- parser = argparse.ArgumentParser()
231
- parser.add_argument("audio_file", help="Audio file to transcribe")
232
- args = parser.parse_args()
233
-
234
- audio_path = Path(args.audio_file)
235
- if not audio_path.exists():
236
- raise FileNotFoundError(f"File not found: {audio_path}")
237
- return audio_path
238
 
239
- def setup_output_dir() -> Path:
240
- out_dir = Path("output/transcripts")
241
- out_dir.mkdir(parents=True, exist_ok=True)
242
- return out_dir
 
 
243
 
244
  try:
245
- # Setup
246
- audio_path = setup_args()
247
- out_dir = setup_output_dir()
248
-
249
- # Initialize services
250
  transcriber = Transcriber(os.getenv("ASSEMBLYAI_API_KEY"))
251
- enhancer = Enhancer(os.getenv("GOOGLE_API_KEY"))
252
-
253
- # Process
254
  utterances = transcriber.get_transcript(audio_path)
255
- chunks = prepare_audio_chunks(audio_path, utterances)
256
 
257
  # Save original transcript
258
- original = format_transcript(utterances)
259
  (out_dir / "autogenerated-transcript.md").write_text(original)
260
 
261
- # Enhance and save
 
 
262
  enhanced = asyncio.run(enhancer.enhance_chunks(chunks))
263
- merged_transcript = "\n\n".join(chunk.strip() for chunk in enhanced)
264
- (out_dir / "transcript.md").write_text(merged_transcript)
 
 
265
 
266
  print("\nTranscripts saved to:")
267
- print("- output/transcripts/autogenerated-transcript.md")
268
- print("- output/transcripts/transcript.md")
269
 
270
  except Exception as e:
271
  print(f"Error: {e}")
 
4
  import json
5
  import hashlib
6
  import os
7
+ from typing import List, Tuple
8
  import assemblyai as aai
9
  from google import generativeai
10
  from pydub import AudioSegment
11
  import asyncio
12
  import io
13
+ from multiprocessing import Pool
14
+ from functools import partial
15
 
16
 
17
  @dataclass
18
  class Utterance:
19
  """A single utterance from a speaker"""
 
20
  speaker: str
21
  text: str
22
  start: int # timestamp in ms from AssemblyAI
 
25
  @property
26
  def timestamp(self) -> str:
27
  """Format start time as HH:MM:SS"""
28
+ seconds = int(self.start // 1000)
29
  hours = seconds // 3600
30
  minutes = (seconds % 3600) // 60
31
  seconds = seconds % 60
 
32
  return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
33
 
34
 
 
42
 
43
  def get_transcript(self, audio_path: Path) -> List[Utterance]:
44
  """Get transcript, using cache if available"""
 
 
 
 
 
 
 
 
 
 
45
  cache_file = self.cache_dir / f"{audio_path.stem}.json"
46
+
47
+ if cache_file.exists():
48
+ with open(cache_file) as f:
49
+ data = json.load(f)
50
+ if data["hash"] == self._get_file_hash(audio_path):
51
+ print("Using cached AssemblyAI transcript...")
52
+ return [Utterance(**u) for u in data["utterances"]]
 
 
53
 
54
+ print("Getting new transcript from AssemblyAI...")
 
55
  config = aai.TranscriptionConfig(speaker_labels=True, language_code="en")
56
  transcript = aai.Transcriber().transcribe(str(audio_path), config=config)
57
+
58
  utterances = [
59
  Utterance(speaker=u.speaker, text=u.text, start=u.start, end=u.end)
60
  for u in transcript.utterances
61
  ]
62
+
63
+ # Cache the result
64
+ cache_data = {
 
 
 
 
 
65
  "hash": self._get_file_hash(audio_path),
66
  "utterances": [vars(u) for u in utterances],
67
  }
68
  with open(cache_file, "w") as f:
69
+ json.dump(cache_data, f, indent=2)
70
+
71
+ return utterances
72
 
73
  def _get_file_hash(self, file_path: Path) -> str:
74
  """Calculate MD5 hash of a file"""
 
85
  def __init__(self, api_key: str):
86
  generativeai.configure(api_key=api_key)
87
  self.model = generativeai.GenerativeModel("gemini-exp-1206")
88
+ self.prompt = Path("prompts/enhance.txt").read_text()
89
 
90
+ async def enhance_chunks(self, chunks: List[Tuple[str, io.BytesIO]]) -> List[str]:
91
+ """Enhance multiple transcript chunks concurrently with concurrency control"""
 
 
 
 
92
  print(f"Enhancing {len(chunks)} chunks...")
93
 
94
+ # Create a semaphore to limit concurrent requests
95
+ semaphore = asyncio.Semaphore(3) # Allow up to 3 concurrent requests
96
+
97
+ async def process_chunk(i: int, chunk: Tuple[str, io.BytesIO]) -> str:
98
  text, audio = chunk
99
+ async with semaphore:
100
+ audio.seek(0)
101
+ response = await self.model.generate_content_async(
102
+ [self.prompt, text, {"mime_type": "audio/mp3", "data": audio.read()}]
103
+ )
104
+ print(f"Completed chunk {i+1}/{len(chunks)}")
105
+ return response.text
106
+
107
+ # Create tasks for all chunks and run them concurrently
108
+ tasks = [
109
+ process_chunk(i, chunk)
110
+ for i, chunk in enumerate(chunks)
111
+ ]
112
 
113
+ # Wait for all tasks to complete
114
+ results = await asyncio.gather(*tasks)
115
+ return results
116
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ def format_chunk(utterances: List[Utterance]) -> str:
119
+ """Format utterances into readable text with timestamps"""
120
+ sections = []
121
+ current_speaker = None
122
+ current_texts = []
123
+
124
+ for u in utterances:
125
+ if current_speaker != u.speaker:
126
+ if current_texts:
127
+ sections.append(f"Speaker {current_speaker} {utterances[len(sections)].timestamp}\n\n{''.join(current_texts)}")
128
+ current_speaker = u.speaker
129
+ current_texts = []
130
+ current_texts.append(u.text)
131
+
132
+ if current_texts:
133
+ sections.append(f"Speaker {current_speaker} {utterances[len(sections)].timestamp}\n\n{''.join(current_texts)}")
134
+
135
+ return "\n\n".join(sections)
136
 
137
 
138
+ def prepare_audio_chunks(audio_path: Path, utterances: List[Utterance]) -> List[Tuple[str, io.BytesIO]]:
139
  """Prepare audio chunks and their corresponding text"""
140
+ def chunk_utterances(utterances: List[Utterance], max_tokens: int = 8000) -> List[List[Utterance]]:
141
  chunks = []
142
  current = []
143
  text_length = 0
144
 
145
  for u in utterances:
 
146
  new_length = text_length + len(u.text)
147
+ if current and new_length > max_tokens:
148
+ chunks.append(current)
 
149
  current = [u]
150
  text_length = len(u.text)
151
  else:
 
154
 
155
  if current:
156
  chunks.append(current)
 
157
  return chunks
158
 
159
  # Split utterances into chunks
160
  chunks = chunk_utterances(utterances)
161
+ print(f"Preparing {len(chunks)} audio segments...")
162
 
163
+ # Load audio once
164
  audio = AudioSegment.from_file(audio_path)
165
 
166
+ # Process each chunk
 
167
  prepared = []
168
  for chunk in chunks:
169
+ # Extract just the needed segment
170
+ segment = audio[chunk[0].start:chunk[-1].end]
 
 
 
 
171
  buffer = io.BytesIO()
172
+ # Use lower quality MP3 for faster processing
173
+ segment.export(buffer, format="mp3", parameters=["-q:a", "9"])
174
+ prepared.append((format_chunk(chunk), buffer))
 
175
 
 
 
176
  return prepared
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def main():
180
+ parser = argparse.ArgumentParser()
181
+ parser.add_argument("audio_file", help="Audio file to transcribe")
182
+ args = parser.parse_args()
 
 
 
 
 
 
183
 
184
+ audio_path = Path(args.audio_file)
185
+ if not audio_path.exists():
186
+ raise FileNotFoundError(f"File not found: {audio_path}")
187
+
188
+ out_dir = Path("output/transcripts")
189
+ out_dir.mkdir(parents=True, exist_ok=True)
190
 
191
  try:
192
+ # Get transcript
 
 
 
 
193
  transcriber = Transcriber(os.getenv("ASSEMBLYAI_API_KEY"))
 
 
 
194
  utterances = transcriber.get_transcript(audio_path)
 
195
 
196
  # Save original transcript
197
+ original = format_chunk(utterances)
198
  (out_dir / "autogenerated-transcript.md").write_text(original)
199
 
200
+ # Enhance transcript
201
+ enhancer = Enhancer(os.getenv("GOOGLE_API_KEY"))
202
+ chunks = prepare_audio_chunks(audio_path, utterances)
203
  enhanced = asyncio.run(enhancer.enhance_chunks(chunks))
204
+
205
+ # Save enhanced transcript
206
+ merged = "\n\n".join(chunk.strip() for chunk in enhanced)
207
+ (out_dir / "transcript.md").write_text(merged)
208
 
209
  print("\nTranscripts saved to:")
210
+ print(f"- {out_dir}/autogenerated-transcript.md")
211
+ print(f"- {out_dir}/transcript.md")
212
 
213
  except Exception as e:
214
  print(f"Error: {e}")