Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import contextlib
import functools
import logging
from dataclasses import dataclass
import torch
try:
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig
from sfast.compilers.diffusion_pipeline_compiler import (
_enable_xformers,
_modify_model,
)
from sfast.cuda.graphs import make_dynamic_graphed_callable
from sfast.jit import utils as jit_utils
from sfast.jit.trace_helper import trace_with_kwargs
except:
pass
def hash_arg(arg):
# micro optimization: bool obj is an instance of int
if isinstance(arg, (str, int, float, bytes)):
return arg
if isinstance(arg, (tuple, list)):
return tuple(map(hash_arg, arg))
if isinstance(arg, dict):
return tuple(
sorted(
((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0]
)
)
return type(arg)
class ModuleFactory:
def get_converted_kwargs(self):
return self.converted_kwargs
import torch as th
import torch.nn as nn
import copy
class BaseModelApplyModelModule(torch.nn.Module):
def __init__(self, func, module):
super().__init__()
self.func = func
self.module = module
def forward(
self,
input_x,
timestep,
c_concat=None,
c_crossattn=None,
y=None,
control=None,
transformer_options={},
):
kwargs = {"y": y}
new_transformer_options = {}
return self.func(
input_x,
timestep,
c_concat=c_concat,
c_crossattn=c_crossattn,
control=control,
transformer_options=new_transformer_options,
**kwargs,
)
class BaseModelApplyModelModuleFactory(ModuleFactory):
kwargs_name = (
"input_x",
"timestep",
"c_concat",
"c_crossattn",
"y",
"control",
)
def __init__(self, callable, kwargs) -> None:
self.callable = callable
self.unet_config = callable.__self__.model_config.unet_config
self.kwargs = kwargs
self.patch_module = {}
self.patch_module_parameter = {}
self.converted_kwargs = self.gen_converted_kwargs()
def gen_converted_kwargs(self):
converted_kwargs = {}
for arg_name, arg in self.kwargs.items():
if arg_name in self.kwargs_name:
converted_kwargs[arg_name] = arg
transformer_options = self.kwargs.get("transformer_options", {})
patches = transformer_options.get("patches", {})
patch_module = {}
patch_module_parameter = {}
new_transformer_options = {}
new_transformer_options["patches"] = patch_module_parameter
self.patch_module = patch_module
self.patch_module_parameter = patch_module_parameter
return converted_kwargs
def gen_cache_key(self):
key_kwargs = {}
for k, v in self.converted_kwargs.items():
key_kwargs[k] = v
patch_module_cache_key = {}
return (
self.callable.__class__.__qualname__,
hash_arg(self.unet_config),
hash_arg(key_kwargs),
hash_arg(patch_module_cache_key),
)
@contextlib.contextmanager
def converted_module_context(self):
module = BaseModelApplyModelModule(self.callable, self.callable.__self__)
yield (module, self.converted_kwargs)
logger = logging.getLogger()
@dataclass
class TracedModuleCacheItem:
module: object
patch_id: int
device: str
class LazyTraceModule:
traced_modules = {}
def __init__(self, config=None, patch_id=None, **kwargs_) -> None:
self.config = config
self.patch_id = patch_id
self.kwargs_ = kwargs_
self.modify_model = functools.partial(
_modify_model,
enable_cnn_optimization=config.enable_cnn_optimization,
prefer_lowp_gemm=config.prefer_lowp_gemm,
enable_triton=config.enable_triton,
enable_triton_reshape=config.enable_triton,
memory_format=config.memory_format,
)
self.cuda_graph_modules = {}
def ts_compiler(
self,
m,
):
with torch.jit.optimized_execution(True):
if self.config.enable_jit_freeze:
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m.eval()
m = jit_utils.better_freeze(m)
self.modify_model(m)
if self.config.enable_cuda_graph:
m = make_dynamic_graphed_callable(m)
return m
def __call__(self, model_function, /, **kwargs):
module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs)
kwargs = module_factory.get_converted_kwargs()
key = module_factory.gen_cache_key()
traced_module = self.cuda_graph_modules.get(key)
if traced_module is None:
with module_factory.converted_module_context() as (m_model, m_kwargs):
logger.info(
f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}'
)
traced_m, call_helper = trace_with_kwargs(
m_model, None, m_kwargs, **self.kwargs_
)
traced_m = self.ts_compiler(traced_m)
traced_module = call_helper(traced_m)
self.cuda_graph_modules[key] = traced_module
return traced_module(**kwargs)
def build_lazy_trace_module(config, device, patch_id):
config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"
if config.enable_xformers:
_enable_xformers(None)
return LazyTraceModule(
config=config,
patch_id=patch_id,
check_trace=True,
strict=True,
)
def gen_stable_fast_config():
config = CompilationConfig.Default()
try:
import xformers
config.enable_xformers = True
except ImportError:
print("xformers not installed, skip")
# CUDA Graph is suggested for small batch sizes.
# After capturing, the model only accepts one fixed image size.
# If you want the model to be dynamic, don't enable it.
config.enable_cuda_graph = False
# config.enable_jit_freeze = False
return config
class StableFastPatch:
def __init__(self, model, config):
self.model = model
self.config = config
self.stable_fast_model = None
def __call__(self, model_function, params):
input_x = params.get("input")
timestep_ = params.get("timestep")
c = params.get("c")
if self.stable_fast_model is None:
self.stable_fast_model = build_lazy_trace_module(
self.config,
input_x.device,
id(self),
)
return self.stable_fast_model(
model_function, input_x=input_x, timestep=timestep_, **c
)
def to(self, device):
if type(device) == torch.device:
if self.config.enable_cuda_graph or self.config.enable_jit_freeze:
if device.type == "cpu":
del self.stable_fast_model
self.stable_fast_model = None
print(
"\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m"
)
return self
class ApplyStableFastUnet:
def apply_stable_fast(self, model, enable_cuda_graph):
config = gen_stable_fast_config()
if config.memory_format is not None:
model.model.to(memory_format=config.memory_format)
patch = StableFastPatch(model, config)
model_stable_fast = model.clone()
model_stable_fast.set_model_unet_function_wrapper(patch)
return (model_stable_fast,)