Spaces:
Running
on
Zero
Running
on
Zero
| import contextlib | |
| import functools | |
| import logging | |
| from dataclasses import dataclass | |
| import torch | |
| try: | |
| from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig | |
| from sfast.compilers.diffusion_pipeline_compiler import ( | |
| _enable_xformers, | |
| _modify_model, | |
| ) | |
| from sfast.cuda.graphs import make_dynamic_graphed_callable | |
| from sfast.jit import utils as jit_utils | |
| from sfast.jit.trace_helper import trace_with_kwargs | |
| except: | |
| pass | |
| def hash_arg(arg): | |
| # micro optimization: bool obj is an instance of int | |
| if isinstance(arg, (str, int, float, bytes)): | |
| return arg | |
| if isinstance(arg, (tuple, list)): | |
| return tuple(map(hash_arg, arg)) | |
| if isinstance(arg, dict): | |
| return tuple( | |
| sorted( | |
| ((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0] | |
| ) | |
| ) | |
| return type(arg) | |
| class ModuleFactory: | |
| def get_converted_kwargs(self): | |
| return self.converted_kwargs | |
| import torch as th | |
| import torch.nn as nn | |
| import copy | |
| class BaseModelApplyModelModule(torch.nn.Module): | |
| def __init__(self, func, module): | |
| super().__init__() | |
| self.func = func | |
| self.module = module | |
| def forward( | |
| self, | |
| input_x, | |
| timestep, | |
| c_concat=None, | |
| c_crossattn=None, | |
| y=None, | |
| control=None, | |
| transformer_options={}, | |
| ): | |
| kwargs = {"y": y} | |
| new_transformer_options = {} | |
| return self.func( | |
| input_x, | |
| timestep, | |
| c_concat=c_concat, | |
| c_crossattn=c_crossattn, | |
| control=control, | |
| transformer_options=new_transformer_options, | |
| **kwargs, | |
| ) | |
| class BaseModelApplyModelModuleFactory(ModuleFactory): | |
| kwargs_name = ( | |
| "input_x", | |
| "timestep", | |
| "c_concat", | |
| "c_crossattn", | |
| "y", | |
| "control", | |
| ) | |
| def __init__(self, callable, kwargs) -> None: | |
| self.callable = callable | |
| self.unet_config = callable.__self__.model_config.unet_config | |
| self.kwargs = kwargs | |
| self.patch_module = {} | |
| self.patch_module_parameter = {} | |
| self.converted_kwargs = self.gen_converted_kwargs() | |
| def gen_converted_kwargs(self): | |
| converted_kwargs = {} | |
| for arg_name, arg in self.kwargs.items(): | |
| if arg_name in self.kwargs_name: | |
| converted_kwargs[arg_name] = arg | |
| transformer_options = self.kwargs.get("transformer_options", {}) | |
| patches = transformer_options.get("patches", {}) | |
| patch_module = {} | |
| patch_module_parameter = {} | |
| new_transformer_options = {} | |
| new_transformer_options["patches"] = patch_module_parameter | |
| self.patch_module = patch_module | |
| self.patch_module_parameter = patch_module_parameter | |
| return converted_kwargs | |
| def gen_cache_key(self): | |
| key_kwargs = {} | |
| for k, v in self.converted_kwargs.items(): | |
| key_kwargs[k] = v | |
| patch_module_cache_key = {} | |
| return ( | |
| self.callable.__class__.__qualname__, | |
| hash_arg(self.unet_config), | |
| hash_arg(key_kwargs), | |
| hash_arg(patch_module_cache_key), | |
| ) | |
| def converted_module_context(self): | |
| module = BaseModelApplyModelModule(self.callable, self.callable.__self__) | |
| yield (module, self.converted_kwargs) | |
| logger = logging.getLogger() | |
| class TracedModuleCacheItem: | |
| module: object | |
| patch_id: int | |
| device: str | |
| class LazyTraceModule: | |
| traced_modules = {} | |
| def __init__(self, config=None, patch_id=None, **kwargs_) -> None: | |
| self.config = config | |
| self.patch_id = patch_id | |
| self.kwargs_ = kwargs_ | |
| self.modify_model = functools.partial( | |
| _modify_model, | |
| enable_cnn_optimization=config.enable_cnn_optimization, | |
| prefer_lowp_gemm=config.prefer_lowp_gemm, | |
| enable_triton=config.enable_triton, | |
| enable_triton_reshape=config.enable_triton, | |
| memory_format=config.memory_format, | |
| ) | |
| self.cuda_graph_modules = {} | |
| def ts_compiler( | |
| self, | |
| m, | |
| ): | |
| with torch.jit.optimized_execution(True): | |
| if self.config.enable_jit_freeze: | |
| # raw freeze causes Tensor reference leak | |
| # because the constant Tensors in the GraphFunction of | |
| # the compilation unit are never freed. | |
| m.eval() | |
| m = jit_utils.better_freeze(m) | |
| self.modify_model(m) | |
| if self.config.enable_cuda_graph: | |
| m = make_dynamic_graphed_callable(m) | |
| return m | |
| def __call__(self, model_function, /, **kwargs): | |
| module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs) | |
| kwargs = module_factory.get_converted_kwargs() | |
| key = module_factory.gen_cache_key() | |
| traced_module = self.cuda_graph_modules.get(key) | |
| if traced_module is None: | |
| with module_factory.converted_module_context() as (m_model, m_kwargs): | |
| logger.info( | |
| f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}' | |
| ) | |
| traced_m, call_helper = trace_with_kwargs( | |
| m_model, None, m_kwargs, **self.kwargs_ | |
| ) | |
| traced_m = self.ts_compiler(traced_m) | |
| traced_module = call_helper(traced_m) | |
| self.cuda_graph_modules[key] = traced_module | |
| return traced_module(**kwargs) | |
| def build_lazy_trace_module(config, device, patch_id): | |
| config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda" | |
| if config.enable_xformers: | |
| _enable_xformers(None) | |
| return LazyTraceModule( | |
| config=config, | |
| patch_id=patch_id, | |
| check_trace=True, | |
| strict=True, | |
| ) | |
| def gen_stable_fast_config(): | |
| config = CompilationConfig.Default() | |
| try: | |
| import xformers | |
| config.enable_xformers = True | |
| except ImportError: | |
| print("xformers not installed, skip") | |
| # CUDA Graph is suggested for small batch sizes. | |
| # After capturing, the model only accepts one fixed image size. | |
| # If you want the model to be dynamic, don't enable it. | |
| config.enable_cuda_graph = False | |
| # config.enable_jit_freeze = False | |
| return config | |
| class StableFastPatch: | |
| def __init__(self, model, config): | |
| self.model = model | |
| self.config = config | |
| self.stable_fast_model = None | |
| def __call__(self, model_function, params): | |
| input_x = params.get("input") | |
| timestep_ = params.get("timestep") | |
| c = params.get("c") | |
| if self.stable_fast_model is None: | |
| self.stable_fast_model = build_lazy_trace_module( | |
| self.config, | |
| input_x.device, | |
| id(self), | |
| ) | |
| return self.stable_fast_model( | |
| model_function, input_x=input_x, timestep=timestep_, **c | |
| ) | |
| def to(self, device): | |
| if type(device) == torch.device: | |
| if self.config.enable_cuda_graph or self.config.enable_jit_freeze: | |
| if device.type == "cpu": | |
| del self.stable_fast_model | |
| self.stable_fast_model = None | |
| print( | |
| "\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m" | |
| ) | |
| return self | |
| class ApplyStableFastUnet: | |
| def apply_stable_fast(self, model, enable_cuda_graph): | |
| config = gen_stable_fast_config() | |
| if config.memory_format is not None: | |
| model.model.to(memory_format=config.memory_format) | |
| patch = StableFastPatch(model, config) | |
| model_stable_fast = model.clone() | |
| model_stable_fast.set_model_unet_function_wrapper(patch) | |
| return (model_stable_fast,) |