Spaces:
Running
on
Zero
Running
on
Zero
from enum import Enum | |
import logging | |
import torch | |
from modules.Model import ModelPatcher | |
from modules.Attention import Attention | |
from modules.Device import Device | |
from modules.SD15 import SDToken | |
from modules.Utilities import util | |
from modules.clip import FluxClip | |
from modules.cond import cast | |
class CLIPAttention(torch.nn.Module): | |
"""#### The CLIPAttention module.""" | |
def __init__( | |
self, | |
embed_dim: int, | |
heads: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
operations: object, | |
): | |
"""#### Initialize the CLIPAttention module. | |
#### Args: | |
- `embed_dim` (int): The embedding dimension. | |
- `heads` (int): The number of attention heads. | |
- `dtype` (torch.dtype): The data type. | |
- `device` (torch.device): The device to use. | |
- `operations` (object): The operations object. | |
""" | |
super().__init__() | |
self.heads = heads | |
self.q_proj = operations.Linear( | |
embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
) | |
self.k_proj = operations.Linear( | |
embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
) | |
self.v_proj = operations.Linear( | |
embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
) | |
self.out_proj = operations.Linear( | |
embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: torch.Tensor = None, | |
optimized_attention: callable = None, | |
) -> torch.Tensor: | |
"""#### Forward pass for the CLIPAttention module. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
- `optimized_attention` (callable, optional): The optimized attention function. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
q = self.q_proj(x) | |
k = self.k_proj(x) | |
v = self.v_proj(x) | |
out = optimized_attention(q, k, v, self.heads, mask) | |
return self.out_proj(out) | |
ACTIVATIONS = { | |
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), | |
"gelu": torch.nn.functional.gelu, | |
} | |
class CLIPMLP(torch.nn.Module): | |
"""#### The CLIPMLP module. | |
(MLP stands for Multi-Layer Perceptron.)""" | |
def __init__( | |
self, | |
embed_dim: int, | |
intermediate_size: int, | |
activation: str, | |
dtype: torch.dtype, | |
device: torch.device, | |
operations: object, | |
): | |
"""#### Initialize the CLIPMLP module. | |
#### Args: | |
- `embed_dim` (int): The embedding dimension. | |
- `intermediate_size` (int): The intermediate size. | |
- `activation` (str): The activation function. | |
- `dtype` (torch.dtype): The data type. | |
- `device` (torch.device): The device to use. | |
- `operations` (object): The operations object. | |
""" | |
super().__init__() | |
self.fc1 = operations.Linear( | |
embed_dim, intermediate_size, bias=True, dtype=dtype, device=device | |
) | |
self.activation = ACTIVATIONS[activation] | |
self.fc2 = operations.Linear( | |
intermediate_size, embed_dim, bias=True, dtype=dtype, device=device | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""#### Forward pass for the CLIPMLP module. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
x = self.fc1(x) | |
x = self.activation(x) | |
x = self.fc2(x) | |
return x | |
class CLIPLayer(torch.nn.Module): | |
"""#### The CLIPLayer module.""" | |
def __init__( | |
self, | |
embed_dim: int, | |
heads: int, | |
intermediate_size: int, | |
intermediate_activation: str, | |
dtype: torch.dtype, | |
device: torch.device, | |
operations: object, | |
): | |
"""#### Initialize the CLIPLayer module. | |
#### Args: | |
- `embed_dim` (int): The embedding dimension. | |
- `heads` (int): The number of attention heads. | |
- `intermediate_size` (int): The intermediate size. | |
- `intermediate_activation` (str): The intermediate activation function. | |
- `dtype` (torch.dtype): The data type. | |
- `device` (torch.device): The device to use. | |
- `operations` (object): The operations object. | |
""" | |
super().__init__() | |
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) | |
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations) | |
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) | |
self.mlp = CLIPMLP( | |
embed_dim, | |
intermediate_size, | |
intermediate_activation, | |
dtype, | |
device, | |
operations, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: torch.Tensor = None, | |
optimized_attention: callable = None, | |
) -> torch.Tensor: | |
"""#### Forward pass for the CLIPLayer module. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
- `optimized_attention` (callable, optional): The optimized attention function. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention) | |
x += self.mlp(self.layer_norm2(x)) | |
return x | |
class CLIPEncoder(torch.nn.Module): | |
"""#### The CLIPEncoder module.""" | |
def __init__( | |
self, | |
num_layers: int, | |
embed_dim: int, | |
heads: int, | |
intermediate_size: int, | |
intermediate_activation: str, | |
dtype: torch.dtype, | |
device: torch.device, | |
operations: object, | |
): | |
"""#### Initialize the CLIPEncoder module. | |
#### Args: | |
- `num_layers` (int): The number of layers. | |
- `embed_dim` (int): The embedding dimension. | |
- `heads` (int): The number of attention heads. | |
- `intermediate_size` (int): The intermediate size. | |
- `intermediate_activation` (str): The intermediate activation function. | |
- `dtype` (torch.dtype): The data type. | |
- `device` (torch.device): The device to use. | |
- `operations` (object): The operations object. | |
""" | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
CLIPLayer( | |
embed_dim, | |
heads, | |
intermediate_size, | |
intermediate_activation, | |
dtype, | |
device, | |
operations, | |
) | |
for i in range(num_layers) | |
] | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: torch.Tensor = None, | |
intermediate_output: int = None, | |
) -> tuple: | |
"""#### Forward pass for the CLIPEncoder module. | |
#### Args: | |
- `x` (torch.Tensor): The input tensor. | |
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
- `intermediate_output` (int, optional): The intermediate output layer. Defaults to None. | |
#### Returns: | |
- `tuple`: The output tensor and the intermediate output tensor. | |
""" | |
optimized_attention = Attention.optimized_attention_for_device() | |
if intermediate_output is not None: | |
if intermediate_output < 0: | |
intermediate_output = len(self.layers) + intermediate_output | |
intermediate = None | |
for i, length in enumerate(self.layers): | |
x = length(x, mask, optimized_attention) | |
if i == intermediate_output: | |
intermediate = x.clone() | |
return x, intermediate | |
class CLIPEmbeddings(torch.nn.Module): | |
"""#### The CLIPEmbeddings module.""" | |
def __init__( | |
self, | |
embed_dim: int, | |
vocab_size: int = 49408, | |
num_positions: int = 77, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
operations: object = torch.nn, | |
): | |
"""#### Initialize the CLIPEmbeddings module. | |
#### Args: | |
- `embed_dim` (int): The embedding dimension. | |
- `vocab_size` (int, optional): The vocabulary size. Defaults to 49408. | |
- `num_positions` (int, optional): The number of positions. Defaults to 77. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `device` (torch.device, optional): The device to use. Defaults to None. | |
""" | |
super().__init__() | |
self.token_embedding = operations.Embedding( | |
vocab_size, embed_dim, dtype=dtype, device=device | |
) | |
self.position_embedding = operations.Embedding( | |
num_positions, embed_dim, dtype=dtype, device=device | |
) | |
def forward(self, input_tokens: torch.Tensor, dtype=torch.float32) -> torch.Tensor: | |
"""#### Forward pass for the CLIPEmbeddings module. | |
#### Args: | |
- `input_tokens` (torch.Tensor): The input tokens. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to torch.float32. | |
#### Returns: | |
- `torch.Tensor`: The output tensor. | |
""" | |
return self.token_embedding(input_tokens, out_dtype=dtype) + cast.cast_to( | |
self.position_embedding.weight, dtype=dtype, device=input_tokens.device | |
) | |
class CLIP: | |
"""#### The CLIP class.""" | |
def __init__( | |
self, | |
target: object = None, | |
embedding_directory: str = None, | |
no_init: bool = False, | |
tokenizer_data={}, | |
parameters=0, | |
model_options={}, | |
): | |
"""#### Initialize the CLIP class. | |
#### Args: | |
- `target` (object, optional): The target object. Defaults to None. | |
- `embedding_directory` (str, optional): The embedding directory. Defaults to None. | |
- `no_init` (bool, optional): Whether to skip initialization. Defaults to False. | |
""" | |
if no_init: | |
return | |
params = target.params.copy() | |
clip = target.clip | |
tokenizer = target.tokenizer | |
load_device = model_options.get("load_device", Device.text_encoder_device()) | |
offload_device = model_options.get( | |
"offload_device", Device.text_encoder_offload_device() | |
) | |
dtype = model_options.get("dtype", None) | |
if dtype is None: | |
dtype = Device.text_encoder_dtype(load_device) | |
params["dtype"] = dtype | |
params["device"] = model_options.get( | |
"initial_device", | |
Device.text_encoder_initial_device( | |
load_device, offload_device, parameters * Device.dtype_size(dtype) | |
), | |
) | |
params["model_options"] = model_options | |
self.cond_stage_model = clip(**(params)) | |
# for dt in self.cond_stage_model.dtypes: | |
# if not Device.supports_cast(load_device, dt): | |
# load_device = offload_device | |
# if params["device"] != offload_device: | |
# self.cond_stage_model.to(offload_device) | |
# logging.warning("Had to shift TE back.") | |
try: | |
self.tokenizer = tokenizer( | |
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data | |
) | |
except TypeError: | |
self.tokenizer = tokenizer( | |
embedding_directory=embedding_directory | |
) | |
self.patcher = ModelPatcher.ModelPatcher( | |
self.cond_stage_model, | |
load_device=load_device, | |
offload_device=offload_device, | |
) | |
if params["device"] == load_device: | |
Device.load_models_gpu([self.patcher], force_full_load=True, flux_enabled=True) | |
self.layer_idx = None | |
logging.debug( | |
"CLIP model load device: {}, offload device: {}, current: {}".format( | |
load_device, offload_device, params["device"] | |
) | |
) | |
def clone(self) -> "CLIP": | |
"""#### Clone the CLIP object. | |
#### Returns: | |
- `CLIP`: The cloned CLIP object. | |
""" | |
n = CLIP(no_init=True) | |
n.patcher = self.patcher.clone() | |
n.cond_stage_model = self.cond_stage_model | |
n.tokenizer = self.tokenizer | |
n.layer_idx = self.layer_idx | |
return n | |
def add_patches( | |
self, patches: list, strength_patch: float = 1.0, strength_model: float = 1.0 | |
) -> None: | |
"""#### Add patches to the model. | |
#### Args: | |
- `patches` (list): The patches to add. | |
- `strength_patch` (float, optional): The strength of the patches. Defaults to 1.0. | |
- `strength_model` (float, optional): The strength of the model. Defaults to 1.0. | |
""" | |
return self.patcher.add_patches(patches, strength_patch, strength_model) | |
def clip_layer(self, layer_idx: int) -> None: | |
"""#### Set the clip layer. | |
#### Args: | |
- `layer_idx` (int): The layer index. | |
""" | |
self.layer_idx = layer_idx | |
def tokenize(self, text: str, return_word_ids: bool = False) -> list: | |
"""#### Tokenize the input text. | |
#### Args: | |
- `text` (str): The input text. | |
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False. | |
#### Returns: | |
- `list`: The tokenized text. | |
""" | |
return self.tokenizer.tokenize_with_weights(text, return_word_ids) | |
def encode_from_tokens(self, tokens: list, return_pooled: bool = False, return_dict: bool = False, flux_enabled:bool = False) -> tuple: | |
"""#### Encode the input tokens. | |
#### Args: | |
- `tokens` (list): The input tokens. | |
- `return_pooled` (bool, optional): Whether to return the pooled output. Defaults to False. | |
- `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False. | |
#### Returns: | |
- `tuple`: The encoded tokens and the pooled output. | |
""" | |
self.cond_stage_model.reset_clip_options() | |
if self.layer_idx is not None: | |
self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) | |
if return_pooled == "unprojected": | |
self.cond_stage_model.set_clip_options({"projected_pooled": False}) | |
self.load_model(flux_enabled=flux_enabled) | |
o = self.cond_stage_model.encode_token_weights(tokens) | |
cond, pooled = o[:2] | |
if return_dict: | |
out = {"cond": cond, "pooled_output": pooled} | |
if len(o) > 2: | |
for k in o[2]: | |
out[k] = o[2][k] | |
return out | |
if return_pooled: | |
return cond, pooled | |
return cond | |
def load_sd(self, sd: dict, full_model: bool = False) -> None: | |
"""#### Load the state dictionary. | |
#### Args: | |
- `sd` (dict): The state dictionary. | |
- `full_model` (bool, optional): Whether to load the full model. Defaults to False. | |
""" | |
if full_model: | |
return self.cond_stage_model.load_state_dict(sd, strict=False) | |
else: | |
return self.cond_stage_model.load_sd(sd) | |
def load_model(self, flux_enabled:bool = False) -> ModelPatcher: | |
"""#### Load the model. | |
#### Returns: | |
- `ModelPatcher`: The model patcher. | |
""" | |
Device.load_model_gpu(self.patcher, flux_enabled=flux_enabled) | |
return self.patcher | |
def encode(self, text): | |
"""#### Encode the input text. | |
#### Args: | |
- `text` (str): The input text. | |
#### Returns: | |
- `torch.Tensor`: The encoded text. | |
""" | |
tokens = self.tokenize(text) | |
return self.encode_from_tokens(tokens) | |
def get_sd(self): | |
"""#### Get the state dictionary. | |
#### Returns: | |
- `dict`: The state dictionary. | |
""" | |
sd_clip = self.cond_stage_model.state_dict() | |
sd_tokenizer = self.tokenizer.state_dict() | |
for k in sd_tokenizer: | |
sd_clip[k] = sd_tokenizer[k] | |
return sd_clip | |
def get_key_patches(self): | |
"""#### Get the key patches. | |
#### Returns: | |
- `list`: The key patches. | |
""" | |
return self.patcher.get_key_patches() | |
class CLIPType(Enum): | |
STABLE_DIFFUSION = 1 | |
SD3 = 3 | |
FLUX = 6 | |
def load_text_encoder_state_dicts( | |
state_dicts=[], | |
embedding_directory=None, | |
clip_type=CLIPType.STABLE_DIFFUSION, | |
model_options={}, | |
): | |
"""#### Load the text encoder state dictionaries. | |
#### Args: | |
- `state_dicts` (list, optional): The state dictionaries. Defaults to []. | |
- `embedding_directory` (str, optional): The embedding directory. Defaults to None. | |
- `clip_type` (CLIPType, optional): The CLIP type. Defaults to CLIPType.STABLE_DIFFUSION. | |
- `model_options` (dict, optional): The model options. Defaults to {}. | |
#### Returns: | |
- `CLIP`: The CLIP object. | |
""" | |
clip_data = state_dicts | |
class EmptyClass: | |
pass | |
for i in range(len(clip_data)): | |
if "text_projection" in clip_data[i]: | |
clip_data[i]["text_projection.weight"] = clip_data[i][ | |
"text_projection" | |
].transpose( | |
0, 1 | |
) # old models saved with the CLIPSave node | |
clip_target = EmptyClass() | |
clip_target.params = {} | |
if len(clip_data) == 2: | |
if clip_type == CLIPType.FLUX: | |
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" | |
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None)) | |
dtype_t5 = None | |
if weight is not None: | |
dtype_t5 = weight.dtype | |
clip_target.clip = FluxClip.flux_clip(dtype_t5=dtype_t5) | |
clip_target.tokenizer = FluxClip.FluxTokenizer | |
parameters = 0 | |
tokenizer_data = {} | |
for c in clip_data: | |
parameters += util.calculate_parameters(c) | |
tokenizer_data, model_options = SDToken.model_options_long_clip( | |
c, tokenizer_data, model_options | |
) | |
clip = CLIP( | |
clip_target, | |
embedding_directory=embedding_directory, | |
parameters=parameters, | |
tokenizer_data=tokenizer_data, | |
model_options=model_options, | |
) | |
for c in clip_data: | |
m, u = clip.load_sd(c) | |
if len(m) > 0: | |
logging.warning("clip missing: {}".format(m)) | |
if len(u) > 0: | |
logging.debug("clip unexpected: {}".format(u)) | |
return clip | |
class CLIPTextEncode: | |
"""#### Text encoding class for the CLIP model.""" | |
def encode(self, clip: CLIP, text: str, flux_enabled: bool = False) -> tuple: | |
"""#### Encode the input text. | |
#### Args: | |
- `clip` (CLIP): The CLIP object. | |
- `text` (str): The input text. | |
- `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False. | |
#### Returns: | |
- `tuple`: The encoded text and the pooled output. | |
""" | |
tokens = clip.tokenize(text) | |
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True, flux_enabled=flux_enabled) | |
return ([[cond, {"pooled_output": pooled}]],) | |
class CLIPSetLastLayer: | |
"""#### Set the last layer class for the CLIP model.""" | |
def set_last_layer(self, clip: CLIP, stop_at_clip_layer: int) -> tuple: | |
"""#### Set the last layer of the CLIP model. | |
works same as Automatic1111 clip skip | |
#### Args: | |
- `clip` (CLIP): The CLIP object. | |
- `stop_at_clip_layer` (int): The layer to stop at. | |
#### Returns: | |
- `tuple`: Thefrom enum import Enum | |
""" | |
clip = clip.clone() | |
clip.clip_layer(stop_at_clip_layer) | |
return (clip,) | |
class ClipTarget: | |
"""#### Target class for the CLIP model.""" | |
def __init__(self, tokenizer: object, clip: object): | |
"""#### Initialize the ClipTarget class. | |
#### Args: | |
- `tokenizer` (object): The tokenizer. | |
- `clip` (object): The CLIP model. | |
""" | |
self.clip = clip | |
self.tokenizer = tokenizer | |
self.params = {} | |