Spaces:
Running
on
Zero
Running
on
Zero
| from enum import Enum | |
| import logging | |
| import torch | |
| from modules.Model import ModelPatcher | |
| from modules.Attention import Attention | |
| from modules.Device import Device | |
| from modules.SD15 import SDToken | |
| from modules.Utilities import util | |
| from modules.clip import FluxClip | |
| from modules.cond import cast | |
| class CLIPAttention(torch.nn.Module): | |
| """#### The CLIPAttention module.""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| heads: int, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPAttention module. | |
| #### Args: | |
| - `embed_dim` (int): The embedding dimension. | |
| - `heads` (int): The number of attention heads. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| super().__init__() | |
| self.heads = heads | |
| self.q_proj = operations.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.k_proj = operations.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.v_proj = operations.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.out_proj = operations.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor = None, | |
| optimized_attention: callable = None, | |
| ) -> torch.Tensor: | |
| """#### Forward pass for the CLIPAttention module. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
| - `optimized_attention` (callable, optional): The optimized attention function. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| out = optimized_attention(q, k, v, self.heads, mask) | |
| return self.out_proj(out) | |
| ACTIVATIONS = { | |
| "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), | |
| "gelu": torch.nn.functional.gelu, | |
| } | |
| class CLIPMLP(torch.nn.Module): | |
| """#### The CLIPMLP module. | |
| (MLP stands for Multi-Layer Perceptron.)""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| intermediate_size: int, | |
| activation: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPMLP module. | |
| #### Args: | |
| - `embed_dim` (int): The embedding dimension. | |
| - `intermediate_size` (int): The intermediate size. | |
| - `activation` (str): The activation function. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| super().__init__() | |
| self.fc1 = operations.Linear( | |
| embed_dim, intermediate_size, bias=True, dtype=dtype, device=device | |
| ) | |
| self.activation = ACTIVATIONS[activation] | |
| self.fc2 = operations.Linear( | |
| intermediate_size, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """#### Forward pass for the CLIPMLP module. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| x = self.fc1(x) | |
| x = self.activation(x) | |
| x = self.fc2(x) | |
| return x | |
| class CLIPLayer(torch.nn.Module): | |
| """#### The CLIPLayer module.""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| heads: int, | |
| intermediate_size: int, | |
| intermediate_activation: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPLayer module. | |
| #### Args: | |
| - `embed_dim` (int): The embedding dimension. | |
| - `heads` (int): The number of attention heads. | |
| - `intermediate_size` (int): The intermediate size. | |
| - `intermediate_activation` (str): The intermediate activation function. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| super().__init__() | |
| self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) | |
| self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations) | |
| self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) | |
| self.mlp = CLIPMLP( | |
| embed_dim, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| operations, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor = None, | |
| optimized_attention: callable = None, | |
| ) -> torch.Tensor: | |
| """#### Forward pass for the CLIPLayer module. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
| - `optimized_attention` (callable, optional): The optimized attention function. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| x += self.self_attn(self.layer_norm1(x), mask, optimized_attention) | |
| x += self.mlp(self.layer_norm2(x)) | |
| return x | |
| class CLIPEncoder(torch.nn.Module): | |
| """#### The CLIPEncoder module.""" | |
| def __init__( | |
| self, | |
| num_layers: int, | |
| embed_dim: int, | |
| heads: int, | |
| intermediate_size: int, | |
| intermediate_activation: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPEncoder module. | |
| #### Args: | |
| - `num_layers` (int): The number of layers. | |
| - `embed_dim` (int): The embedding dimension. | |
| - `heads` (int): The number of attention heads. | |
| - `intermediate_size` (int): The intermediate size. | |
| - `intermediate_activation` (str): The intermediate activation function. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| super().__init__() | |
| self.layers = torch.nn.ModuleList( | |
| [ | |
| CLIPLayer( | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| operations, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor = None, | |
| intermediate_output: int = None, | |
| ) -> tuple: | |
| """#### Forward pass for the CLIPEncoder module. | |
| #### Args: | |
| - `x` (torch.Tensor): The input tensor. | |
| - `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
| - `intermediate_output` (int, optional): The intermediate output layer. Defaults to None. | |
| #### Returns: | |
| - `tuple`: The output tensor and the intermediate output tensor. | |
| """ | |
| optimized_attention = Attention.optimized_attention_for_device() | |
| if intermediate_output is not None: | |
| if intermediate_output < 0: | |
| intermediate_output = len(self.layers) + intermediate_output | |
| intermediate = None | |
| for i, length in enumerate(self.layers): | |
| x = length(x, mask, optimized_attention) | |
| if i == intermediate_output: | |
| intermediate = x.clone() | |
| return x, intermediate | |
| class CLIPEmbeddings(torch.nn.Module): | |
| """#### The CLIPEmbeddings module.""" | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| vocab_size: int = 49408, | |
| num_positions: int = 77, | |
| dtype: torch.dtype = None, | |
| device: torch.device = None, | |
| operations: object = torch.nn, | |
| ): | |
| """#### Initialize the CLIPEmbeddings module. | |
| #### Args: | |
| - `embed_dim` (int): The embedding dimension. | |
| - `vocab_size` (int, optional): The vocabulary size. Defaults to 49408. | |
| - `num_positions` (int, optional): The number of positions. Defaults to 77. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `device` (torch.device, optional): The device to use. Defaults to None. | |
| """ | |
| super().__init__() | |
| self.token_embedding = operations.Embedding( | |
| vocab_size, embed_dim, dtype=dtype, device=device | |
| ) | |
| self.position_embedding = operations.Embedding( | |
| num_positions, embed_dim, dtype=dtype, device=device | |
| ) | |
| def forward(self, input_tokens: torch.Tensor, dtype=torch.float32) -> torch.Tensor: | |
| """#### Forward pass for the CLIPEmbeddings module. | |
| #### Args: | |
| - `input_tokens` (torch.Tensor): The input tokens. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to torch.float32. | |
| #### Returns: | |
| - `torch.Tensor`: The output tensor. | |
| """ | |
| return self.token_embedding(input_tokens, out_dtype=dtype) + cast.cast_to( | |
| self.position_embedding.weight, dtype=dtype, device=input_tokens.device | |
| ) | |
| class CLIP: | |
| """#### The CLIP class.""" | |
| def __init__( | |
| self, | |
| target: object = None, | |
| embedding_directory: str = None, | |
| no_init: bool = False, | |
| tokenizer_data={}, | |
| parameters=0, | |
| model_options={}, | |
| ): | |
| """#### Initialize the CLIP class. | |
| #### Args: | |
| - `target` (object, optional): The target object. Defaults to None. | |
| - `embedding_directory` (str, optional): The embedding directory. Defaults to None. | |
| - `no_init` (bool, optional): Whether to skip initialization. Defaults to False. | |
| """ | |
| if no_init: | |
| return | |
| params = target.params.copy() | |
| clip = target.clip | |
| tokenizer = target.tokenizer | |
| load_device = model_options.get("load_device", Device.text_encoder_device()) | |
| offload_device = model_options.get( | |
| "offload_device", Device.text_encoder_offload_device() | |
| ) | |
| dtype = model_options.get("dtype", None) | |
| if dtype is None: | |
| dtype = Device.text_encoder_dtype(load_device) | |
| params["dtype"] = dtype | |
| params["device"] = model_options.get( | |
| "initial_device", | |
| Device.text_encoder_initial_device( | |
| load_device, offload_device, parameters * Device.dtype_size(dtype) | |
| ), | |
| ) | |
| params["model_options"] = model_options | |
| self.cond_stage_model = clip(**(params)) | |
| # for dt in self.cond_stage_model.dtypes: | |
| # if not Device.supports_cast(load_device, dt): | |
| # load_device = offload_device | |
| # if params["device"] != offload_device: | |
| # self.cond_stage_model.to(offload_device) | |
| # logging.warning("Had to shift TE back.") | |
| try: | |
| self.tokenizer = tokenizer( | |
| embedding_directory=embedding_directory, tokenizer_data=tokenizer_data | |
| ) | |
| except TypeError: | |
| self.tokenizer = tokenizer( | |
| embedding_directory=embedding_directory | |
| ) | |
| self.patcher = ModelPatcher.ModelPatcher( | |
| self.cond_stage_model, | |
| load_device=load_device, | |
| offload_device=offload_device, | |
| ) | |
| if params["device"] == load_device: | |
| Device.load_models_gpu([self.patcher], force_full_load=True, flux_enabled=True) | |
| self.layer_idx = None | |
| logging.debug( | |
| "CLIP model load device: {}, offload device: {}, current: {}".format( | |
| load_device, offload_device, params["device"] | |
| ) | |
| ) | |
| def clone(self) -> "CLIP": | |
| """#### Clone the CLIP object. | |
| #### Returns: | |
| - `CLIP`: The cloned CLIP object. | |
| """ | |
| n = CLIP(no_init=True) | |
| n.patcher = self.patcher.clone() | |
| n.cond_stage_model = self.cond_stage_model | |
| n.tokenizer = self.tokenizer | |
| n.layer_idx = self.layer_idx | |
| return n | |
| def add_patches( | |
| self, patches: list, strength_patch: float = 1.0, strength_model: float = 1.0 | |
| ) -> None: | |
| """#### Add patches to the model. | |
| #### Args: | |
| - `patches` (list): The patches to add. | |
| - `strength_patch` (float, optional): The strength of the patches. Defaults to 1.0. | |
| - `strength_model` (float, optional): The strength of the model. Defaults to 1.0. | |
| """ | |
| return self.patcher.add_patches(patches, strength_patch, strength_model) | |
| def clip_layer(self, layer_idx: int) -> None: | |
| """#### Set the clip layer. | |
| #### Args: | |
| - `layer_idx` (int): The layer index. | |
| """ | |
| self.layer_idx = layer_idx | |
| def tokenize(self, text: str, return_word_ids: bool = False) -> list: | |
| """#### Tokenize the input text. | |
| #### Args: | |
| - `text` (str): The input text. | |
| - `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False. | |
| #### Returns: | |
| - `list`: The tokenized text. | |
| """ | |
| return self.tokenizer.tokenize_with_weights(text, return_word_ids) | |
| def encode_from_tokens(self, tokens: list, return_pooled: bool = False, return_dict: bool = False, flux_enabled:bool = False) -> tuple: | |
| """#### Encode the input tokens. | |
| #### Args: | |
| - `tokens` (list): The input tokens. | |
| - `return_pooled` (bool, optional): Whether to return the pooled output. Defaults to False. | |
| - `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False. | |
| #### Returns: | |
| - `tuple`: The encoded tokens and the pooled output. | |
| """ | |
| self.cond_stage_model.reset_clip_options() | |
| if self.layer_idx is not None: | |
| self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) | |
| if return_pooled == "unprojected": | |
| self.cond_stage_model.set_clip_options({"projected_pooled": False}) | |
| self.load_model(flux_enabled=flux_enabled) | |
| o = self.cond_stage_model.encode_token_weights(tokens) | |
| cond, pooled = o[:2] | |
| if return_dict: | |
| out = {"cond": cond, "pooled_output": pooled} | |
| if len(o) > 2: | |
| for k in o[2]: | |
| out[k] = o[2][k] | |
| return out | |
| if return_pooled: | |
| return cond, pooled | |
| return cond | |
| def load_sd(self, sd: dict, full_model: bool = False) -> None: | |
| """#### Load the state dictionary. | |
| #### Args: | |
| - `sd` (dict): The state dictionary. | |
| - `full_model` (bool, optional): Whether to load the full model. Defaults to False. | |
| """ | |
| if full_model: | |
| return self.cond_stage_model.load_state_dict(sd, strict=False) | |
| else: | |
| return self.cond_stage_model.load_sd(sd) | |
| def load_model(self, flux_enabled:bool = False) -> ModelPatcher: | |
| """#### Load the model. | |
| #### Returns: | |
| - `ModelPatcher`: The model patcher. | |
| """ | |
| Device.load_model_gpu(self.patcher, flux_enabled=flux_enabled) | |
| return self.patcher | |
| def encode(self, text): | |
| """#### Encode the input text. | |
| #### Args: | |
| - `text` (str): The input text. | |
| #### Returns: | |
| - `torch.Tensor`: The encoded text. | |
| """ | |
| tokens = self.tokenize(text) | |
| return self.encode_from_tokens(tokens) | |
| def get_sd(self): | |
| """#### Get the state dictionary. | |
| #### Returns: | |
| - `dict`: The state dictionary. | |
| """ | |
| sd_clip = self.cond_stage_model.state_dict() | |
| sd_tokenizer = self.tokenizer.state_dict() | |
| for k in sd_tokenizer: | |
| sd_clip[k] = sd_tokenizer[k] | |
| return sd_clip | |
| def get_key_patches(self): | |
| """#### Get the key patches. | |
| #### Returns: | |
| - `list`: The key patches. | |
| """ | |
| return self.patcher.get_key_patches() | |
| class CLIPType(Enum): | |
| STABLE_DIFFUSION = 1 | |
| SD3 = 3 | |
| FLUX = 6 | |
| def load_text_encoder_state_dicts( | |
| state_dicts=[], | |
| embedding_directory=None, | |
| clip_type=CLIPType.STABLE_DIFFUSION, | |
| model_options={}, | |
| ): | |
| """#### Load the text encoder state dictionaries. | |
| #### Args: | |
| - `state_dicts` (list, optional): The state dictionaries. Defaults to []. | |
| - `embedding_directory` (str, optional): The embedding directory. Defaults to None. | |
| - `clip_type` (CLIPType, optional): The CLIP type. Defaults to CLIPType.STABLE_DIFFUSION. | |
| - `model_options` (dict, optional): The model options. Defaults to {}. | |
| #### Returns: | |
| - `CLIP`: The CLIP object. | |
| """ | |
| clip_data = state_dicts | |
| class EmptyClass: | |
| pass | |
| for i in range(len(clip_data)): | |
| if "text_projection" in clip_data[i]: | |
| clip_data[i]["text_projection.weight"] = clip_data[i][ | |
| "text_projection" | |
| ].transpose( | |
| 0, 1 | |
| ) # old models saved with the CLIPSave node | |
| clip_target = EmptyClass() | |
| clip_target.params = {} | |
| if len(clip_data) == 2: | |
| if clip_type == CLIPType.FLUX: | |
| weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" | |
| weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None)) | |
| dtype_t5 = None | |
| if weight is not None: | |
| dtype_t5 = weight.dtype | |
| clip_target.clip = FluxClip.flux_clip(dtype_t5=dtype_t5) | |
| clip_target.tokenizer = FluxClip.FluxTokenizer | |
| parameters = 0 | |
| tokenizer_data = {} | |
| for c in clip_data: | |
| parameters += util.calculate_parameters(c) | |
| tokenizer_data, model_options = SDToken.model_options_long_clip( | |
| c, tokenizer_data, model_options | |
| ) | |
| clip = CLIP( | |
| clip_target, | |
| embedding_directory=embedding_directory, | |
| parameters=parameters, | |
| tokenizer_data=tokenizer_data, | |
| model_options=model_options, | |
| ) | |
| for c in clip_data: | |
| m, u = clip.load_sd(c) | |
| if len(m) > 0: | |
| logging.warning("clip missing: {}".format(m)) | |
| if len(u) > 0: | |
| logging.debug("clip unexpected: {}".format(u)) | |
| return clip | |
| class CLIPTextEncode: | |
| """#### Text encoding class for the CLIP model.""" | |
| def encode(self, clip: CLIP, text: str, flux_enabled: bool = False) -> tuple: | |
| """#### Encode the input text. | |
| #### Args: | |
| - `clip` (CLIP): The CLIP object. | |
| - `text` (str): The input text. | |
| - `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False. | |
| #### Returns: | |
| - `tuple`: The encoded text and the pooled output. | |
| """ | |
| tokens = clip.tokenize(text) | |
| cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True, flux_enabled=flux_enabled) | |
| return ([[cond, {"pooled_output": pooled}]],) | |
| class CLIPSetLastLayer: | |
| """#### Set the last layer class for the CLIP model.""" | |
| def set_last_layer(self, clip: CLIP, stop_at_clip_layer: int) -> tuple: | |
| """#### Set the last layer of the CLIP model. | |
| works same as Automatic1111 clip skip | |
| #### Args: | |
| - `clip` (CLIP): The CLIP object. | |
| - `stop_at_clip_layer` (int): The layer to stop at. | |
| #### Returns: | |
| - `tuple`: Thefrom enum import Enum | |
| """ | |
| clip = clip.clone() | |
| clip.clip_layer(stop_at_clip_layer) | |
| return (clip,) | |
| class ClipTarget: | |
| """#### Target class for the CLIP model.""" | |
| def __init__(self, tokenizer: object, clip: object): | |
| """#### Initialize the ClipTarget class. | |
| #### Args: | |
| - `tokenizer` (object): The tokenizer. | |
| - `clip` (object): The CLIP model. | |
| """ | |
| self.clip = clip | |
| self.tokenizer = tokenizer | |
| self.params = {} | |