Spaces:
Running
on
Zero
Running
on
Zero
import contextlib | |
import functools | |
import logging | |
from dataclasses import dataclass | |
import torch | |
try: | |
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig | |
from sfast.compilers.diffusion_pipeline_compiler import ( | |
_enable_xformers, | |
_modify_model, | |
) | |
from sfast.cuda.graphs import make_dynamic_graphed_callable | |
from sfast.jit import utils as jit_utils | |
from sfast.jit.trace_helper import trace_with_kwargs | |
except: | |
pass | |
def hash_arg(arg): | |
# micro optimization: bool obj is an instance of int | |
if isinstance(arg, (str, int, float, bytes)): | |
return arg | |
if isinstance(arg, (tuple, list)): | |
return tuple(map(hash_arg, arg)) | |
if isinstance(arg, dict): | |
return tuple( | |
sorted( | |
((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0] | |
) | |
) | |
return type(arg) | |
class ModuleFactory: | |
def get_converted_kwargs(self): | |
return self.converted_kwargs | |
import torch as th | |
import torch.nn as nn | |
import copy | |
class BaseModelApplyModelModule(torch.nn.Module): | |
def __init__(self, func, module): | |
super().__init__() | |
self.func = func | |
self.module = module | |
def forward( | |
self, | |
input_x, | |
timestep, | |
c_concat=None, | |
c_crossattn=None, | |
y=None, | |
control=None, | |
transformer_options={}, | |
): | |
kwargs = {"y": y} | |
new_transformer_options = {} | |
return self.func( | |
input_x, | |
timestep, | |
c_concat=c_concat, | |
c_crossattn=c_crossattn, | |
control=control, | |
transformer_options=new_transformer_options, | |
**kwargs, | |
) | |
class BaseModelApplyModelModuleFactory(ModuleFactory): | |
kwargs_name = ( | |
"input_x", | |
"timestep", | |
"c_concat", | |
"c_crossattn", | |
"y", | |
"control", | |
) | |
def __init__(self, callable, kwargs) -> None: | |
self.callable = callable | |
self.unet_config = callable.__self__.model_config.unet_config | |
self.kwargs = kwargs | |
self.patch_module = {} | |
self.patch_module_parameter = {} | |
self.converted_kwargs = self.gen_converted_kwargs() | |
def gen_converted_kwargs(self): | |
converted_kwargs = {} | |
for arg_name, arg in self.kwargs.items(): | |
if arg_name in self.kwargs_name: | |
converted_kwargs[arg_name] = arg | |
transformer_options = self.kwargs.get("transformer_options", {}) | |
patches = transformer_options.get("patches", {}) | |
patch_module = {} | |
patch_module_parameter = {} | |
new_transformer_options = {} | |
new_transformer_options["patches"] = patch_module_parameter | |
self.patch_module = patch_module | |
self.patch_module_parameter = patch_module_parameter | |
return converted_kwargs | |
def gen_cache_key(self): | |
key_kwargs = {} | |
for k, v in self.converted_kwargs.items(): | |
key_kwargs[k] = v | |
patch_module_cache_key = {} | |
return ( | |
self.callable.__class__.__qualname__, | |
hash_arg(self.unet_config), | |
hash_arg(key_kwargs), | |
hash_arg(patch_module_cache_key), | |
) | |
def converted_module_context(self): | |
module = BaseModelApplyModelModule(self.callable, self.callable.__self__) | |
yield (module, self.converted_kwargs) | |
logger = logging.getLogger() | |
class TracedModuleCacheItem: | |
module: object | |
patch_id: int | |
device: str | |
class LazyTraceModule: | |
traced_modules = {} | |
def __init__(self, config=None, patch_id=None, **kwargs_) -> None: | |
self.config = config | |
self.patch_id = patch_id | |
self.kwargs_ = kwargs_ | |
self.modify_model = functools.partial( | |
_modify_model, | |
enable_cnn_optimization=config.enable_cnn_optimization, | |
prefer_lowp_gemm=config.prefer_lowp_gemm, | |
enable_triton=config.enable_triton, | |
enable_triton_reshape=config.enable_triton, | |
memory_format=config.memory_format, | |
) | |
self.cuda_graph_modules = {} | |
def ts_compiler( | |
self, | |
m, | |
): | |
with torch.jit.optimized_execution(True): | |
if self.config.enable_jit_freeze: | |
# raw freeze causes Tensor reference leak | |
# because the constant Tensors in the GraphFunction of | |
# the compilation unit are never freed. | |
m.eval() | |
m = jit_utils.better_freeze(m) | |
self.modify_model(m) | |
if self.config.enable_cuda_graph: | |
m = make_dynamic_graphed_callable(m) | |
return m | |
def __call__(self, model_function, /, **kwargs): | |
module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs) | |
kwargs = module_factory.get_converted_kwargs() | |
key = module_factory.gen_cache_key() | |
traced_module = self.cuda_graph_modules.get(key) | |
if traced_module is None: | |
with module_factory.converted_module_context() as (m_model, m_kwargs): | |
logger.info( | |
f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}' | |
) | |
traced_m, call_helper = trace_with_kwargs( | |
m_model, None, m_kwargs, **self.kwargs_ | |
) | |
traced_m = self.ts_compiler(traced_m) | |
traced_module = call_helper(traced_m) | |
self.cuda_graph_modules[key] = traced_module | |
return traced_module(**kwargs) | |
def build_lazy_trace_module(config, device, patch_id): | |
config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda" | |
if config.enable_xformers: | |
_enable_xformers(None) | |
return LazyTraceModule( | |
config=config, | |
patch_id=patch_id, | |
check_trace=True, | |
strict=True, | |
) | |
def gen_stable_fast_config(): | |
config = CompilationConfig.Default() | |
try: | |
import xformers | |
config.enable_xformers = True | |
except ImportError: | |
print("xformers not installed, skip") | |
# CUDA Graph is suggested for small batch sizes. | |
# After capturing, the model only accepts one fixed image size. | |
# If you want the model to be dynamic, don't enable it. | |
config.enable_cuda_graph = False | |
# config.enable_jit_freeze = False | |
return config | |
class StableFastPatch: | |
def __init__(self, model, config): | |
self.model = model | |
self.config = config | |
self.stable_fast_model = None | |
def __call__(self, model_function, params): | |
input_x = params.get("input") | |
timestep_ = params.get("timestep") | |
c = params.get("c") | |
if self.stable_fast_model is None: | |
self.stable_fast_model = build_lazy_trace_module( | |
self.config, | |
input_x.device, | |
id(self), | |
) | |
return self.stable_fast_model( | |
model_function, input_x=input_x, timestep=timestep_, **c | |
) | |
def to(self, device): | |
if type(device) == torch.device: | |
if self.config.enable_cuda_graph or self.config.enable_jit_freeze: | |
if device.type == "cpu": | |
del self.stable_fast_model | |
self.stable_fast_model = None | |
print( | |
"\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m" | |
) | |
return self | |
class ApplyStableFastUnet: | |
def apply_stable_fast(self, model, enable_cuda_graph): | |
config = gen_stable_fast_config() | |
if config.memory_format is not None: | |
model.model.to(memory_format=config.memory_format) | |
patch = StableFastPatch(model, config) | |
model_stable_fast = model.clone() | |
model_stable_fast.set_model_unet_function_wrapper(patch) | |
return (model_stable_fast,) |