Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -159,25 +159,25 @@ pipe.to("cuda")
|
|
| 159 |
# pipe.load_lora_weights("TODO/TODO", adapter_name="ltx-lora")
|
| 160 |
# pipe.set_adapters(["lrx-lora"], adapter_weights=[1.0])
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
@spaces.GPU(duration=120)
|
| 163 |
@torch.inference_mode()
|
| 164 |
def generate_video(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 165 |
progress_steps = []
|
| 166 |
-
|
| 167 |
-
def setup_progressbar_length(_num_steps=num_inference_steps):
|
| 168 |
-
for _step in bytes(range(_num_steps)):
|
| 169 |
-
progress_steps.append(_step) # one step one byte - fq the logic
|
| 170 |
-
|
| 171 |
-
def progress_step():
|
| 172 |
-
if len(progress_steps) == 0:
|
| 173 |
-
return
|
| 174 |
-
for done_step in tqdm(enumerate(range(len(progress_steps)))):
|
| 175 |
-
progress_steps.pop()
|
| 176 |
-
if len(progress_steps) == 0:
|
| 177 |
-
tqdm.close()
|
| 178 |
-
break
|
| 179 |
-
|
| 180 |
-
setup_progressbar_length()
|
| 181 |
|
| 182 |
# Randomize seed if seed is 0
|
| 183 |
if seed == 0:
|
|
@@ -194,7 +194,7 @@ def generate_video(prompt, negative_prompt, height, width, num_frames, num_infer
|
|
| 194 |
num_frames=num_frames,
|
| 195 |
num_inference_steps=num_inference_steps,
|
| 196 |
generator=torch.Generator(device='cuda').manual_seed(seed),
|
| 197 |
-
callback_on_step_end=
|
| 198 |
).frames[0]
|
| 199 |
|
| 200 |
# Create output filename based on prompt and timestamp
|
|
@@ -233,6 +233,7 @@ with gr.Blocks() as demo:
|
|
| 233 |
|
| 234 |
output_video = gr.Video(label="Generated Video", show_label=True)
|
| 235 |
generate_button = gr.Button("Generate Video")
|
|
|
|
| 236 |
save_state_button = gr.Button("Save State")
|
| 237 |
|
| 238 |
random_seed_button.click(randomize_seed, outputs=seed)
|
|
@@ -241,6 +242,10 @@ with gr.Blocks() as demo:
|
|
| 241 |
inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
|
| 242 |
outputs=output_video
|
| 243 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
save_state_button.click(
|
| 245 |
save_ui_state,
|
| 246 |
inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
|
|
|
|
| 159 |
# pipe.load_lora_weights("TODO/TODO", adapter_name="ltx-lora")
|
| 160 |
# pipe.set_adapters(["lrx-lora"], adapter_weights=[1.0])
|
| 161 |
|
| 162 |
+
INTERRUPT_PIPELINE = False
|
| 163 |
+
|
| 164 |
+
def interrupt_inference():
|
| 165 |
+
INTERRUPT_PIPELINE = True
|
| 166 |
+
|
| 167 |
+
def interrupt_callback(pipeline, i, t, callback_kwargs):
|
| 168 |
+
stop_idx = 19
|
| 169 |
+
if i >= stop_idx:
|
| 170 |
+
pipeline._interrupt = False
|
| 171 |
+
return callback_kwargs
|
| 172 |
+
|
| 173 |
+
pipeline._interrupt = INTERRUPT_PIPELINE
|
| 174 |
+
return callback_kwargs
|
| 175 |
+
|
| 176 |
@spaces.GPU(duration=120)
|
| 177 |
@torch.inference_mode()
|
| 178 |
def generate_video(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed, progress=gr.Progress(track_tqdm=True)):
|
| 179 |
+
INTERRUPT_PIPELINE = False
|
| 180 |
progress_steps = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
# Randomize seed if seed is 0
|
| 183 |
if seed == 0:
|
|
|
|
| 194 |
num_frames=num_frames,
|
| 195 |
num_inference_steps=num_inference_steps,
|
| 196 |
generator=torch.Generator(device='cuda').manual_seed(seed),
|
| 197 |
+
callback_on_step_end=interrupt_callback
|
| 198 |
).frames[0]
|
| 199 |
|
| 200 |
# Create output filename based on prompt and timestamp
|
|
|
|
| 233 |
|
| 234 |
output_video = gr.Video(label="Generated Video", show_label=True)
|
| 235 |
generate_button = gr.Button("Generate Video")
|
| 236 |
+
cancel_button = gr.Button("Cancel")
|
| 237 |
save_state_button = gr.Button("Save State")
|
| 238 |
|
| 239 |
random_seed_button.click(randomize_seed, outputs=seed)
|
|
|
|
| 242 |
inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
|
| 243 |
outputs=output_video
|
| 244 |
)
|
| 245 |
+
cancel_button.click(
|
| 246 |
+
interrupt_inference,
|
| 247 |
+
outputs=gr.Text(label="Interrupted.")
|
| 248 |
+
)
|
| 249 |
save_state_button.click(
|
| 250 |
save_ui_state,
|
| 251 |
inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
|