Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import math | |
import torch | |
from modules.Utilities import Latent | |
from modules.Device import Device | |
from modules.NeuralNetwork import unet | |
from modules.cond import cast, cond | |
from modules.sample import sampling | |
class BaseModel(torch.nn.Module): | |
"""#### Base class for models.""" | |
def __init__( | |
self, | |
model_config: object, | |
model_type: sampling.ModelType = sampling.ModelType.EPS, | |
device: torch.device = None, | |
unet_model: object = unet.UNetModel1, | |
flux: bool = False, | |
): | |
"""#### Initialize the BaseModel class. | |
#### Args: | |
- `model_config` (object): The model configuration. | |
- `model_type` (sampling.ModelType, optional): The model type. Defaults to sampling.ModelType.EPS. | |
- `device` (torch.device, optional): The device to use. Defaults to None. | |
- `unet_model` (object, optional): The UNet model. Defaults to unet.UNetModel1. | |
""" | |
super().__init__() | |
unet_config = model_config.unet_config | |
self.latent_format = model_config.latent_format | |
self.model_config = model_config | |
self.manual_cast_dtype = model_config.manual_cast_dtype | |
self.device = device | |
if flux: | |
if not unet_config.get("disable_unet_model_creation", False): | |
operations = model_config.custom_operations | |
self.diffusion_model = unet_model( | |
**unet_config, device=device, operations=operations | |
) | |
logging.info( | |
"model weight dtype {}, manual cast: {}".format( | |
self.get_dtype(), self.manual_cast_dtype | |
) | |
) | |
else: | |
if not unet_config.get("disable_unet_model_creation", False): | |
if self.manual_cast_dtype is not None: | |
operations = cast.manual_cast | |
else: | |
operations = cast.disable_weight_init | |
self.diffusion_model = unet_model( | |
**unet_config, device=device, operations=operations | |
) | |
self.model_type = model_type | |
self.model_sampling = sampling.model_sampling( | |
model_config, model_type, flux=flux | |
) | |
self.adm_channels = unet_config.get("adm_in_channels", None) | |
if self.adm_channels is None: | |
self.adm_channels = 0 | |
self.concat_keys = () | |
logging.info("model_type {}".format(model_type.name)) | |
logging.debug("adm {}".format(self.adm_channels)) | |
self.memory_usage_factor = model_config.memory_usage_factor if flux else 2.0 | |
def apply_model( | |
self, | |
x: torch.Tensor, | |
t: torch.Tensor, | |
c_concat: torch.Tensor = None, | |
c_crossattn: torch.Tensor = None, | |
control: torch.Tensor = None, | |
transformer_options: dict = {}, | |
**kwargs, | |
) -> torch.Tensor: | |
"""#### Apply the model to the input tensor. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `t` (torch.Tensor): The timestep tensor. | |
- `c_concat` (torch.Tensor, optional): The concatenated condition tensor. Defaults to None. | |
- `c_crossattn` (torch.Tensor, optional): The cross-attention condition tensor. Defaults to None. | |
- `control` (torch.Tensor, optional): The control tensor. Defaults to None. | |
- `transformer_options` (dict, optional): The transformer options. Defaults to {}. | |
- `**kwargs`: Additional keyword arguments. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
sigma = t | |
xc = self.model_sampling.calculate_input(sigma, x) | |
# Optimize concatenation operation by avoiding unnecessary list creation | |
if c_concat is not None: | |
xc = torch.cat((xc, c_concat), dim=1) | |
# Determine dtype once to avoid repeated calls to get_dtype() | |
dtype = ( | |
self.manual_cast_dtype | |
if self.manual_cast_dtype is not None | |
else self.get_dtype() | |
) | |
# Batch operations to reduce overhead | |
xc = xc.to(dtype) | |
t = self.model_sampling.timestep(t).float() | |
context = c_crossattn.to(dtype) if c_crossattn is not None else None | |
# Process extra conditions more efficiently | |
extra_conds = {} | |
for name, value in kwargs.items(): | |
if hasattr(value, "dtype") and value.dtype not in (torch.int, torch.long): | |
extra_conds[name] = value.to(dtype) | |
else: | |
extra_conds[name] = value | |
# Run diffusion model and calculate denoised output | |
model_output = self.diffusion_model( | |
xc, | |
t, | |
context=context, | |
control=control, | |
transformer_options=transformer_options, | |
**extra_conds, | |
).float() | |
return self.model_sampling.calculate_denoised(sigma, model_output, x) | |
def get_dtype(self) -> torch.dtype: | |
"""#### Get the data type of the model. | |
#### Returns: | |
- `torch.dtype`: The data type. | |
""" | |
return self.diffusion_model.dtype | |
def encode_adm(self, **kwargs) -> None: | |
"""#### Encode the ADM. | |
#### Args: | |
- `**kwargs`: Additional keyword arguments. | |
#### Returns: | |
- `None`: The encoded ADM. | |
""" | |
return None | |
def extra_conds(self, **kwargs) -> dict: | |
"""#### Get the extra conditions. | |
#### Args: | |
- `**kwargs`: Additional keyword arguments. | |
#### Returns: | |
- `dict`: The extra conditions. | |
""" | |
out = {} | |
adm = self.encode_adm(**kwargs) | |
if adm is not None: | |
out["y"] = cond.CONDRegular(adm) | |
cross_attn = kwargs.get("cross_attn", None) | |
if cross_attn is not None: | |
out["c_crossattn"] = cond.CONDCrossAttn(cross_attn) | |
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) | |
if cross_attn_cnet is not None: | |
out["crossattn_controlnet"] = cond.CONDCrossAttn(cross_attn_cnet) | |
return out | |
def load_model_weights(self, sd: dict, unet_prefix: str = "") -> "BaseModel": | |
"""#### Load the model weights. | |
#### Args: | |
- `sd` (dict): The state dictionary. | |
- `unet_prefix` (str, optional): The UNet prefix. Defaults to "". | |
#### Returns: | |
- `BaseModel`: The model with loaded weights. | |
""" | |
to_load = {} | |
keys = list(sd.keys()) | |
for k in keys: | |
if k.startswith(unet_prefix): | |
to_load[k[len(unet_prefix) :]] = sd.pop(k) | |
to_load = self.model_config.process_unet_state_dict(to_load) | |
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) | |
if len(m) > 0: | |
logging.warning("unet missing: {}".format(m)) | |
if len(u) > 0: | |
logging.warning("unet unexpected: {}".format(u)) | |
del to_load | |
return self | |
def process_latent_in(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent input. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return self.latent_format.process_in(latent) | |
def process_latent_out(self, latent: torch.Tensor) -> torch.Tensor: | |
"""#### Process the latent output. | |
#### Args: | |
- `latent` (torch.Tensor): The latent tensor. | |
#### Returns: | |
- `torch.Tensor`: The processed latent tensor. | |
""" | |
return self.latent_format.process_out(latent) | |
def memory_required(self, input_shape: tuple) -> float: | |
"""#### Calculate the memory required for the model. | |
#### Args: | |
- `input_shape` (tuple): The input shape. | |
#### Returns: | |
- `float`: The memory required. | |
""" | |
dtype = self.get_dtype() | |
if self.manual_cast_dtype is not None: | |
dtype = self.manual_cast_dtype | |
# TODO: this needs to be tweaked | |
area = input_shape[0] * math.prod(input_shape[2:]) | |
return (area * Device.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * ( | |
1024 * 1024 | |
) | |
class BASE: | |
"""#### Base class for model configurations.""" | |
unet_config = {} | |
unet_extra_config = { | |
"num_heads": -1, | |
"num_head_channels": 64, | |
} | |
required_keys = {} | |
clip_prefix = [] | |
clip_vision_prefix = None | |
noise_aug_config = None | |
sampling_settings = {} | |
latent_format = Latent.LatentFormat | |
vae_key_prefix = ["first_stage_model."] | |
text_encoder_key_prefix = ["cond_stage_model."] | |
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] | |
memory_usage_factor = 2.0 | |
manual_cast_dtype = None | |
custom_operations = None | |
def matches(cls, unet_config: dict, state_dict: dict = None) -> bool: | |
"""#### Check if the UNet configuration matches. | |
#### Args: | |
- `unet_config` (dict): The UNet configuration. | |
- `state_dict` (dict, optional): The state dictionary. Defaults to None. | |
#### Returns: | |
- `bool`: Whether the configuration matches. | |
""" | |
for k in cls.unet_config: | |
if k not in unet_config or cls.unet_config[k] != unet_config[k]: | |
return False | |
if state_dict is not None: | |
for k in cls.required_keys: | |
if k not in state_dict: | |
return False | |
return True | |
def model_type(self, state_dict: dict, prefix: str = "") -> sampling.ModelType: | |
"""#### Get the model type. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
- `prefix` (str, optional): The prefix. Defaults to "". | |
#### Returns: | |
- `sampling.ModelType`: The model type. | |
""" | |
return sampling.ModelType.EPS | |
def inpaint_model(self) -> bool: | |
"""#### Check if the model is an inpaint model. | |
#### Returns: | |
- `bool`: Whether the model is an inpaint model. | |
""" | |
return self.unet_config["in_channels"] > 4 | |
def __init__(self, unet_config: dict): | |
"""#### Initialize the BASE class. | |
#### Args: | |
- `unet_config` (dict): The UNet configuration. | |
""" | |
self.unet_config = unet_config.copy() | |
self.sampling_settings = self.sampling_settings.copy() | |
self.latent_format = self.latent_format() | |
for x in self.unet_extra_config: | |
self.unet_config[x] = self.unet_extra_config[x] | |
def get_model( | |
self, state_dict: dict, prefix: str = "", device: torch.device = None | |
) -> BaseModel: | |
"""#### Get the model. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
- `prefix` (str, optional): The prefix. Defaults to "". | |
- `device` (torch.device, optional): The device to use. Defaults to None. | |
#### Returns: | |
- `BaseModel`: The model. | |
""" | |
out = BaseModel( | |
self, model_type=self.model_type(state_dict, prefix), device=device | |
) | |
return out | |
def process_unet_state_dict(self, state_dict: dict) -> dict: | |
"""#### Process the UNet state dictionary. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
#### Returns: | |
- `dict`: The processed state dictionary. | |
""" | |
return state_dict | |
def process_vae_state_dict(self, state_dict: dict) -> dict: | |
"""#### Process the VAE state dictionary. | |
#### Args: | |
- `state_dict` (dict): The state dictionary. | |
#### Returns: | |
- `dict`: The processed state dictionary. | |
""" | |
return state_dict | |
def set_inference_dtype( | |
self, dtype: torch.dtype, manual_cast_dtype: torch.dtype | |
) -> None: | |
"""#### Set the inference data type. | |
#### Args: | |
- `dtype` (torch.dtype): The data type. | |
- `manual_cast_dtype` (torch.dtype): The manual cast data type. | |
""" | |
self.unet_config["dtype"] = dtype | |
self.manual_cast_dtype = manual_cast_dtype | |