ford442 commited on
Commit
00651e4
·
verified ·
1 Parent(s): bdb5517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -4
app.py CHANGED
@@ -316,7 +316,13 @@ def scheduler_swap_callback(pipeline, step_index, timestep, callback_kwargs):
316
  #pipe.scheduler.set_timesteps(num_inference_steps*.70)
317
  # print(f"-- setting step {pipeline.num_timesteps * 0.9} --")
318
  # pipeline.scheduler._step_index = pipeline.num_timesteps * 0.9
319
- return {"latents": callback_kwargs["latents"]}
 
 
 
 
 
 
320
 
321
  def upload_to_ftp(filename):
322
  try:
@@ -359,6 +365,7 @@ def generate_30(
359
  height: int = 768,
360
  guidance_scale: float = 4,
361
  num_inference_steps: int = 125,
 
362
  use_resolution_binning: bool = True,
363
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
364
  ):
@@ -374,7 +381,8 @@ def generate_30(
374
  "num_inference_steps": num_inference_steps,
375
  "generator": generator,
376
  "output_type": "pil",
377
- "callback_on_step_end": pyx.scheduler_swap_callback
 
378
  }
379
  if use_resolution_binning:
380
  options["use_resolution_binning"] = True
@@ -412,6 +420,7 @@ def generate_60(
412
  height: int = 768,
413
  guidance_scale: float = 4,
414
  num_inference_steps: int = 125,
 
415
  use_resolution_binning: bool = True,
416
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
417
  ):
@@ -427,7 +436,8 @@ def generate_60(
427
  "num_inference_steps": num_inference_steps,
428
  "generator": generator,
429
  "output_type": "pil",
430
- "callback_on_step_end": scheduler_swap_callback
 
431
  }
432
  if use_resolution_binning:
433
  options["use_resolution_binning"] = True
@@ -455,6 +465,7 @@ def generate_90(
455
  height: int = 768,
456
  guidance_scale: float = 4,
457
  num_inference_steps: int = 125,
 
458
  use_resolution_binning: bool = True,
459
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
460
  ):
@@ -470,7 +481,8 @@ def generate_90(
470
  "num_inference_steps": num_inference_steps,
471
  "generator": generator,
472
  "output_type": "pil",
473
- "callback_on_step_end": scheduler_swap_callback
 
474
  }
475
  if use_resolution_binning:
476
  options["use_resolution_binning"] = True
@@ -575,6 +587,13 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
575
  step=0.1,
576
  value=3.8,
577
  )
 
 
 
 
 
 
 
578
  num_inference_steps = gr.Slider(
579
  label="Number of inference steps",
580
  minimum=10,
@@ -610,6 +629,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
610
  width,
611
  height,
612
  guidance_scale,
 
613
  num_inference_steps,
614
  ],
615
  outputs=[result],
@@ -629,6 +649,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
629
  width,
630
  height,
631
  guidance_scale,
 
632
  num_inference_steps,
633
  ],
634
  outputs=[result],
@@ -648,6 +669,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
648
  width,
649
  height,
650
  guidance_scale,
 
651
  num_inference_steps,
652
  ],
653
  outputs=[result],
 
316
  #pipe.scheduler.set_timesteps(num_inference_steps*.70)
317
  # print(f"-- setting step {pipeline.num_timesteps * 0.9} --")
318
  # pipeline.scheduler._step_index = pipeline.num_timesteps * 0.9
319
+ if step_index == int(pipeline.num_timesteps * self.config.cutoff_step_ratio)):
320
+ prompt_embeds = callback_kwargs["prompt_embeds"]
321
+ prompt_embeds = prompt_embeds.chunk(2)[-1]
322
+ # update guidance_scale and prompt_embeds
323
+ pipeline._guidance_scale = new_scale
324
+ callback_kwargs["prompt_embeds"] = prompt_embeds
325
+ return callback_kwargs
326
 
327
  def upload_to_ftp(filename):
328
  try:
 
365
  height: int = 768,
366
  guidance_scale: float = 4,
367
  num_inference_steps: int = 125,
368
+ guidance_cutoff: float = 1.0,
369
  use_resolution_binning: bool = True,
370
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
371
  ):
 
381
  "num_inference_steps": num_inference_steps,
382
  "generator": generator,
383
  "output_type": "pil",
384
+ "callback_on_step_end": pyx.scheduler_swap_callback(cutoff_step_ratio=guidance_cutoff),
385
+ "callback_on_step_end_tensor_inputs": ['prompt_embeds']
386
  }
387
  if use_resolution_binning:
388
  options["use_resolution_binning"] = True
 
420
  height: int = 768,
421
  guidance_scale: float = 4,
422
  num_inference_steps: int = 125,
423
+ guidance_cutoff: float = 1.0,
424
  use_resolution_binning: bool = True,
425
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
426
  ):
 
436
  "num_inference_steps": num_inference_steps,
437
  "generator": generator,
438
  "output_type": "pil",
439
+ "callback_on_step_end": pyx.scheduler_swap_callback(cutoff_step_ratio=guidance_cutoff),
440
+ "callback_on_step_end_tensor_inputs": ['prompt_embeds']
441
  }
442
  if use_resolution_binning:
443
  options["use_resolution_binning"] = True
 
465
  height: int = 768,
466
  guidance_scale: float = 4,
467
  num_inference_steps: int = 125,
468
+ guidance_cutoff: float = 1.0,
469
  use_resolution_binning: bool = True,
470
  progress=gr.Progress(track_tqdm=True) # Add progress as a keyword argument
471
  ):
 
481
  "num_inference_steps": num_inference_steps,
482
  "generator": generator,
483
  "output_type": "pil",
484
+ "callback_on_step_end": pyx.scheduler_swap_callback(cutoff_step_ratio=guidance_cutoff),
485
+ "callback_on_step_end_tensor_inputs": ['prompt_embeds']
486
  }
487
  if use_resolution_binning:
488
  options["use_resolution_binning"] = True
 
587
  step=0.1,
588
  value=3.8,
589
  )
590
+ guidance_cutoff = gr.Slider(
591
+ label="Guidance Scale Cutoff",
592
+ minimum=0.01,
593
+ maximum=1.0,
594
+ step=0.01,
595
+ value=1.0,
596
+ )
597
  num_inference_steps = gr.Slider(
598
  label="Number of inference steps",
599
  minimum=10,
 
629
  width,
630
  height,
631
  guidance_scale,
632
+ guidance_cutoff,
633
  num_inference_steps,
634
  ],
635
  outputs=[result],
 
649
  width,
650
  height,
651
  guidance_scale,
652
+ guidance_cutoff,
653
  num_inference_steps,
654
  ],
655
  outputs=[result],
 
669
  width,
670
  height,
671
  guidance_scale,
672
+ guidance_cutoff,
673
  num_inference_steps,
674
  ],
675
  outputs=[result],