Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
import dataclasses | |
import unittest | |
from collections import defaultdict | |
from typing import DefaultDict, Dict | |
import torch | |
from modules.AutoEncoders.ResBlock import forward_timestep_embed1 | |
from modules.NeuralNetwork.unet import apply_control1 | |
from modules.sample.sampling_util import timestep_embedding | |
class CacheContext: | |
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) | |
incremental_name_counters: DefaultDict[str, int] = dataclasses.field( | |
default_factory=lambda: defaultdict(int)) | |
def get_incremental_name(self, name=None): | |
if name is None: | |
name = "default" | |
idx = self.incremental_name_counters[name] | |
self.incremental_name_counters[name] += 1 | |
return f"{name}_{idx}" | |
def reset_incremental_names(self): | |
self.incremental_name_counters.clear() | |
def get_buffer(self, name): | |
return self.buffers.get(name) | |
def set_buffer(self, name, buffer): | |
self.buffers[name] = buffer | |
def clear_buffers(self): | |
self.buffers.clear() | |
def get_buffer(name): | |
cache_context = get_current_cache_context() | |
assert cache_context is not None, "cache_context must be set before" | |
return cache_context.get_buffer(name) | |
def set_buffer(name, buffer): | |
cache_context = get_current_cache_context() | |
assert cache_context is not None, "cache_context must be set before" | |
cache_context.set_buffer(name, buffer) | |
_current_cache_context = None | |
def create_cache_context(): | |
return CacheContext() | |
def get_current_cache_context(): | |
return _current_cache_context | |
def set_current_cache_context(cache_context=None): | |
global _current_cache_context | |
_current_cache_context = cache_context | |
def cache_context(cache_context): | |
global _current_cache_context | |
old_cache_context = _current_cache_context | |
_current_cache_context = cache_context | |
try: | |
yield | |
finally: | |
_current_cache_context = old_cache_context | |
# def patch_get_output_data(): | |
# import execution | |
# get_output_data = getattr(execution, "get_output_data", None) | |
# if get_output_data is None: | |
# return | |
# if getattr(get_output_data, "_patched", False): | |
# return | |
# def new_get_output_data(*args, **kwargs): | |
# out = get_output_data(*args, **kwargs) | |
# cache_context = get_current_cache_context() | |
# if cache_context is not None: | |
# cache_context.clear_buffers() | |
# set_current_cache_context(None) | |
# return out | |
# new_get_output_data._patched = True | |
# execution.get_output_data = new_get_output_data | |
def are_two_tensors_similar(t1, t2, *, threshold): | |
if t1.shape != t2.shape: | |
return False | |
mean_diff = (t1 - t2).abs().mean() | |
mean_t1 = t1.abs().mean() | |
diff = mean_diff / mean_t1 | |
return diff.item() < threshold | |
def apply_prev_hidden_states_residual(hidden_states, | |
encoder_hidden_states=None): | |
hidden_states_residual = get_buffer("hidden_states_residual") | |
assert hidden_states_residual is not None, "hidden_states_residual must be set before" | |
hidden_states = hidden_states_residual + hidden_states | |
hidden_states = hidden_states.contiguous() | |
if encoder_hidden_states is None: | |
return hidden_states | |
encoder_hidden_states_residual = get_buffer( | |
"encoder_hidden_states_residual") | |
if encoder_hidden_states_residual is None: | |
encoder_hidden_states = None | |
else: | |
encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states | |
encoder_hidden_states = encoder_hidden_states.contiguous() | |
return hidden_states, encoder_hidden_states | |
def get_can_use_cache(first_hidden_states_residual, | |
threshold, | |
parallelized=False): | |
prev_first_hidden_states_residual = get_buffer( | |
"first_hidden_states_residual") | |
can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar( | |
prev_first_hidden_states_residual, | |
first_hidden_states_residual, | |
threshold=threshold, | |
) | |
return can_use_cache | |
class CachedTransformerBlocks(torch.nn.Module): | |
def __init__( | |
self, | |
transformer_blocks, | |
single_transformer_blocks=None, | |
*, | |
residual_diff_threshold, | |
validate_can_use_cache_function=None, | |
return_hidden_states_first=True, | |
accept_hidden_states_first=True, | |
cat_hidden_states_first=False, | |
return_hidden_states_only=False, | |
clone_original_hidden_states=False, | |
): | |
super().__init__() | |
self.transformer_blocks = transformer_blocks | |
self.single_transformer_blocks = single_transformer_blocks | |
self.residual_diff_threshold = residual_diff_threshold | |
self.validate_can_use_cache_function = validate_can_use_cache_function | |
self.return_hidden_states_first = return_hidden_states_first | |
self.accept_hidden_states_first = accept_hidden_states_first | |
self.cat_hidden_states_first = cat_hidden_states_first | |
self.return_hidden_states_only = return_hidden_states_only | |
self.clone_original_hidden_states = clone_original_hidden_states | |
def forward(self, *args, **kwargs): | |
img_arg_name = None | |
if "img" in kwargs: | |
img_arg_name = "img" | |
elif "hidden_states" in kwargs: | |
img_arg_name = "hidden_states" | |
txt_arg_name = None | |
if "txt" in kwargs: | |
txt_arg_name = "txt" | |
elif "context" in kwargs: | |
txt_arg_name = "context" | |
elif "encoder_hidden_states" in kwargs: | |
txt_arg_name = "encoder_hidden_states" | |
if self.accept_hidden_states_first: | |
if args: | |
img = args[0] | |
args = args[1:] | |
else: | |
img = kwargs.pop(img_arg_name) | |
if args: | |
txt = args[0] | |
args = args[1:] | |
else: | |
txt = kwargs.pop(txt_arg_name) | |
else: | |
if args: | |
txt = args[0] | |
args = args[1:] | |
else: | |
txt = kwargs.pop(txt_arg_name) | |
if args: | |
img = args[0] | |
args = args[1:] | |
else: | |
img = kwargs.pop(img_arg_name) | |
hidden_states = img | |
encoder_hidden_states = txt | |
if self.residual_diff_threshold <= 0.0: | |
for block in self.transformer_blocks: | |
if txt_arg_name == "encoder_hidden_states": | |
hidden_states = block( | |
hidden_states, | |
*args, | |
encoder_hidden_states=encoder_hidden_states, | |
**kwargs) | |
else: | |
if self.accept_hidden_states_first: | |
hidden_states = block(hidden_states, | |
encoder_hidden_states, *args, | |
**kwargs) | |
else: | |
hidden_states = block(encoder_hidden_states, | |
hidden_states, *args, **kwargs) | |
if not self.return_hidden_states_only: | |
hidden_states, encoder_hidden_states = hidden_states | |
if not self.return_hidden_states_first: | |
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states | |
if self.single_transformer_blocks is not None: | |
hidden_states = torch.cat( | |
[hidden_states, encoder_hidden_states] | |
if self.cat_hidden_states_first else | |
[encoder_hidden_states, hidden_states], | |
dim=1) | |
for block in self.single_transformer_blocks: | |
hidden_states = block(hidden_states, *args, **kwargs) | |
hidden_states = hidden_states[:, | |
encoder_hidden_states.shape[1]:] | |
if self.return_hidden_states_only: | |
return hidden_states | |
else: | |
return ((hidden_states, encoder_hidden_states) | |
if self.return_hidden_states_first else | |
(encoder_hidden_states, hidden_states)) | |
original_hidden_states = hidden_states | |
if self.clone_original_hidden_states: | |
original_hidden_states = original_hidden_states.clone() | |
first_transformer_block = self.transformer_blocks[0] | |
if txt_arg_name == "encoder_hidden_states": | |
hidden_states = first_transformer_block( | |
hidden_states, | |
*args, | |
encoder_hidden_states=encoder_hidden_states, | |
**kwargs) | |
else: | |
if self.accept_hidden_states_first: | |
hidden_states = first_transformer_block( | |
hidden_states, encoder_hidden_states, *args, **kwargs) | |
else: | |
hidden_states = first_transformer_block( | |
encoder_hidden_states, hidden_states, *args, **kwargs) | |
if not self.return_hidden_states_only: | |
hidden_states, encoder_hidden_states = hidden_states | |
if not self.return_hidden_states_first: | |
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states | |
first_hidden_states_residual = hidden_states - original_hidden_states | |
del original_hidden_states | |
can_use_cache = get_can_use_cache( | |
first_hidden_states_residual, | |
threshold=self.residual_diff_threshold, | |
) | |
if self.validate_can_use_cache_function is not None: | |
can_use_cache = self.validate_can_use_cache_function(can_use_cache) | |
torch._dynamo.graph_break() | |
if can_use_cache: | |
del first_hidden_states_residual | |
hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual( | |
hidden_states, encoder_hidden_states) | |
else: | |
set_buffer("first_hidden_states_residual", | |
first_hidden_states_residual) | |
del first_hidden_states_residual | |
( | |
hidden_states, | |
encoder_hidden_states, | |
hidden_states_residual, | |
encoder_hidden_states_residual, | |
) = self.call_remaining_transformer_blocks( | |
hidden_states, | |
encoder_hidden_states, | |
*args, | |
txt_arg_name=txt_arg_name, | |
**kwargs) | |
set_buffer("hidden_states_residual", hidden_states_residual) | |
if encoder_hidden_states_residual is not None: | |
set_buffer("encoder_hidden_states_residual", | |
encoder_hidden_states_residual) | |
torch._dynamo.graph_break() | |
if self.return_hidden_states_only: | |
return hidden_states | |
else: | |
return ((hidden_states, encoder_hidden_states) | |
if self.return_hidden_states_first else | |
(encoder_hidden_states, hidden_states)) | |
def call_remaining_transformer_blocks(self, | |
hidden_states, | |
encoder_hidden_states, | |
*args, | |
txt_arg_name=None, | |
**kwargs): | |
original_hidden_states = hidden_states | |
original_encoder_hidden_states = encoder_hidden_states | |
if self.clone_original_hidden_states: | |
original_hidden_states = original_hidden_states.clone() | |
original_encoder_hidden_states = original_encoder_hidden_states.clone( | |
) | |
for block in self.transformer_blocks[1:]: | |
if txt_arg_name == "encoder_hidden_states": | |
hidden_states = block( | |
hidden_states, | |
*args, | |
encoder_hidden_states=encoder_hidden_states, | |
**kwargs) | |
else: | |
if self.accept_hidden_states_first: | |
hidden_states = block(hidden_states, encoder_hidden_states, | |
*args, **kwargs) | |
else: | |
hidden_states = block(encoder_hidden_states, hidden_states, | |
*args, **kwargs) | |
if not self.return_hidden_states_only: | |
hidden_states, encoder_hidden_states = hidden_states | |
if not self.return_hidden_states_first: | |
hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states | |
if self.single_transformer_blocks is not None: | |
hidden_states = torch.cat([hidden_states, encoder_hidden_states] | |
if self.cat_hidden_states_first else | |
[encoder_hidden_states, hidden_states], | |
dim=1) | |
for block in self.single_transformer_blocks: | |
hidden_states = block(hidden_states, *args, **kwargs) | |
if self.cat_hidden_states_first: | |
hidden_states, encoder_hidden_states = hidden_states.split( | |
[ | |
hidden_states.shape[1] - | |
encoder_hidden_states.shape[1], | |
encoder_hidden_states.shape[1] | |
], | |
dim=1) | |
else: | |
encoder_hidden_states, hidden_states = hidden_states.split( | |
[ | |
encoder_hidden_states.shape[1], | |
hidden_states.shape[1] - encoder_hidden_states.shape[1] | |
], | |
dim=1) | |
hidden_states_shape = hidden_states.shape | |
hidden_states = hidden_states.flatten().contiguous().reshape( | |
hidden_states_shape) | |
if encoder_hidden_states is not None: | |
encoder_hidden_states_shape = encoder_hidden_states.shape | |
encoder_hidden_states = encoder_hidden_states.flatten().contiguous( | |
).reshape(encoder_hidden_states_shape) | |
hidden_states_residual = hidden_states - original_hidden_states | |
if encoder_hidden_states is None: | |
encoder_hidden_states_residual = None | |
else: | |
encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states | |
return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual | |
# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 | |
def create_patch_unet_model__forward(model, | |
*, | |
residual_diff_threshold, | |
validate_can_use_cache_function=None): | |
def call_remaining_blocks(self, transformer_options, control, | |
transformer_patches, hs, h, *args, **kwargs): | |
original_hidden_states = h | |
for id, module in enumerate(self.input_blocks): | |
if id < 2: | |
continue | |
transformer_options["block"] = ("input", id) | |
h = forward_timestep_embed1(module, h, *args, **kwargs) | |
h = apply_control1(h, control, 'input') | |
if "input_block_patch" in transformer_patches: | |
patch = transformer_patches["input_block_patch"] | |
for p in patch: | |
h = p(h, transformer_options) | |
hs.append(h) | |
if "input_block_patch_after_skip" in transformer_patches: | |
patch = transformer_patches["input_block_patch_after_skip"] | |
for p in patch: | |
h = p(h, transformer_options) | |
transformer_options["block"] = ("middle", 0) | |
if self.middle_block is not None: | |
h = forward_timestep_embed1(self.middle_block, h, *args, **kwargs) | |
h = apply_control1(h, control, 'middle') | |
for id, module in enumerate(self.output_blocks): | |
transformer_options["block"] = ("output", id) | |
hsp = hs.pop() | |
hsp = apply_control1(hsp, control, 'output') | |
if "output_block_patch" in transformer_patches: | |
patch = transformer_patches["output_block_patch"] | |
for p in patch: | |
h, hsp = p(h, hsp, transformer_options) | |
h = torch.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed1(module, h, *args, output_shape, | |
**kwargs) | |
hidden_states_residual = h - original_hidden_states | |
return h, hidden_states_residual | |
def unet_model__forward(self, | |
x, | |
timesteps=None, | |
context=None, | |
y=None, | |
control=None, | |
transformer_options={}, | |
**kwargs): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
:param y: an [N] Tensor of labels, if class-conditional. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
transformer_options["original_shape"] = list(x.shape) | |
transformer_options["transformer_index"] = 0 | |
transformer_patches = transformer_options.get("patches", {}) | |
num_video_frames = kwargs.get("num_video_frames", | |
self.default_num_video_frames) | |
image_only_indicator = kwargs.get("image_only_indicator", None) | |
time_context = kwargs.get("time_context", None) | |
assert (y is not None) == ( | |
self.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, | |
self.model_channels, | |
repeat_only=False).to(x.dtype) | |
emb = self.time_embed(t_emb) | |
if "emb_patch" in transformer_patches: | |
patch = transformer_patches["emb_patch"] | |
for p in patch: | |
emb = p(emb, self.model_channels, transformer_options) | |
if self.num_classes is not None: | |
assert y.shape[0] == x.shape[0] | |
emb = emb + self.label_emb(y) | |
can_use_cache = False | |
h = x | |
for id, module in enumerate(self.input_blocks): | |
if id >= 2: | |
break | |
transformer_options["block"] = ("input", id) | |
if id == 1: | |
original_h = h | |
h = forward_timestep_embed1( | |
module, | |
h, | |
emb, | |
context, | |
transformer_options, | |
time_context=time_context, | |
num_video_frames=num_video_frames, | |
image_only_indicator=image_only_indicator) | |
h = apply_control1(h, control, 'input') | |
if "input_block_patch" in transformer_patches: | |
patch = transformer_patches["input_block_patch"] | |
for p in patch: | |
h = p(h, transformer_options) | |
hs.append(h) | |
if "input_block_patch_after_skip" in transformer_patches: | |
patch = transformer_patches["input_block_patch_after_skip"] | |
for p in patch: | |
h = p(h, transformer_options) | |
if id == 1: | |
first_hidden_states_residual = h - original_h | |
can_use_cache = get_can_use_cache( | |
first_hidden_states_residual, | |
threshold=residual_diff_threshold, | |
) | |
if validate_can_use_cache_function is not None: | |
can_use_cache = validate_can_use_cache_function( | |
can_use_cache) | |
if not can_use_cache: | |
set_buffer("first_hidden_states_residual", | |
first_hidden_states_residual) | |
del first_hidden_states_residual | |
torch._dynamo.graph_break() | |
if can_use_cache: | |
h = apply_prev_hidden_states_residual(h) | |
else: | |
h, hidden_states_residual = call_remaining_blocks( | |
self, | |
transformer_options, | |
control, | |
transformer_patches, | |
hs, | |
h, | |
emb, | |
context, | |
transformer_options, | |
time_context=time_context, | |
num_video_frames=num_video_frames, | |
image_only_indicator=image_only_indicator) | |
set_buffer("hidden_states_residual", hidden_states_residual) | |
torch._dynamo.graph_break() | |
h = h.type(x.dtype) | |
if self.predict_codebook_ids: | |
return self.id_predictor(h) | |
else: | |
return self.out(h) | |
new__forward = unet_model__forward.__get__(model) | |
def patch__forward(): | |
with unittest.mock.patch.object(model, "_forward", new__forward): | |
yield | |
return patch__forward | |
# Based on 90f349f93df3083a507854d7fc7c3e1bb9014e24 | |
def create_patch_flux_forward_orig(model, | |
*, | |
residual_diff_threshold, | |
validate_can_use_cache_function=None): | |
from torch import Tensor | |
def call_remaining_blocks(self, blocks_replace, control, img, txt, vec, pe, | |
attn_mask, ca_idx, timesteps, transformer_options): | |
original_hidden_states = img | |
extra_block_forward_kwargs = {} | |
if attn_mask is not None: | |
extra_block_forward_kwargs["attn_mask"] = attn_mask | |
for i, block in enumerate(self.double_blocks): | |
if i < 1: | |
continue | |
if ("double_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"], out["txt"] = block( | |
img=args["img"], | |
txt=args["txt"], | |
vec=args["vec"], | |
pe=args["pe"], | |
**extra_block_forward_kwargs) | |
return out | |
out = blocks_replace[("double_block", | |
i)]({ | |
"img": img, | |
"txt": txt, | |
"vec": vec, | |
"pe": pe, | |
**extra_block_forward_kwargs | |
}, { | |
"original_block": block_wrap, | |
"transformer_options": transformer_options | |
}) | |
txt = out["txt"] | |
img = out["img"] | |
else: | |
img, txt = block(img=img, | |
txt=txt, | |
vec=vec, | |
pe=pe, | |
**extra_block_forward_kwargs) | |
if control is not None: # Controlnet | |
control_i = control.get("input") | |
if i < len(control_i): | |
add = control_i[i] | |
if add is not None: | |
img += add | |
# PuLID attention | |
if getattr(self, "pulid_data", {}): | |
if i % self.pulid_double_interval == 0: | |
# Will calculate influence of all pulid nodes at once | |
for _, node_data in self.pulid_data.items(): | |
if torch.any((node_data['sigma_start'] >= timesteps) | |
& (timesteps >= node_data['sigma_end'])): | |
img = img + node_data['weight'] * self.pulid_ca[ | |
ca_idx](node_data['embedding'], img) | |
ca_idx += 1 | |
img = torch.cat((txt, img), 1) | |
for i, block in enumerate(self.single_blocks): | |
if ("single_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"] = block(args["img"], | |
vec=args["vec"], | |
pe=args["pe"], | |
**extra_block_forward_kwargs) | |
return out | |
out = blocks_replace[("single_block", | |
i)]({ | |
"img": img, | |
"vec": vec, | |
"pe": pe, | |
**extra_block_forward_kwargs | |
}, { | |
"original_block": block_wrap, | |
"transformer_options": transformer_options | |
}) | |
img = out["img"] | |
else: | |
img = block(img, vec=vec, pe=pe, **extra_block_forward_kwargs) | |
if control is not None: # Controlnet | |
control_o = control.get("output") | |
if i < len(control_o): | |
add = control_o[i] | |
if add is not None: | |
img[:, txt.shape[1]:, ...] += add | |
# PuLID attention | |
if getattr(self, "pulid_data", {}): | |
real_img, txt = img[:, txt.shape[1]:, | |
...], img[:, :txt.shape[1], ...] | |
if i % self.pulid_single_interval == 0: | |
# Will calculate influence of all nodes at once | |
for _, node_data in self.pulid_data.items(): | |
if torch.any((node_data['sigma_start'] >= timesteps) | |
& (timesteps >= node_data['sigma_end'])): | |
real_img = real_img + node_data[ | |
'weight'] * self.pulid_ca[ca_idx]( | |
node_data['embedding'], real_img) | |
ca_idx += 1 | |
img = torch.cat((txt, real_img), 1) | |
img = img[:, txt.shape[1]:, ...] | |
img = img.contiguous() | |
hidden_states_residual = img - original_hidden_states | |
return img, hidden_states_residual | |
def forward_orig( | |
self, | |
img: Tensor, | |
img_ids: Tensor, | |
txt: Tensor, | |
txt_ids: Tensor, | |
timesteps: Tensor, | |
y: Tensor, | |
guidance: Tensor = None, | |
control=None, | |
transformer_options={}, | |
attn_mask: Tensor = None, | |
) -> Tensor: | |
patches_replace = transformer_options.get("patches_replace", {}) | |
if img.ndim != 3 or txt.ndim != 3: | |
raise ValueError( | |
"Input img and txt tensors must have 3 dimensions.") | |
# running on sequences img | |
img = self.img_in(img) | |
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) | |
if self.params.guidance_embed: | |
if guidance is None: | |
raise ValueError( | |
"Didn't get guidance strength for guidance distilled model." | |
) | |
vec = vec + self.guidance_in( | |
timestep_embedding(guidance, 256).to(img.dtype)) | |
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) | |
txt = self.txt_in(txt) | |
ids = torch.cat((txt_ids, img_ids), dim=1) | |
pe = self.pe_embedder(ids) | |
ca_idx = 0 | |
extra_block_forward_kwargs = {} | |
if attn_mask is not None: | |
extra_block_forward_kwargs["attn_mask"] = attn_mask | |
blocks_replace = patches_replace.get("dit", {}) | |
for i, block in enumerate(self.double_blocks): | |
if i >= 1: | |
break | |
if ("double_block", i) in blocks_replace: | |
def block_wrap(args): | |
out = {} | |
out["img"], out["txt"] = block( | |
img=args["img"], | |
txt=args["txt"], | |
vec=args["vec"], | |
pe=args["pe"], | |
**extra_block_forward_kwargs) | |
return out | |
out = blocks_replace[("double_block", | |
i)]({ | |
"img": img, | |
"txt": txt, | |
"vec": vec, | |
"pe": pe, | |
**extra_block_forward_kwargs | |
}, { | |
"original_block": block_wrap, | |
"transformer_options": transformer_options | |
}) | |
txt = out["txt"] | |
img = out["img"] | |
else: | |
img, txt = block(img=img, | |
txt=txt, | |
vec=vec, | |
pe=pe, | |
**extra_block_forward_kwargs) | |
if control is not None: # Controlnet | |
control_i = control.get("input") | |
if i < len(control_i): | |
add = control_i[i] | |
if add is not None: | |
img += add | |
# PuLID attention | |
if getattr(self, "pulid_data", {}): | |
if i % self.pulid_double_interval == 0: | |
# Will calculate influence of all pulid nodes at once | |
for _, node_data in self.pulid_data.items(): | |
if torch.any((node_data['sigma_start'] >= timesteps) | |
& (timesteps >= node_data['sigma_end'])): | |
img = img + node_data['weight'] * self.pulid_ca[ | |
ca_idx](node_data['embedding'], img) | |
ca_idx += 1 | |
if i == 0: | |
first_hidden_states_residual = img | |
can_use_cache = get_can_use_cache( | |
first_hidden_states_residual, | |
threshold=residual_diff_threshold, | |
) | |
if validate_can_use_cache_function is not None: | |
can_use_cache = validate_can_use_cache_function( | |
can_use_cache) | |
if not can_use_cache: | |
set_buffer("first_hidden_states_residual", | |
first_hidden_states_residual) | |
del first_hidden_states_residual | |
torch._dynamo.graph_break() | |
if can_use_cache: | |
img = apply_prev_hidden_states_residual(img) | |
else: | |
img, hidden_states_residual = call_remaining_blocks( | |
self, | |
blocks_replace, | |
control, | |
img, | |
txt, | |
vec, | |
pe, | |
attn_mask, | |
ca_idx, | |
timesteps, | |
transformer_options, | |
) | |
set_buffer("hidden_states_residual", hidden_states_residual) | |
torch._dynamo.graph_break() | |
img = self.final_layer(img, | |
vec) # (N, T, patch_size ** 2 * out_channels) | |
return img | |
new_forward_orig = forward_orig.__get__(model) | |
def patch_forward_orig(): | |
with unittest.mock.patch.object(model, "forward_orig", | |
new_forward_orig): | |
yield | |
return patch_forward_orig | |