KingNish commited on
Commit
7d83b5a
·
verified ·
1 Parent(s): 06ded55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -158,6 +158,7 @@ def stage2_generate(model_stage2, prompt, batch_size=16):
158
  Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
159
  """
160
  # Unflatten prompt: assume prompt shape (1, T) and then reformat.
 
161
  codec_ids = codectool.unflatten(prompt, n_quantizer=1)
162
  codec_ids = codectool.offset_tok_ids(
163
  codec_ids,
@@ -238,10 +239,20 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
238
  print(f"{output_filename} already processed.")
239
  stage2_result.append(output_filename)
240
  continue
 
241
  prompt = np.load(path).astype(np.int32)
242
- # Only process multiples of 6 seconds; here 50 tokens per second.
 
 
 
 
 
243
  output_duration = (prompt.shape[-1] // 50) // 6 * 6
 
 
244
  num_batch = output_duration // 6
 
 
245
  if num_batch <= batch_size:
246
  output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
247
  else:
@@ -251,16 +262,25 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
251
  start_idx = seg * batch_size * 300
252
  end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
253
  current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
254
- segment = stage2_generate(model_stage2, prompt[:, start_idx:end_idx], batch_size=current_batch)
 
 
 
 
255
  segments.append(segment)
 
 
256
  output = np.concatenate(segments, axis=0)
 
257
  # Process any remaining tokens if prompt length not fully used.
258
  if output_duration * 50 != prompt.shape[-1]:
259
  ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
260
  output = np.concatenate([output, ending], axis=0)
261
- # Convert Stage2 output tokens back to numpy array using stage2’s codec manipulator.
 
262
  output = codectool_stage2.ids2npy(output)
263
- # Fix any invalid codes (if needed)
 
264
  fixed_output = copy.deepcopy(output)
265
  for i, line in enumerate(output):
266
  for j, element in enumerate(line):
@@ -268,6 +288,7 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
268
  counter = Counter(line)
269
  most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
270
  fixed_output[i, j] = most_common
 
271
  np.save(output_filename, fixed_output)
272
  stage2_result.append(output_filename)
273
  return stage2_result
 
158
  Given a prompt (a numpy array of raw codec ids), upsample using the Stage2 model.
159
  """
160
  # Unflatten prompt: assume prompt shape (1, T) and then reformat.
161
+ print(f"stage2_generate: received prompt with shape: {prompt.shape}")
162
  codec_ids = codectool.unflatten(prompt, n_quantizer=1)
163
  codec_ids = codectool.offset_tok_ids(
164
  codec_ids,
 
239
  print(f"{output_filename} already processed.")
240
  stage2_result.append(output_filename)
241
  continue
242
+
243
  prompt = np.load(path).astype(np.int32)
244
+ # Ensure prompt is 2D.
245
+ if prompt.ndim == 1:
246
+ prompt = prompt[np.newaxis, :]
247
+ print(f"Loaded prompt from {path} with shape: {prompt.shape}")
248
+
249
+ # Compute output duration: tokens per second assumed to be 50, only full 6-second segments.
250
  output_duration = (prompt.shape[-1] // 50) // 6 * 6
251
+ if output_duration == 0:
252
+ raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
253
  num_batch = output_duration // 6
254
+
255
+ # Process prompt in batches
256
  if num_batch <= batch_size:
257
  output = stage2_generate(model_stage2, prompt[:, :output_duration*50], batch_size=num_batch)
258
  else:
 
262
  start_idx = seg * batch_size * 300
263
  end_idx = min((seg + 1) * batch_size * 300, output_duration * 50)
264
  current_batch = batch_size if (seg != num_segments - 1 or num_batch % batch_size == 0) else num_batch % batch_size
265
+ segment_prompt = prompt[:, start_idx:end_idx]
266
+ if segment_prompt.shape[-1] == 0:
267
+ print(f"Warning: empty segment detected for seg {seg}, start {start_idx}, end {end_idx}. Skipping this segment.")
268
+ continue
269
+ segment = stage2_generate(model_stage2, segment_prompt, batch_size=current_batch)
270
  segments.append(segment)
271
+ if len(segments) == 0:
272
+ raise ValueError(f"No valid segments produced for {path}.")
273
  output = np.concatenate(segments, axis=0)
274
+
275
  # Process any remaining tokens if prompt length not fully used.
276
  if output_duration * 50 != prompt.shape[-1]:
277
  ending = stage2_generate(model_stage2, prompt[:, output_duration * 50:], batch_size=1)
278
  output = np.concatenate([output, ending], axis=0)
279
+
280
+ # Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
281
  output = codectool_stage2.ids2npy(output)
282
+
283
+ # Fix any invalid codes
284
  fixed_output = copy.deepcopy(output)
285
  for i, line in enumerate(output):
286
  for j, element in enumerate(line):
 
288
  counter = Counter(line)
289
  most_common = sorted(counter.items(), key=lambda x: x[1], reverse=True)[0][0]
290
  fixed_output[i, j] = most_common
291
+
292
  np.save(output_filename, fixed_output)
293
  stage2_result.append(output_filename)
294
  return stage2_result