tsqn commited on
Commit
a79a4d5
·
verified ·
1 Parent(s): 487bc0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -159,25 +159,25 @@ pipe.to("cuda")
159
  # pipe.load_lora_weights("TODO/TODO", adapter_name="ltx-lora")
160
  # pipe.set_adapters(["lrx-lora"], adapter_weights=[1.0])
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  @spaces.GPU(duration=120)
163
  @torch.inference_mode()
164
  def generate_video(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed, progress=gr.Progress(track_tqdm=True)):
 
165
  progress_steps = []
166
-
167
- def setup_progressbar_length(_num_steps=num_inference_steps):
168
- for _step in bytes(range(_num_steps)):
169
- progress_steps.append(_step) # one step one byte - fq the logic
170
-
171
- def progress_step():
172
- if len(progress_steps) == 0:
173
- return
174
- for done_step in tqdm(enumerate(range(len(progress_steps)))):
175
- progress_steps.pop()
176
- if len(progress_steps) == 0:
177
- tqdm.close()
178
- break
179
-
180
- setup_progressbar_length()
181
 
182
  # Randomize seed if seed is 0
183
  if seed == 0:
@@ -194,7 +194,7 @@ def generate_video(prompt, negative_prompt, height, width, num_frames, num_infer
194
  num_frames=num_frames,
195
  num_inference_steps=num_inference_steps,
196
  generator=torch.Generator(device='cuda').manual_seed(seed),
197
- callback_on_step_end=progress_step
198
  ).frames[0]
199
 
200
  # Create output filename based on prompt and timestamp
@@ -233,6 +233,7 @@ with gr.Blocks() as demo:
233
 
234
  output_video = gr.Video(label="Generated Video", show_label=True)
235
  generate_button = gr.Button("Generate Video")
 
236
  save_state_button = gr.Button("Save State")
237
 
238
  random_seed_button.click(randomize_seed, outputs=seed)
@@ -241,6 +242,10 @@ with gr.Blocks() as demo:
241
  inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
242
  outputs=output_video
243
  )
 
 
 
 
244
  save_state_button.click(
245
  save_ui_state,
246
  inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
 
159
  # pipe.load_lora_weights("TODO/TODO", adapter_name="ltx-lora")
160
  # pipe.set_adapters(["lrx-lora"], adapter_weights=[1.0])
161
 
162
+ INTERRUPT_PIPELINE = False
163
+
164
+ def interrupt_inference():
165
+ INTERRUPT_PIPELINE = True
166
+
167
+ def interrupt_callback(pipeline, i, t, callback_kwargs):
168
+ stop_idx = 19
169
+ if i >= stop_idx:
170
+ pipeline._interrupt = False
171
+ return callback_kwargs
172
+
173
+ pipeline._interrupt = INTERRUPT_PIPELINE
174
+ return callback_kwargs
175
+
176
  @spaces.GPU(duration=120)
177
  @torch.inference_mode()
178
  def generate_video(prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed, progress=gr.Progress(track_tqdm=True)):
179
+ INTERRUPT_PIPELINE = False
180
  progress_steps = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # Randomize seed if seed is 0
183
  if seed == 0:
 
194
  num_frames=num_frames,
195
  num_inference_steps=num_inference_steps,
196
  generator=torch.Generator(device='cuda').manual_seed(seed),
197
+ callback_on_step_end=interrupt_callback
198
  ).frames[0]
199
 
200
  # Create output filename based on prompt and timestamp
 
233
 
234
  output_video = gr.Video(label="Generated Video", show_label=True)
235
  generate_button = gr.Button("Generate Video")
236
+ cancel_button = gr.Button("Cancel")
237
  save_state_button = gr.Button("Save State")
238
 
239
  random_seed_button.click(randomize_seed, outputs=seed)
 
242
  inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],
243
  outputs=output_video
244
  )
245
+ cancel_button.click(
246
+ interrupt_inference,
247
+ outputs=gr.Text(label="Interrupted.")
248
+ )
249
  save_state_button.click(
250
  save_ui_state,
251
  inputs=[prompt, negative_prompt, height, width, num_frames, num_inference_steps, fps, seed],