Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import contextlib
import unittest
import torch
from . import first_block_cache
class ApplyFBCacheOnModel:
def patch(
self,
model,
object_to_patch,
residual_diff_threshold,
max_consecutive_cache_hits=-1,
start=0.0,
end=1.0,
):
if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0:
return (model, )
# first_block_cache.patch_get_output_data()
using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1
if using_validation:
model_sampling = model.get_model_object("model_sampling")
start_sigma, end_sigma = (float(
model_sampling.percent_to_sigma(pct)) for pct in (start, end))
del model_sampling
@torch.compiler.disable()
def validate_use_cache(use_cached):
nonlocal consecutive_cache_hits
use_cached = use_cached and end_sigma <= current_timestep <= start_sigma
use_cached = use_cached and (max_consecutive_cache_hits < 0
or consecutive_cache_hits
< max_consecutive_cache_hits)
consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0
return use_cached
else:
validate_use_cache = None
prev_timestep = None
prev_input_state = None
current_timestep = None
consecutive_cache_hits = 0
def reset_cache_state():
# Resets the cache state and hits/time tracking variables.
nonlocal prev_input_state, prev_timestep, consecutive_cache_hits
prev_input_state = prev_timestep = None
consecutive_cache_hits = 0
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())
def ensure_cache_state(model_input: torch.Tensor, timestep: float):
# Validates the current cache state and hits/time tracking variables
# and triggers a reset if necessary. Also updates current_timestep.
nonlocal current_timestep
input_state = (model_input.shape, model_input.dtype, model_input.device)
need_reset = (
prev_timestep is None or
prev_input_state != input_state or
first_block_cache.get_current_cache_context() is None or
timestep >= prev_timestep
)
if need_reset:
reset_cache_state()
current_timestep = timestep
def update_cache_state(model_input: torch.Tensor, timestep: float):
# Updates the previous timestep and input state validation variables.
nonlocal prev_timestep, prev_input_state
prev_timestep = timestep
prev_input_state = (model_input.shape, model_input.dtype, model_input.device)
model = model[0].clone()
diffusion_model = model.get_model_object(object_to_patch)
if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"):
if diffusion_model.__class__.__name__ == "UNetModel":
create_patch_function = first_block_cache.create_patch_unet_model__forward
elif diffusion_model.__class__.__name__ == "Flux":
create_patch_function = first_block_cache.create_patch_flux_forward_orig
else:
raise ValueError(
f"Unsupported model {diffusion_model.__class__.__name__}")
patch_forward = create_patch_function(
diffusion_model,
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
)
def model_unet_function_wrapper(model_function, kwargs):
try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
t = timestep[0].item()
ensure_cache_state(input, t)
with patch_forward():
result = model_function(input, timestep, **c)
update_cache_state(input, t)
return result
except Exception as exc:
reset_cache_state()
raise exc from None
else:
is_non_native_ltxv = False
if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
is_non_native_ltxv = True
diffusion_model = diffusion_model.transformer
double_blocks_name = None
single_blocks_name = None
if hasattr(diffusion_model, "transformer_blocks"):
double_blocks_name = "transformer_blocks"
elif hasattr(diffusion_model, "double_blocks"):
double_blocks_name = "double_blocks"
elif hasattr(diffusion_model, "joint_blocks"):
double_blocks_name = "joint_blocks"
else:
raise ValueError(
f"No double blocks found for {diffusion_model.__class__.__name__}"
)
if hasattr(diffusion_model, "single_blocks"):
single_blocks_name = "single_blocks"
if is_non_native_ltxv:
original_create_skip_layer_mask = getattr(
diffusion_model, "create_skip_layer_mask", None)
if original_create_skip_layer_mask is not None:
# original_double_blocks = getattr(diffusion_model,
# double_blocks_name)
def new_create_skip_layer_mask(self, *args, **kwargs):
# with unittest.mock.patch.object(self, double_blocks_name,
# original_double_blocks):
# return original_create_skip_layer_mask(*args, **kwargs)
# return original_create_skip_layer_mask(*args, **kwargs)
raise RuntimeError(
"STG is not supported with FBCache yet")
diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
diffusion_model)
cached_transformer_blocks = torch.nn.ModuleList([
first_block_cache.CachedTransformerBlocks(
None if double_blocks_name is None else getattr(
diffusion_model, double_blocks_name),
None if single_blocks_name is None else getattr(
diffusion_model, single_blocks_name),
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
cat_hidden_states_first=diffusion_model.__class__.__name__
== "HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.
__name__ == "LTXVModel" or is_non_native_ltxv,
clone_original_hidden_states=diffusion_model.__class__.
__name__ == "LTXVModel",
return_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
accept_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()
def model_unet_function_wrapper(model_function, kwargs):
try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
t = timestep[0].item()
ensure_cache_state(input, t)
with unittest.mock.patch.object(
diffusion_model,
double_blocks_name,
cached_transformer_blocks,
), unittest.mock.patch.object(
diffusion_model,
single_blocks_name,
dummy_single_transformer_blocks,
) if single_blocks_name is not None else contextlib.nullcontext(
):
result = model_function(input, timestep, **c)
update_cache_state(input, t)
return result
except Exception as exc:
reset_cache_state()
raise exc from None
model.set_model_unet_function_wrapper(model_unet_function_wrapper)
return (model, )