jamesliu1217 commited on
Commit
b95773f
·
verified ·
1 Parent(s): 45f491c
Files changed (1) hide show
  1. src/pipeline.py +24 -1
src/pipeline.py CHANGED
@@ -656,7 +656,30 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
656
  guidance = guidance.expand(latents.shape[0])
657
  else:
658
  guidance = None
659
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  # 6. Denoising loop
661
  with self.progress_bar(total=num_inference_steps) as progress_bar:
662
  for i, t in enumerate(timesteps):
 
656
  guidance = guidance.expand(latents.shape[0])
657
  else:
658
  guidance = None
659
+
660
+ ## Caching conditions
661
+ # clean the cache
662
+ for name, attn_processor in self.transformer.attn_processors.items():
663
+ attn_processor.bank_kv.clear()
664
+ # cache with warmup latents
665
+ start_idx = latents.shape[1] - 32
666
+ warmup_latents = latents[:, start_idx:, :]
667
+ warmup_latent_ids = latent_image_ids[start_idx:, :]
668
+ t = torch.tensor([timesteps[0]], device=device)
669
+ timestep = t.expand(warmup_latents.shape[0]).to(latents.dtype)
670
+ _ = self.transformer(
671
+ hidden_states=warmup_latents,
672
+ cond_hidden_states=cond_latents,
673
+ timestep=timestep/ 1000,
674
+ guidance=guidance,
675
+ pooled_projections=pooled_prompt_embeds,
676
+ encoder_hidden_states=prompt_embeds,
677
+ txt_ids=text_ids,
678
+ img_ids=warmup_latent_ids,
679
+ joint_attention_kwargs=self.joint_attention_kwargs,
680
+ return_dict=False,
681
+ )[0]
682
+
683
  # 6. Denoising loop
684
  with self.progress_bar(total=num_inference_steps) as progress_bar:
685
  for i, t in enumerate(timesteps):