Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
import unittest | |
import torch | |
from . import first_block_cache | |
class ApplyFBCacheOnModel: | |
def patch( | |
self, | |
model, | |
object_to_patch, | |
residual_diff_threshold, | |
max_consecutive_cache_hits=-1, | |
start=0.0, | |
end=1.0, | |
): | |
if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0: | |
return (model, ) | |
# first_block_cache.patch_get_output_data() | |
using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1 | |
if using_validation: | |
model_sampling = model.get_model_object("model_sampling") | |
start_sigma, end_sigma = (float( | |
model_sampling.percent_to_sigma(pct)) for pct in (start, end)) | |
del model_sampling | |
def validate_use_cache(use_cached): | |
nonlocal consecutive_cache_hits | |
use_cached = use_cached and end_sigma <= current_timestep <= start_sigma | |
use_cached = use_cached and (max_consecutive_cache_hits < 0 | |
or consecutive_cache_hits | |
< max_consecutive_cache_hits) | |
consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0 | |
return use_cached | |
else: | |
validate_use_cache = None | |
prev_timestep = None | |
prev_input_state = None | |
current_timestep = None | |
consecutive_cache_hits = 0 | |
def reset_cache_state(): | |
# Resets the cache state and hits/time tracking variables. | |
nonlocal prev_input_state, prev_timestep, consecutive_cache_hits | |
prev_input_state = prev_timestep = None | |
consecutive_cache_hits = 0 | |
first_block_cache.set_current_cache_context( | |
first_block_cache.create_cache_context()) | |
def ensure_cache_state(model_input: torch.Tensor, timestep: float): | |
# Validates the current cache state and hits/time tracking variables | |
# and triggers a reset if necessary. Also updates current_timestep. | |
nonlocal current_timestep | |
input_state = (model_input.shape, model_input.dtype, model_input.device) | |
need_reset = ( | |
prev_timestep is None or | |
prev_input_state != input_state or | |
first_block_cache.get_current_cache_context() is None or | |
timestep >= prev_timestep | |
) | |
if need_reset: | |
reset_cache_state() | |
current_timestep = timestep | |
def update_cache_state(model_input: torch.Tensor, timestep: float): | |
# Updates the previous timestep and input state validation variables. | |
nonlocal prev_timestep, prev_input_state | |
prev_timestep = timestep | |
prev_input_state = (model_input.shape, model_input.dtype, model_input.device) | |
model = model[0].clone() | |
diffusion_model = model.get_model_object(object_to_patch) | |
if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"): | |
if diffusion_model.__class__.__name__ == "UNetModel": | |
create_patch_function = first_block_cache.create_patch_unet_model__forward | |
elif diffusion_model.__class__.__name__ == "Flux": | |
create_patch_function = first_block_cache.create_patch_flux_forward_orig | |
else: | |
raise ValueError( | |
f"Unsupported model {diffusion_model.__class__.__name__}") | |
patch_forward = create_patch_function( | |
diffusion_model, | |
residual_diff_threshold=residual_diff_threshold, | |
validate_can_use_cache_function=validate_use_cache, | |
) | |
def model_unet_function_wrapper(model_function, kwargs): | |
try: | |
input = kwargs["input"] | |
timestep = kwargs["timestep"] | |
c = kwargs["c"] | |
t = timestep[0].item() | |
ensure_cache_state(input, t) | |
with patch_forward(): | |
result = model_function(input, timestep, **c) | |
update_cache_state(input, t) | |
return result | |
except Exception as exc: | |
reset_cache_state() | |
raise exc from None | |
else: | |
is_non_native_ltxv = False | |
if diffusion_model.__class__.__name__ == "LTXVTransformer3D": | |
is_non_native_ltxv = True | |
diffusion_model = diffusion_model.transformer | |
double_blocks_name = None | |
single_blocks_name = None | |
if hasattr(diffusion_model, "transformer_blocks"): | |
double_blocks_name = "transformer_blocks" | |
elif hasattr(diffusion_model, "double_blocks"): | |
double_blocks_name = "double_blocks" | |
elif hasattr(diffusion_model, "joint_blocks"): | |
double_blocks_name = "joint_blocks" | |
else: | |
raise ValueError( | |
f"No double blocks found for {diffusion_model.__class__.__name__}" | |
) | |
if hasattr(diffusion_model, "single_blocks"): | |
single_blocks_name = "single_blocks" | |
if is_non_native_ltxv: | |
original_create_skip_layer_mask = getattr( | |
diffusion_model, "create_skip_layer_mask", None) | |
if original_create_skip_layer_mask is not None: | |
# original_double_blocks = getattr(diffusion_model, | |
# double_blocks_name) | |
def new_create_skip_layer_mask(self, *args, **kwargs): | |
# with unittest.mock.patch.object(self, double_blocks_name, | |
# original_double_blocks): | |
# return original_create_skip_layer_mask(*args, **kwargs) | |
# return original_create_skip_layer_mask(*args, **kwargs) | |
raise RuntimeError( | |
"STG is not supported with FBCache yet") | |
diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__( | |
diffusion_model) | |
cached_transformer_blocks = torch.nn.ModuleList([ | |
first_block_cache.CachedTransformerBlocks( | |
None if double_blocks_name is None else getattr( | |
diffusion_model, double_blocks_name), | |
None if single_blocks_name is None else getattr( | |
diffusion_model, single_blocks_name), | |
residual_diff_threshold=residual_diff_threshold, | |
validate_can_use_cache_function=validate_use_cache, | |
cat_hidden_states_first=diffusion_model.__class__.__name__ | |
== "HunyuanVideo", | |
return_hidden_states_only=diffusion_model.__class__. | |
__name__ == "LTXVModel" or is_non_native_ltxv, | |
clone_original_hidden_states=diffusion_model.__class__. | |
__name__ == "LTXVModel", | |
return_hidden_states_first=diffusion_model.__class__. | |
__name__ != "OpenAISignatureMMDITWrapper", | |
accept_hidden_states_first=diffusion_model.__class__. | |
__name__ != "OpenAISignatureMMDITWrapper", | |
) | |
]) | |
dummy_single_transformer_blocks = torch.nn.ModuleList() | |
def model_unet_function_wrapper(model_function, kwargs): | |
try: | |
input = kwargs["input"] | |
timestep = kwargs["timestep"] | |
c = kwargs["c"] | |
t = timestep[0].item() | |
ensure_cache_state(input, t) | |
with unittest.mock.patch.object( | |
diffusion_model, | |
double_blocks_name, | |
cached_transformer_blocks, | |
), unittest.mock.patch.object( | |
diffusion_model, | |
single_blocks_name, | |
dummy_single_transformer_blocks, | |
) if single_blocks_name is not None else contextlib.nullcontext( | |
): | |
result = model_function(input, timestep, **c) | |
update_cache_state(input, t) | |
return result | |
except Exception as exc: | |
reset_cache_state() | |
raise exc from None | |
model.set_model_unet_function_wrapper(model_unet_function_wrapper) | |
return (model, ) | |