Update app.py
Browse files
app.py
CHANGED
@@ -221,6 +221,10 @@ def stage2_inference(model, stage1_output_set, stage2_output_dir, batch_size=1):
|
|
221 |
output_duration = prompt.shape[-1] // 50 // 6 * 6
|
222 |
num_batch = output_duration // 6
|
223 |
|
|
|
|
|
|
|
|
|
224 |
if num_batch <= batch_size:
|
225 |
# If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
|
226 |
output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
|
|
|
221 |
output_duration = prompt.shape[-1] // 50 // 6 * 6
|
222 |
num_batch = output_duration // 6
|
223 |
|
224 |
+
if output_duration <= 0:
|
225 |
+
print(f'{output_filename} stage1 output is too short, skipping stage2.')
|
226 |
+
continue
|
227 |
+
|
228 |
if num_batch <= batch_size:
|
229 |
# If num_batch is less than or equal to batch_size, we can infer the entire prompt at once
|
230 |
output = stage2_generate(model, prompt[:, :output_duration*50], batch_size=num_batch)
|