Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,397 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import contextlib
import functools
import logging
from dataclasses import dataclass
import torch
try:
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig
from sfast.compilers.diffusion_pipeline_compiler import (
_enable_xformers,
_modify_model,
)
from sfast.cuda.graphs import make_dynamic_graphed_callable
from sfast.jit import utils as jit_utils
from sfast.jit.trace_helper import trace_with_kwargs
except:
pass
def hash_arg(arg):
# micro optimization: bool obj is an instance of int
if isinstance(arg, (str, int, float, bytes)):
return arg
if isinstance(arg, (tuple, list)):
return tuple(map(hash_arg, arg))
if isinstance(arg, dict):
return tuple(
sorted(
((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0]
)
)
return type(arg)
class ModuleFactory:
def get_converted_kwargs(self):
return self.converted_kwargs
import torch as th
import torch.nn as nn
import copy
class BaseModelApplyModelModule(torch.nn.Module):
def __init__(self, func, module):
super().__init__()
self.func = func
self.module = module
def forward(
self,
input_x,
timestep,
c_concat=None,
c_crossattn=None,
y=None,
control=None,
transformer_options={},
):
kwargs = {"y": y}
new_transformer_options = {}
return self.func(
input_x,
timestep,
c_concat=c_concat,
c_crossattn=c_crossattn,
control=control,
transformer_options=new_transformer_options,
**kwargs,
)
class BaseModelApplyModelModuleFactory(ModuleFactory):
kwargs_name = (
"input_x",
"timestep",
"c_concat",
"c_crossattn",
"y",
"control",
)
def __init__(self, callable, kwargs) -> None:
self.callable = callable
self.unet_config = callable.__self__.model_config.unet_config
self.kwargs = kwargs
self.patch_module = {}
self.patch_module_parameter = {}
self.converted_kwargs = self.gen_converted_kwargs()
def gen_converted_kwargs(self):
converted_kwargs = {}
for arg_name, arg in self.kwargs.items():
if arg_name in self.kwargs_name:
converted_kwargs[arg_name] = arg
transformer_options = self.kwargs.get("transformer_options", {})
patches = transformer_options.get("patches", {})
patch_module = {}
patch_module_parameter = {}
new_transformer_options = {}
new_transformer_options["patches"] = patch_module_parameter
self.patch_module = patch_module
self.patch_module_parameter = patch_module_parameter
return converted_kwargs
def gen_cache_key(self):
key_kwargs = {}
for k, v in self.converted_kwargs.items():
key_kwargs[k] = v
patch_module_cache_key = {}
return (
self.callable.__class__.__qualname__,
hash_arg(self.unet_config),
hash_arg(key_kwargs),
hash_arg(patch_module_cache_key),
)
@contextlib.contextmanager
def converted_module_context(self):
module = BaseModelApplyModelModule(self.callable, self.callable.__self__)
yield (module, self.converted_kwargs)
logger = logging.getLogger()
@dataclass
class TracedModuleCacheItem:
module: object
patch_id: int
device: str
class LazyTraceModule:
traced_modules = {}
def __init__(self, config=None, patch_id=None, **kwargs_) -> None:
self.config = config
self.patch_id = patch_id
self.kwargs_ = kwargs_
self.modify_model = functools.partial(
_modify_model,
enable_cnn_optimization=config.enable_cnn_optimization,
prefer_lowp_gemm=config.prefer_lowp_gemm,
enable_triton=config.enable_triton,
enable_triton_reshape=config.enable_triton,
memory_format=config.memory_format,
)
self.cuda_graph_modules = {}
def ts_compiler(
self,
m,
):
with torch.jit.optimized_execution(True):
if self.config.enable_jit_freeze:
# raw freeze causes Tensor reference leak
# because the constant Tensors in the GraphFunction of
# the compilation unit are never freed.
m.eval()
m = jit_utils.better_freeze(m)
self.modify_model(m)
if self.config.enable_cuda_graph:
m = make_dynamic_graphed_callable(m)
return m
def __call__(self, model_function, /, **kwargs):
module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs)
kwargs = module_factory.get_converted_kwargs()
key = module_factory.gen_cache_key()
traced_module = self.cuda_graph_modules.get(key)
if traced_module is None:
with module_factory.converted_module_context() as (m_model, m_kwargs):
logger.info(
f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}'
)
traced_m, call_helper = trace_with_kwargs(
m_model, None, m_kwargs, **self.kwargs_
)
traced_m = self.ts_compiler(traced_m)
traced_module = call_helper(traced_m)
self.cuda_graph_modules[key] = traced_module
return traced_module(**kwargs)
def build_lazy_trace_module(config, device, patch_id):
config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"
if config.enable_xformers:
_enable_xformers(None)
return LazyTraceModule(
config=config,
patch_id=patch_id,
check_trace=True,
strict=True,
)
def gen_stable_fast_config():
config = CompilationConfig.Default()
try:
import xformers
config.enable_xformers = True
except ImportError:
print("xformers not installed, skip")
# CUDA Graph is suggested for small batch sizes.
# After capturing, the model only accepts one fixed image size.
# If you want the model to be dynamic, don't enable it.
config.enable_cuda_graph = False
# config.enable_jit_freeze = False
return config
class StableFastPatch:
def __init__(self, model, config):
self.model = model
self.config = config
self.stable_fast_model = None
def __call__(self, model_function, params):
input_x = params.get("input")
timestep_ = params.get("timestep")
c = params.get("c")
if self.stable_fast_model is None:
self.stable_fast_model = build_lazy_trace_module(
self.config,
input_x.device,
id(self),
)
return self.stable_fast_model(
model_function, input_x=input_x, timestep=timestep_, **c
)
def to(self, device):
if type(device) == torch.device:
if self.config.enable_cuda_graph or self.config.enable_jit_freeze:
if device.type == "cpu":
del self.stable_fast_model
self.stable_fast_model = None
print(
"\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m"
)
return self
class ApplyStableFastUnet:
def apply_stable_fast(self, model, enable_cuda_graph):
config = gen_stable_fast_config()
if config.memory_format is not None:
model.model.to(memory_format=config.memory_format)
patch = StableFastPatch(model, config)
model_stable_fast = model.clone()
model_stable_fast.set_model_unet_function_wrapper(patch)
return (model_stable_fast,) |