KingNish commited on
Commit
c26a9f5
·
verified ·
1 Parent(s): 7d83b5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -246,10 +246,19 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
246
  prompt = prompt[np.newaxis, :]
247
  print(f"Loaded prompt from {path} with shape: {prompt.shape}")
248
 
249
- # Compute output duration: tokens per second assumed to be 50, only full 6-second segments.
250
- output_duration = (prompt.shape[-1] // 50) // 6 * 6
 
 
 
 
 
 
 
 
251
  if output_duration == 0:
252
  raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
 
253
  num_batch = output_duration // 6
254
 
255
  # Process prompt in batches
@@ -280,7 +289,7 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
280
  # Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
281
  output = codectool_stage2.ids2npy(output)
282
 
283
- # Fix any invalid codes
284
  fixed_output = copy.deepcopy(output)
285
  for i, line in enumerate(output):
286
  for j, element in enumerate(line):
 
246
  prompt = prompt[np.newaxis, :]
247
  print(f"Loaded prompt from {path} with shape: {prompt.shape}")
248
 
249
+ # Compute total duration in seconds (assuming 50 tokens per second)
250
+ total_duration_sec = prompt.shape[-1] // 50
251
+ if total_duration_sec < 6:
252
+ # Not enough tokens for a full 6-sec segment; use the entire prompt.
253
+ output_duration = total_duration_sec
254
+ print(f"Prompt too short for 6-sec segmentation. Using full duration: {output_duration} seconds.")
255
+ else:
256
+ output_duration = (total_duration_sec // 6) * 6
257
+
258
+ # If after the above, output_duration is still zero, raise an error.
259
  if output_duration == 0:
260
  raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
261
+
262
  num_batch = output_duration // 6
263
 
264
  # Process prompt in batches
 
289
  # Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
290
  output = codectool_stage2.ids2npy(output)
291
 
292
+ # Fix any invalid codes (if needed)
293
  fixed_output = copy.deepcopy(output)
294
  for i, line in enumerate(output):
295
  for j, element in enumerate(line):