Spaces:
Running
on
Zero
Running
on
Zero
add cache
Browse files- src/pipeline.py +24 -1
src/pipeline.py
CHANGED
@@ -656,7 +656,30 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
|
656 |
guidance = guidance.expand(latents.shape[0])
|
657 |
else:
|
658 |
guidance = None
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
# 6. Denoising loop
|
661 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
662 |
for i, t in enumerate(timesteps):
|
|
|
656 |
guidance = guidance.expand(latents.shape[0])
|
657 |
else:
|
658 |
guidance = None
|
659 |
+
|
660 |
+
## Caching conditions
|
661 |
+
# clean the cache
|
662 |
+
for name, attn_processor in self.transformer.attn_processors.items():
|
663 |
+
attn_processor.bank_kv.clear()
|
664 |
+
# cache with warmup latents
|
665 |
+
start_idx = latents.shape[1] - 32
|
666 |
+
warmup_latents = latents[:, start_idx:, :]
|
667 |
+
warmup_latent_ids = latent_image_ids[start_idx:, :]
|
668 |
+
t = torch.tensor([timesteps[0]], device=device)
|
669 |
+
timestep = t.expand(warmup_latents.shape[0]).to(latents.dtype)
|
670 |
+
_ = self.transformer(
|
671 |
+
hidden_states=warmup_latents,
|
672 |
+
cond_hidden_states=cond_latents,
|
673 |
+
timestep=timestep/ 1000,
|
674 |
+
guidance=guidance,
|
675 |
+
pooled_projections=pooled_prompt_embeds,
|
676 |
+
encoder_hidden_states=prompt_embeds,
|
677 |
+
txt_ids=text_ids,
|
678 |
+
img_ids=warmup_latent_ids,
|
679 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
680 |
+
return_dict=False,
|
681 |
+
)[0]
|
682 |
+
|
683 |
# 6. Denoising loop
|
684 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
685 |
for i, t in enumerate(timesteps):
|