Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import math | |
| import torch | |
| from modules.Utilities import Latent | |
| from modules.Device import Device | |
| from modules.NeuralNetwork import unet | |
| from modules.cond import cast, cond | |
| from modules.sample import sampling | |
| class BaseModel(torch.nn.Module): | |
| """#### Base class for models.""" | |
| def __init__( | |
| self, | |
| model_config: object, | |
| model_type: sampling.ModelType = sampling.ModelType.EPS, | |
| device: torch.device = None, | |
| unet_model: object = unet.UNetModel1, | |
| flux: bool = False, | |
| ): | |
| """#### Initialize the BaseModel class. | |
| #### Args: | |
| - `model_config` (object): The model configuration. | |
| - `model_type` (sampling.ModelType, optional): The model type. Defaults to sampling.ModelType.EPS. | |
| - `device` (torch.device, optional): The device to use. Defaults to None. | |
| - `unet_model` (object, optional): The UNet model. Defaults to unet.UNetModel1. | |
| """ | |
| super().__init__() | |
| unet_config = model_config.unet_config | |
| self.latent_format = model_config.latent_format | |
| self.model_config = model_config | |
| self.manual_cast_dtype = model_config.manual_cast_dtype | |
| self.device = device | |
| if flux: | |
| if not unet_config.get("disable_unet_model_creation", False): | |
| operations = model_config.custom_operations | |
| self.diffusion_model = unet_model( | |
| **unet_config, device=device, operations=operations | |
| ) | |
| logging.info( | |
| "model weight dtype {}, manual cast: {}".format( | |
| self.get_dtype(), self.manual_cast_dtype | |
| ) | |
| ) | |
| else: | |
| if not unet_config.get("disable_unet_model_creation", False): | |
| if self.manual_cast_dtype is not None: | |
| operations = cast.manual_cast | |
| else: | |
| operations = cast.disable_weight_init | |
| self.diffusion_model = unet_model( | |
| **unet_config, device=device, operations=operations | |
| ) | |
| self.model_type = model_type | |
| self.model_sampling = sampling.model_sampling( | |
| model_config, model_type, flux=flux | |
| ) | |
| self.adm_channels = unet_config.get("adm_in_channels", None) | |
| if self.adm_channels is None: | |
| self.adm_channels = 0 | |
| self.concat_keys = () | |
| logging.info("model_type {}".format(model_type.name)) | |
| logging.debug("adm {}".format(self.adm_channels)) | |
| self.memory_usage_factor = model_config.memory_usage_factor if flux else 2.0 | |
| def apply_model( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| c_concat: torch.Tensor = None, | |
| c_crossattn: torch.Tensor = None, | |
| control: torch.Tensor = None, | |
| transformer_options: dict = {}, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """#### Apply the model to the input tensor. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `t` (torch.Tensor): The timestep tensor. | |
| - `c_concat` (torch.Tensor, optional): The concatenated condition tensor. Defaults to None. | |
| - `c_crossattn` (torch.Tensor, optional): The cross-attention condition tensor. Defaults to None. | |
| - `control` (torch.Tensor, optional): The control tensor. Defaults to None. | |
| - `transformer_options` (dict, optional): The transformer options. Defaults to {}. | |
| - `**kwargs`: Additional keyword arguments. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| sigma = t | |
| xc = self.model_sampling.calculate_input(sigma, x) | |
| # Optimize concatenation operation by avoiding unnecessary list creation | |
| if c_concat is not None: | |
| xc = torch.cat((xc, c_concat), dim=1) | |
| # Determine dtype once to avoid repeated calls to get_dtype() | |
| dtype = ( | |
| self.manual_cast_dtype | |
| if self.manual_cast_dtype is not None | |
| else self.get_dtype() | |
| ) | |
| # Batch operations to reduce overhead | |
| xc = xc.to(dtype) | |
| t = self.model_sampling.timestep(t).float() | |
| context = c_crossattn.to(dtype) if c_crossattn is not None else None | |
| # Process extra conditions more efficiently | |
| extra_conds = {} | |
| for name, value in kwargs.items(): | |
| if hasattr(value, "dtype") and value.dtype not in (torch.int, torch.long): | |
| extra_conds[name] = value.to(dtype) | |
| else: | |
| extra_conds[name] = value | |
| # Run diffusion model and calculate denoised output | |
| model_output = self.diffusion_model( | |
| xc, | |
| t, | |
| context=context, | |
| control=control, | |
| transformer_options=transformer_options, | |
| **extra_conds, | |
| ).float() | |
| return self.model_sampling.calculate_denoised(sigma, model_output, x) | |
| def get_dtype(self) -> torch.dtype: | |
| """#### Get the data type of the model. | |
| #### Returns: | |
| - `torch.dtype`: The data type. | |
| """ | |
| return self.diffusion_model.dtype | |
| def encode_adm(self, **kwargs) -> None: | |
| """#### Encode the ADM. | |
| #### Args: | |
| - `**kwargs`: Additional keyword arguments. | |
| #### Returns: | |
| - `None`: The encoded ADM. | |
| """ | |
| return None | |
| def extra_conds(self, **kwargs) -> dict: | |
| """#### Get the extra conditions. | |
| #### Args: | |
| - `**kwargs`: Additional keyword arguments. | |
| #### Returns: | |
| - `dict`: The extra conditions. | |
| """ | |
| out = {} | |
| adm = self.encode_adm(**kwargs) | |
| if adm is not None: | |
| out["y"] = cond.CONDRegular(adm) | |
| cross_attn = kwargs.get("cross_attn", None) | |
| if cross_attn is not None: | |
| out["c_crossattn"] = cond.CONDCrossAttn(cross_attn) | |
| cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) | |
| if cross_attn_cnet is not None: | |
| out["crossattn_controlnet"] = cond.CONDCrossAttn(cross_attn_cnet) | |
| return out | |
| def load_model_weights(self, sd: dict, unet_prefix: str = "") -> "BaseModel": | |
| """#### Load the model weights. | |
| #### Args: | |
| - `sd` (dict): The state dictionary. | |
| - `unet_prefix` (str, optional): The UNet prefix. Defaults to "". | |
| #### Returns: | |
| - `BaseModel`: The model with loaded weights. | |
| """ | |
| to_load = {} | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| if k.startswith(unet_prefix): | |
| to_load[k[len(unet_prefix) :]] = sd.pop(k) | |
| to_load = self.model_config.process_unet_state_dict(to_load) | |
| m, u = self.diffusion_model.load_state_dict(to_load, strict=False) | |
| if len(m) > 0: | |
| logging.warning("unet missing: {}".format(m)) | |
| if len(u) > 0: | |
| logging.warning("unet unexpected: {}".format(u)) | |
| del to_load | |
| return self | |
| def process_latent_in(self, latent: torch.Tensor) -> torch.Tensor: | |
| """#### Process the latent input. | |
| #### Args: | |
| - `latent` (torch.Tensor): The latent tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The processed latent tensor. | |
| """ | |
| return self.latent_format.process_in(latent) | |
| def process_latent_out(self, latent: torch.Tensor) -> torch.Tensor: | |
| """#### Process the latent output. | |
| #### Args: | |
| - `latent` (torch.Tensor): The latent tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The processed latent tensor. | |
| """ | |
| return self.latent_format.process_out(latent) | |
| def memory_required(self, input_shape: tuple) -> float: | |
| """#### Calculate the memory required for the model. | |
| #### Args: | |
| - `input_shape` (tuple): The input shape. | |
| #### Returns: | |
| - `float`: The memory required. | |
| """ | |
| dtype = self.get_dtype() | |
| if self.manual_cast_dtype is not None: | |
| dtype = self.manual_cast_dtype | |
| # TODO: this needs to be tweaked | |
| area = input_shape[0] * math.prod(input_shape[2:]) | |
| return (area * Device.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * ( | |
| 1024 * 1024 | |
| ) | |
| class BASE: | |
| """#### Base class for model configurations.""" | |
| unet_config = {} | |
| unet_extra_config = { | |
| "num_heads": -1, | |
| "num_head_channels": 64, | |
| } | |
| required_keys = {} | |
| clip_prefix = [] | |
| clip_vision_prefix = None | |
| noise_aug_config = None | |
| sampling_settings = {} | |
| latent_format = Latent.LatentFormat | |
| vae_key_prefix = ["first_stage_model."] | |
| text_encoder_key_prefix = ["cond_stage_model."] | |
| supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] | |
| memory_usage_factor = 2.0 | |
| manual_cast_dtype = None | |
| custom_operations = None | |
| def matches(cls, unet_config: dict, state_dict: dict = None) -> bool: | |
| """#### Check if the UNet configuration matches. | |
| #### Args: | |
| - `unet_config` (dict): The UNet configuration. | |
| - `state_dict` (dict, optional): The state dictionary. Defaults to None. | |
| #### Returns: | |
| - `bool`: Whether the configuration matches. | |
| """ | |
| for k in cls.unet_config: | |
| if k not in unet_config or cls.unet_config[k] != unet_config[k]: | |
| return False | |
| if state_dict is not None: | |
| for k in cls.required_keys: | |
| if k not in state_dict: | |
| return False | |
| return True | |
| def model_type(self, state_dict: dict, prefix: str = "") -> sampling.ModelType: | |
| """#### Get the model type. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| - `prefix` (str, optional): The prefix. Defaults to "". | |
| #### Returns: | |
| - `sampling.ModelType`: The model type. | |
| """ | |
| return sampling.ModelType.EPS | |
| def inpaint_model(self) -> bool: | |
| """#### Check if the model is an inpaint model. | |
| #### Returns: | |
| - `bool`: Whether the model is an inpaint model. | |
| """ | |
| return self.unet_config["in_channels"] > 4 | |
| def __init__(self, unet_config: dict): | |
| """#### Initialize the BASE class. | |
| #### Args: | |
| - `unet_config` (dict): The UNet configuration. | |
| """ | |
| self.unet_config = unet_config.copy() | |
| self.sampling_settings = self.sampling_settings.copy() | |
| self.latent_format = self.latent_format() | |
| for x in self.unet_extra_config: | |
| self.unet_config[x] = self.unet_extra_config[x] | |
| def get_model( | |
| self, state_dict: dict, prefix: str = "", device: torch.device = None | |
| ) -> BaseModel: | |
| """#### Get the model. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| - `prefix` (str, optional): The prefix. Defaults to "". | |
| - `device` (torch.device, optional): The device to use. Defaults to None. | |
| #### Returns: | |
| - `BaseModel`: The model. | |
| """ | |
| out = BaseModel( | |
| self, model_type=self.model_type(state_dict, prefix), device=device | |
| ) | |
| return out | |
| def process_unet_state_dict(self, state_dict: dict) -> dict: | |
| """#### Process the UNet state dictionary. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| #### Returns: | |
| - `dict`: The processed state dictionary. | |
| """ | |
| return state_dict | |
| def process_vae_state_dict(self, state_dict: dict) -> dict: | |
| """#### Process the VAE state dictionary. | |
| #### Args: | |
| - `state_dict` (dict): The state dictionary. | |
| #### Returns: | |
| - `dict`: The processed state dictionary. | |
| """ | |
| return state_dict | |
| def set_inference_dtype( | |
| self, dtype: torch.dtype, manual_cast_dtype: torch.dtype | |
| ) -> None: | |
| """#### Set the inference data type. | |
| #### Args: | |
| - `dtype` (torch.dtype): The data type. | |
| - `manual_cast_dtype` (torch.dtype): The manual cast data type. | |
| """ | |
| self.unet_config["dtype"] = dtype | |
| self.manual_cast_dtype = manual_cast_dtype | |