import contextlib import functools import logging from dataclasses import dataclass import torch try: from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig from sfast.compilers.diffusion_pipeline_compiler import ( _enable_xformers, _modify_model, ) from sfast.cuda.graphs import make_dynamic_graphed_callable from sfast.jit import utils as jit_utils from sfast.jit.trace_helper import trace_with_kwargs except: pass def hash_arg(arg): # micro optimization: bool obj is an instance of int if isinstance(arg, (str, int, float, bytes)): return arg if isinstance(arg, (tuple, list)): return tuple(map(hash_arg, arg)) if isinstance(arg, dict): return tuple( sorted( ((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0] ) ) return type(arg) class ModuleFactory: def get_converted_kwargs(self): return self.converted_kwargs import torch as th import torch.nn as nn import copy class BaseModelApplyModelModule(torch.nn.Module): def __init__(self, func, module): super().__init__() self.func = func self.module = module def forward( self, input_x, timestep, c_concat=None, c_crossattn=None, y=None, control=None, transformer_options={}, ): kwargs = {"y": y} new_transformer_options = {} return self.func( input_x, timestep, c_concat=c_concat, c_crossattn=c_crossattn, control=control, transformer_options=new_transformer_options, **kwargs, ) class BaseModelApplyModelModuleFactory(ModuleFactory): kwargs_name = ( "input_x", "timestep", "c_concat", "c_crossattn", "y", "control", ) def __init__(self, callable, kwargs) -> None: self.callable = callable self.unet_config = callable.__self__.model_config.unet_config self.kwargs = kwargs self.patch_module = {} self.patch_module_parameter = {} self.converted_kwargs = self.gen_converted_kwargs() def gen_converted_kwargs(self): converted_kwargs = {} for arg_name, arg in self.kwargs.items(): if arg_name in self.kwargs_name: converted_kwargs[arg_name] = arg transformer_options = self.kwargs.get("transformer_options", {}) patches = transformer_options.get("patches", {}) patch_module = {} patch_module_parameter = {} new_transformer_options = {} new_transformer_options["patches"] = patch_module_parameter self.patch_module = patch_module self.patch_module_parameter = patch_module_parameter return converted_kwargs def gen_cache_key(self): key_kwargs = {} for k, v in self.converted_kwargs.items(): key_kwargs[k] = v patch_module_cache_key = {} return ( self.callable.__class__.__qualname__, hash_arg(self.unet_config), hash_arg(key_kwargs), hash_arg(patch_module_cache_key), ) @contextlib.contextmanager def converted_module_context(self): module = BaseModelApplyModelModule(self.callable, self.callable.__self__) yield (module, self.converted_kwargs) logger = logging.getLogger() @dataclass class TracedModuleCacheItem: module: object patch_id: int device: str class LazyTraceModule: traced_modules = {} def __init__(self, config=None, patch_id=None, **kwargs_) -> None: self.config = config self.patch_id = patch_id self.kwargs_ = kwargs_ self.modify_model = functools.partial( _modify_model, enable_cnn_optimization=config.enable_cnn_optimization, prefer_lowp_gemm=config.prefer_lowp_gemm, enable_triton=config.enable_triton, enable_triton_reshape=config.enable_triton, memory_format=config.memory_format, ) self.cuda_graph_modules = {} def ts_compiler( self, m, ): with torch.jit.optimized_execution(True): if self.config.enable_jit_freeze: # raw freeze causes Tensor reference leak # because the constant Tensors in the GraphFunction of # the compilation unit are never freed. m.eval() m = jit_utils.better_freeze(m) self.modify_model(m) if self.config.enable_cuda_graph: m = make_dynamic_graphed_callable(m) return m def __call__(self, model_function, /, **kwargs): module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs) kwargs = module_factory.get_converted_kwargs() key = module_factory.gen_cache_key() traced_module = self.cuda_graph_modules.get(key) if traced_module is None: with module_factory.converted_module_context() as (m_model, m_kwargs): logger.info( f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}' ) traced_m, call_helper = trace_with_kwargs( m_model, None, m_kwargs, **self.kwargs_ ) traced_m = self.ts_compiler(traced_m) traced_module = call_helper(traced_m) self.cuda_graph_modules[key] = traced_module return traced_module(**kwargs) def build_lazy_trace_module(config, device, patch_id): config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda" if config.enable_xformers: _enable_xformers(None) return LazyTraceModule( config=config, patch_id=patch_id, check_trace=True, strict=True, ) def gen_stable_fast_config(): config = CompilationConfig.Default() try: import xformers config.enable_xformers = True except ImportError: print("xformers not installed, skip") # CUDA Graph is suggested for small batch sizes. # After capturing, the model only accepts one fixed image size. # If you want the model to be dynamic, don't enable it. config.enable_cuda_graph = False # config.enable_jit_freeze = False return config class StableFastPatch: def __init__(self, model, config): self.model = model self.config = config self.stable_fast_model = None def __call__(self, model_function, params): input_x = params.get("input") timestep_ = params.get("timestep") c = params.get("c") if self.stable_fast_model is None: self.stable_fast_model = build_lazy_trace_module( self.config, input_x.device, id(self), ) return self.stable_fast_model( model_function, input_x=input_x, timestep=timestep_, **c ) def to(self, device): if type(device) == torch.device: if self.config.enable_cuda_graph or self.config.enable_jit_freeze: if device.type == "cpu": del self.stable_fast_model self.stable_fast_model = None print( "\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m" ) return self class ApplyStableFastUnet: def apply_stable_fast(self, model, enable_cuda_graph): config = gen_stable_fast_config() if config.memory_format is not None: model.model.to(memory_format=config.memory_format) patch = StableFastPatch(model, config) model_stable_fast = model.clone() model_stable_fast.set_model_unet_function_wrapper(patch) return (model_stable_fast,)