Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import torch | |
import comfy.utils | |
from comfy.patcher_extension import WrappersMP | |
from typing import TYPE_CHECKING, Callable, Optional | |
if TYPE_CHECKING: | |
from comfy.model_patcher import ModelPatcher | |
from comfy.patcher_extension import WrapperExecutor | |
COMPILE_KEY = "torch.compile" | |
TORCH_COMPILE_KWARGS = "torch_compile_kwargs" | |
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable: | |
''' | |
Create a wrapper that will refer to the compiled_diffusion_model. | |
''' | |
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs): | |
try: | |
orig_modules = {} | |
for key, value in compiled_module_dict.items(): | |
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key) | |
comfy.utils.set_attr(executor.class_obj, key, value) | |
return executor(*args, **kwargs) | |
finally: | |
for key, value in orig_modules.items(): | |
comfy.utils.set_attr(executor.class_obj, key, value) | |
return apply_torch_compile_wrapper | |
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None, | |
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None, | |
keys: list[str]=["diffusion_model"], *args, **kwargs): | |
''' | |
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance. | |
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model. | |
When a list of keys is provided, it will perform torch.compile on only the selected modules. | |
''' | |
# clear out any other torch.compile wrappers | |
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY) | |
# if no keys, default to 'diffusion_model' | |
if not keys: | |
keys = ["diffusion_model"] | |
# create kwargs dict that can be referenced later | |
compile_kwargs = { | |
"backend": backend, | |
"options": options, | |
"mode": mode, | |
"fullgraph": fullgraph, | |
"dynamic": dynamic, | |
} | |
# get a dict of compiled keys | |
compiled_modules = {} | |
for key in keys: | |
compiled_modules[key] = torch.compile( | |
model=model.get_model_object(key), | |
**compile_kwargs, | |
) | |
# add torch.compile wrapper | |
wrapper_func = apply_torch_compile_factory( | |
compiled_module_dict=compiled_modules, | |
) | |
# store wrapper to run on BaseModel's apply_model function | |
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func) | |
# keep compile kwargs for reference | |
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs | |