Update app.py
Browse files
app.py
CHANGED
@@ -246,10 +246,19 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
|
|
246 |
prompt = prompt[np.newaxis, :]
|
247 |
print(f"Loaded prompt from {path} with shape: {prompt.shape}")
|
248 |
|
249 |
-
# Compute
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
if output_duration == 0:
|
252 |
raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
|
|
|
253 |
num_batch = output_duration // 6
|
254 |
|
255 |
# Process prompt in batches
|
@@ -280,7 +289,7 @@ def stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_s
|
|
280 |
# Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
|
281 |
output = codectool_stage2.ids2npy(output)
|
282 |
|
283 |
-
# Fix any invalid codes
|
284 |
fixed_output = copy.deepcopy(output)
|
285 |
for i, line in enumerate(output):
|
286 |
for j, element in enumerate(line):
|
|
|
246 |
prompt = prompt[np.newaxis, :]
|
247 |
print(f"Loaded prompt from {path} with shape: {prompt.shape}")
|
248 |
|
249 |
+
# Compute total duration in seconds (assuming 50 tokens per second)
|
250 |
+
total_duration_sec = prompt.shape[-1] // 50
|
251 |
+
if total_duration_sec < 6:
|
252 |
+
# Not enough tokens for a full 6-sec segment; use the entire prompt.
|
253 |
+
output_duration = total_duration_sec
|
254 |
+
print(f"Prompt too short for 6-sec segmentation. Using full duration: {output_duration} seconds.")
|
255 |
+
else:
|
256 |
+
output_duration = (total_duration_sec // 6) * 6
|
257 |
+
|
258 |
+
# If after the above, output_duration is still zero, raise an error.
|
259 |
if output_duration == 0:
|
260 |
raise ValueError(f"Output duration computed as 0 for {path}. Prompt length: {prompt.shape[-1]} tokens")
|
261 |
+
|
262 |
num_batch = output_duration // 6
|
263 |
|
264 |
# Process prompt in batches
|
|
|
289 |
# Convert Stage2 output tokens back to numpy using Stage2’s codec manipulator.
|
290 |
output = codectool_stage2.ids2npy(output)
|
291 |
|
292 |
+
# Fix any invalid codes (if needed)
|
293 |
fixed_output = copy.deepcopy(output)
|
294 |
for i, line in enumerate(output):
|
295 |
for j, element in enumerate(line):
|