LightDiffusion-Next / modules /WaveSpeed /first_block_cache.py
Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import contextlib
import dataclasses
import unittest
from collections import defaultdict
from typing import DefaultDict, Dict
import torch
from modules.AutoEncoders.ResBlock import forward_timestep_embed1
from modules.NeuralNetwork.unet import apply_control1
from modules.sample.sampling_util import timestep_embedding
@dataclasses.dataclass
class CacheContext:
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
self.incremental_name_counters[name] += 1
return f"{name}_{idx}"
def reset_incremental_names(self):
self.incremental_name_counters.clear()
@torch.compiler.disable()
def get_buffer(self, name):
return self.buffers.get(name)
@torch.compiler.disable()
def set_buffer(self, name, buffer):
self.buffers[name] = buffer
def clear_buffers(self):
self.buffers.clear()
@torch.compiler.disable()
def get_buffer(name):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable()
def set_buffer(name, buffer):
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
_current_cache_context = None
def create_cache_context():
return CacheContext()
def get_current_cache_context():
return _current_cache_context
def set_current_cache_context(cache_context=None):
global _current_cache_context
_current_cache_context = cache_context
@contextlib.contextmanager
def cache_context(cache_context):
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
try:
yield
finally:
_current_cache_context = old_cache_context
# def patch_get_output_data():
# import execution
# get_output_data = getattr(execution, "get_output_data", None)
# if get_output_data is None:
# return
# if getattr(get_output_data, "_patched", False):
# return
# def new_get_output_data(*args, **kwargs):
# out = get_output_data(*args, **kwargs)
# cache_context = get_current_cache_context()
# if cache_context is not None:
# cache_context.clear_buffers()
# set_current_cache_context(None)
# return out
# new_get_output_data._patched = True
# execution.get_output_data = new_get_output_data
@torch.compiler.disable()
def are_two_tensors_similar(t1, t2, *, threshold):
if t1.shape != t2.shape:
return False
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = mean_diff / mean_t1
return diff.item() < threshold
@torch.compiler.disable()
def apply_prev_hidden_states_residual(hidden_states,
encoder_hidden_states=None):
hidden_states_residual = get_buffer("hidden_states_residual")
assert hidden_states_residual is not None, "hidden_states_residual must be set before"
hidden_states = hidden_states_residual + hidden_states
hidden_states = hidden_states.contiguous()
if encoder_hidden_states is None:
return hidden_states
encoder_hidden_states_residual = get_buffer(
"encoder_hidden_states_residual")
if encoder_hidden_states_residual is None:
encoder_hidden_states = None
else:
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
encoder_hidden_states = encoder_hidden_states.contiguous()
return hidden_states, encoder_hidden_states
@torch.compiler.disable()
def get_can_use_cache(first_hidden_states_residual,
threshold,
parallelized=False):
prev_first_hidden_states_residual = get_buffer(
"first_hidden_states_residual")
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
prev_first_hidden_states_residual,
first_hidden_states_residual,
threshold=threshold,
)
return can_use_cache
class CachedTransformerBlocks(torch.nn.Module):
def __init__(
self,
transformer_blocks,
single_transformer_blocks=None,
*,
residual_diff_threshold,
validate_can_use_cache_function=None,
return_hidden_states_first=True,
accept_hidden_states_first=True,
cat_hidden_states_first=False,
return_hidden_states_only=False,
clone_original_hidden_states=False,
):
super().__init__()
self.transformer_blocks = transformer_blocks
self.single_transformer_blocks = single_transformer_blocks
self.residual_diff_threshold = residual_diff_threshold
self.validate_can_use_cache_function = validate_can_use_cache_function
self.return_hidden_states_first = return_hidden_states_first
self.accept_hidden_states_first = accept_hidden_states_first
self.cat_hidden_states_first = cat_hidden_states_first
self.return_hidden_states_only = return_hidden_states_only
self.clone_original_hidden_states = clone_original_hidden_states
def forward(self, *args, **kwargs):
img_arg_name = None
if "img" in kwargs:
img_arg_name = "img"
elif "hidden_states" in kwargs:
img_arg_name = "hidden_states"
txt_arg_name = None
if "txt" in kwargs:
txt_arg_name = "txt"
elif "context" in kwargs:
txt_arg_name = "context"
elif "encoder_hidden_states" in kwargs:
txt_arg_name = "encoder_hidden_states"
if self.accept_hidden_states_first:
if args:
img = args[0]
args = args[1:]
else:
img = kwargs.pop(img_arg_name)
if args:
txt = args[0]
args = args[1:]
else:
txt = kwargs.pop(txt_arg_name)
else:
if args:
txt = args[0]
args = args[1:]
else:
txt = kwargs.pop(txt_arg_name)
if args:
img = args[0]
args = args[1:]
else:
img = kwargs.pop(img_arg_name)
hidden_states = img
encoder_hidden_states = txt
if self.residual_diff_threshold <= 0.0:
for block in self.transformer_blocks:
if txt_arg_name == "encoder_hidden_states":
hidden_states = block(
hidden_states,
*args,
encoder_hidden_states=encoder_hidden_states,
**kwargs)
else:
if self.accept_hidden_states_first:
hidden_states = block(hidden_states,
encoder_hidden_states, *args,
**kwargs)
else:
hidden_states = block(encoder_hidden_states,
hidden_states, *args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat(
[hidden_states, encoder_hidden_states]
if self.cat_hidden_states_first else
[encoder_hidden_states, hidden_states],
dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
hidden_states = hidden_states[:,
encoder_hidden_states.shape[1]:]
if self.return_hidden_states_only:
return hidden_states
else:
return ((hidden_states, encoder_hidden_states)
if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))
original_hidden_states = hidden_states
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
first_transformer_block = self.transformer_blocks[0]
if txt_arg_name == "encoder_hidden_states":
hidden_states = first_transformer_block(
hidden_states,
*args,
encoder_hidden_states=encoder_hidden_states,
**kwargs)
else:
if self.accept_hidden_states_first:
hidden_states = first_transformer_block(
hidden_states, encoder_hidden_states, *args, **kwargs)
else:
hidden_states = first_transformer_block(
encoder_hidden_states, hidden_states, *args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
first_hidden_states_residual = hidden_states - original_hidden_states
del original_hidden_states
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=self.residual_diff_threshold,
)
if self.validate_can_use_cache_function is not None:
can_use_cache = self.validate_can_use_cache_function(can_use_cache)
torch._dynamo.graph_break()
if can_use_cache:
del first_hidden_states_residual
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
hidden_states, encoder_hidden_states)
else:
set_buffer("first_hidden_states_residual",
first_hidden_states_residual)
del first_hidden_states_residual
(
hidden_states,
encoder_hidden_states,
hidden_states_residual,
encoder_hidden_states_residual,
) = self.call_remaining_transformer_blocks(
hidden_states,
encoder_hidden_states,
*args,
txt_arg_name=txt_arg_name,
**kwargs)
set_buffer("hidden_states_residual", hidden_states_residual)
if encoder_hidden_states_residual is not None:
set_buffer("encoder_hidden_states_residual",
encoder_hidden_states_residual)
torch._dynamo.graph_break()
if self.return_hidden_states_only:
return hidden_states
else:
return ((hidden_states, encoder_hidden_states)
if self.return_hidden_states_first else
(encoder_hidden_states, hidden_states))
def call_remaining_transformer_blocks(self,
hidden_states,
encoder_hidden_states,
*args,
txt_arg_name=None,
**kwargs):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
if self.clone_original_hidden_states:
original_hidden_states = original_hidden_states.clone()
original_encoder_hidden_states = original_encoder_hidden_states.clone(
)
for block in self.transformer_blocks[1:]:
if txt_arg_name == "encoder_hidden_states":
hidden_states = block(
hidden_states,
*args,
encoder_hidden_states=encoder_hidden_states,
**kwargs)
else:
if self.accept_hidden_states_first:
hidden_states = block(hidden_states, encoder_hidden_states,
*args, **kwargs)
else:
hidden_states = block(encoder_hidden_states, hidden_states,
*args, **kwargs)
if not self.return_hidden_states_only:
hidden_states, encoder_hidden_states = hidden_states
if not self.return_hidden_states_first:
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
if self.single_transformer_blocks is not None:
hidden_states = torch.cat([hidden_states, encoder_hidden_states]
if self.cat_hidden_states_first else
[encoder_hidden_states, hidden_states],
dim=1)
for block in self.single_transformer_blocks:
hidden_states = block(hidden_states, *args, **kwargs)
if self.cat_hidden_states_first:
hidden_states, encoder_hidden_states = hidden_states.split(
[
hidden_states.shape[1] -
encoder_hidden_states.shape[1],
encoder_hidden_states.shape[1]
],
dim=1)
else:
encoder_hidden_states, hidden_states = hidden_states.split(
[
encoder_hidden_states.shape[1],
hidden_states.shape[1] - encoder_hidden_states.shape[1]
],
dim=1)
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.flatten().contiguous().reshape(
hidden_states_shape)
if encoder_hidden_states is not None:
encoder_hidden_states_shape = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.flatten().contiguous(
).reshape(encoder_hidden_states_shape)
hidden_states_residual = hidden_states - original_hidden_states
if encoder_hidden_states is None:
encoder_hidden_states_residual = None
else:
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24
def create_patch_unet_model__forward(model,
*,
residual_diff_threshold,
validate_can_use_cache_function=None):
def call_remaining_blocks(self, transformer_options, control,
transformer_patches, hs, h, *args, **kwargs):
original_hidden_states = h
for id, module in enumerate(self.input_blocks):
if id < 2:
continue
transformer_options["block"] = ("input", id)
h = forward_timestep_embed1(module, h, *args, **kwargs)
h = apply_control1(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
if self.middle_block is not None:
h = forward_timestep_embed1(self.middle_block, h, *args, **kwargs)
h = apply_control1(h, control, 'middle')
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control1(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed1(module, h, *args, output_shape,
**kwargs)
hidden_states_residual = h - original_hidden_states
return h, hidden_states_residual
def unet_model__forward(self,
x,
timesteps=None,
context=None,
y=None,
control=None,
transformer_options={},
**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames",
self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
can_use_cache = False
h = x
for id, module in enumerate(self.input_blocks):
if id >= 2:
break
transformer_options["block"] = ("input", id)
if id == 1:
original_h = h
h = forward_timestep_embed1(
module,
h,
emb,
context,
transformer_options,
time_context=time_context,
num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator)
h = apply_control1(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
if id == 1:
first_hidden_states_residual = h - original_h
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=residual_diff_threshold,
)
if validate_can_use_cache_function is not None:
can_use_cache = validate_can_use_cache_function(
can_use_cache)
if not can_use_cache:
set_buffer("first_hidden_states_residual",
first_hidden_states_residual)
del first_hidden_states_residual
torch._dynamo.graph_break()
if can_use_cache:
h = apply_prev_hidden_states_residual(h)
else:
h, hidden_states_residual = call_remaining_blocks(
self,
transformer_options,
control,
transformer_patches,
hs,
h,
emb,
context,
transformer_options,
time_context=time_context,
num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
new__forward = unet_model__forward.__get__(model)
@contextlib.contextmanager
def patch__forward():
with unittest.mock.patch.object(model, "_forward", new__forward):
yield
return patch__forward
# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24
def create_patch_flux_forward_orig(model,
*,
residual_diff_threshold,
validate_can_use_cache_function=None):
from torch import Tensor
def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe,
attn_mask, ca_idx, timesteps, transformer_options):
original_hidden_states = img
extra_block_forward_kwargs = {}
if attn_mask is not None:
extra_block_forward_kwargs["attn_mask"] = attn_mask
for i, block in enumerate(self.double_blocks):
if i < 1:
continue
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
**extra_block_forward_kwargs)
return out
out = blocks_replace[("double_block",
i)]({
"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
**extra_block_forward_kwargs
}, {
"original_block": block_wrap,
"transformer_options": transformer_options
})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
**extra_block_forward_kwargs)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
# PuLID attention
if getattr(self, "pulid_data", {}):
if i % self.pulid_double_interval == 0:
# Will calculate influence of all pulid nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps)
& (timesteps >= node_data['sigma_end'])):
img = img + node_data['weight'] * self.pulid_ca[
ca_idx](node_data['embedding'], img)
ca_idx += 1
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
**extra_block_forward_kwargs)
return out
out = blocks_replace[("single_block",
i)]({
"img": img,
"vec": vec,
"pe": pe,
**extra_block_forward_kwargs
}, {
"original_block": block_wrap,
"transformer_options": transformer_options
})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, **extra_block_forward_kwargs)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img[:, txt.shape[1]:, ...] += add
# PuLID attention
if getattr(self, "pulid_data", {}):
real_img, txt = img[:, txt.shape[1]:,
...], img[:, :txt.shape[1], ...]
if i % self.pulid_single_interval == 0:
# Will calculate influence of all nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps)
& (timesteps >= node_data['sigma_end'])):
real_img = real_img + node_data[
'weight'] * self.pulid_ca[ca_idx](
node_data['embedding'], real_img)
ca_idx += 1
img = torch.cat((txt, real_img), 1)
img = img[:, txt.shape[1]:, ...]
img = img.contiguous()
hidden_states_residual = img - original_hidden_states
return img, hidden_states_residual
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError(
"Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(
timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
ca_idx = 0
extra_block_forward_kwargs = {}
if attn_mask is not None:
extra_block_forward_kwargs["attn_mask"] = attn_mask
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
if i >= 1:
break
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(
img=args["img"],
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
**extra_block_forward_kwargs)
return out
out = blocks_replace[("double_block",
i)]({
"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
**extra_block_forward_kwargs
}, {
"original_block": block_wrap,
"transformer_options": transformer_options
})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img,
txt=txt,
vec=vec,
pe=pe,
**extra_block_forward_kwargs)
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
# PuLID attention
if getattr(self, "pulid_data", {}):
if i % self.pulid_double_interval == 0:
# Will calculate influence of all pulid nodes at once
for _, node_data in self.pulid_data.items():
if torch.any((node_data['sigma_start'] >= timesteps)
& (timesteps >= node_data['sigma_end'])):
img = img + node_data['weight'] * self.pulid_ca[
ca_idx](node_data['embedding'], img)
ca_idx += 1
if i == 0:
first_hidden_states_residual = img
can_use_cache = get_can_use_cache(
first_hidden_states_residual,
threshold=residual_diff_threshold,
)
if validate_can_use_cache_function is not None:
can_use_cache = validate_can_use_cache_function(
can_use_cache)
if not can_use_cache:
set_buffer("first_hidden_states_residual",
first_hidden_states_residual)
del first_hidden_states_residual
torch._dynamo.graph_break()
if can_use_cache:
img = apply_prev_hidden_states_residual(img)
else:
img, hidden_states_residual = call_remaining_blocks(
self,
blocks_replace,
control,
img,
txt,
vec,
pe,
attn_mask,
ca_idx,
timesteps,
transformer_options,
)
set_buffer("hidden_states_residual", hidden_states_residual)
torch._dynamo.graph_break()
img = self.final_layer(img,
vec) # (N, T, patch_size ** 2 * out_channels)
return img
new_forward_orig = forward_orig.__get__(model)
@contextlib.contextmanager
def patch_forward_orig():
with unittest.mock.patch.object(model, "forward_orig",
new_forward_orig):
yield
return patch_forward_orig