Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from enum import Enum
import logging
import torch
from modules.Model import ModelPatcher
from modules.Attention import Attention
from modules.Device import Device
from modules.SD15 import SDToken
from modules.Utilities import util
from modules.clip import FluxClip
from modules.cond import cast
class CLIPAttention(torch.nn.Module):
"""#### The CLIPAttention module."""
def __init__(
self,
embed_dim: int,
heads: int,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPAttention module.
#### Args:
- `embed_dim` (int): The embedding dimension.
- `heads` (int): The number of attention heads.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.k_proj = operations.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.v_proj = operations.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
self.out_proj = operations.Linear(
embed_dim, embed_dim, bias=True, dtype=dtype, device=device
)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
optimized_attention: callable = None,
) -> torch.Tensor:
"""#### Forward pass for the CLIPAttention module.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None.
- `optimized_attention` (callable, optional): The optimized attention function. Defaults to None.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {
"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
"""#### The CLIPMLP module.
(MLP stands for Multi-Layer Perceptron.)"""
def __init__(
self,
embed_dim: int,
intermediate_size: int,
activation: str,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPMLP module.
#### Args:
- `embed_dim` (int): The embedding dimension.
- `intermediate_size` (int): The intermediate size.
- `activation` (str): The activation function.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
super().__init__()
self.fc1 = operations.Linear(
embed_dim, intermediate_size, bias=True, dtype=dtype, device=device
)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(
intermediate_size, embed_dim, bias=True, dtype=dtype, device=device
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""#### Forward pass for the CLIPMLP module.
#### Args:
- `x` (torch.Tensor): The input tensor.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
"""#### The CLIPLayer module."""
def __init__(
self,
embed_dim: int,
heads: int,
intermediate_size: int,
intermediate_activation: str,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPLayer module.
#### Args:
- `embed_dim` (int): The embedding dimension.
- `heads` (int): The number of attention heads.
- `intermediate_size` (int): The intermediate size.
- `intermediate_activation` (str): The intermediate activation function.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(
embed_dim,
intermediate_size,
intermediate_activation,
dtype,
device,
operations,
)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
optimized_attention: callable = None,
) -> torch.Tensor:
"""#### Forward pass for the CLIPLayer module.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None.
- `optimized_attention` (callable, optional): The optimized attention function. Defaults to None.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
"""#### The CLIPEncoder module."""
def __init__(
self,
num_layers: int,
embed_dim: int,
heads: int,
intermediate_size: int,
intermediate_activation: str,
dtype: torch.dtype,
device: torch.device,
operations: object,
):
"""#### Initialize the CLIPEncoder module.
#### Args:
- `num_layers` (int): The number of layers.
- `embed_dim` (int): The embedding dimension.
- `heads` (int): The number of attention heads.
- `intermediate_size` (int): The intermediate size.
- `intermediate_activation` (str): The intermediate activation function.
- `dtype` (torch.dtype): The data type.
- `device` (torch.device): The device to use.
- `operations` (object): The operations object.
"""
super().__init__()
self.layers = torch.nn.ModuleList(
[
CLIPLayer(
embed_dim,
heads,
intermediate_size,
intermediate_activation,
dtype,
device,
operations,
)
for i in range(num_layers)
]
)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
intermediate_output: int = None,
) -> tuple:
"""#### Forward pass for the CLIPEncoder module.
#### Args:
- `x` (torch.Tensor): The input tensor.
- `mask` (torch.Tensor, optional): The attention mask. Defaults to None.
- `intermediate_output` (int, optional): The intermediate output layer. Defaults to None.
#### Returns:
- `tuple`: The output tensor and the intermediate output tensor.
"""
optimized_attention = Attention.optimized_attention_for_device()
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, length in enumerate(self.layers):
x = length(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
"""#### The CLIPEmbeddings module."""
def __init__(
self,
embed_dim: int,
vocab_size: int = 49408,
num_positions: int = 77,
dtype: torch.dtype = None,
device: torch.device = None,
operations: object = torch.nn,
):
"""#### Initialize the CLIPEmbeddings module.
#### Args:
- `embed_dim` (int): The embedding dimension.
- `vocab_size` (int, optional): The vocabulary size. Defaults to 49408.
- `num_positions` (int, optional): The number of positions. Defaults to 77.
- `dtype` (torch.dtype, optional): The data type. Defaults to None.
- `device` (torch.device, optional): The device to use. Defaults to None.
"""
super().__init__()
self.token_embedding = operations.Embedding(
vocab_size, embed_dim, dtype=dtype, device=device
)
self.position_embedding = operations.Embedding(
num_positions, embed_dim, dtype=dtype, device=device
)
def forward(self, input_tokens: torch.Tensor, dtype=torch.float32) -> torch.Tensor:
"""#### Forward pass for the CLIPEmbeddings module.
#### Args:
- `input_tokens` (torch.Tensor): The input tokens.
- `dtype` (torch.dtype, optional): The data type. Defaults to torch.float32.
#### Returns:
- `torch.Tensor`: The output tensor.
"""
return self.token_embedding(input_tokens, out_dtype=dtype) + cast.cast_to(
self.position_embedding.weight, dtype=dtype, device=input_tokens.device
)
class CLIP:
"""#### The CLIP class."""
def __init__(
self,
target: object = None,
embedding_directory: str = None,
no_init: bool = False,
tokenizer_data={},
parameters=0,
model_options={},
):
"""#### Initialize the CLIP class.
#### Args:
- `target` (object, optional): The target object. Defaults to None.
- `embedding_directory` (str, optional): The embedding directory. Defaults to None.
- `no_init` (bool, optional): Whether to skip initialization. Defaults to False.
"""
if no_init:
return
params = target.params.copy()
clip = target.clip
tokenizer = target.tokenizer
load_device = model_options.get("load_device", Device.text_encoder_device())
offload_device = model_options.get(
"offload_device", Device.text_encoder_offload_device()
)
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = Device.text_encoder_dtype(load_device)
params["dtype"] = dtype
params["device"] = model_options.get(
"initial_device",
Device.text_encoder_initial_device(
load_device, offload_device, parameters * Device.dtype_size(dtype)
),
)
params["model_options"] = model_options
self.cond_stage_model = clip(**(params))
# for dt in self.cond_stage_model.dtypes:
# if not Device.supports_cast(load_device, dt):
# load_device = offload_device
# if params["device"] != offload_device:
# self.cond_stage_model.to(offload_device)
# logging.warning("Had to shift TE back.")
try:
self.tokenizer = tokenizer(
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data
)
except TypeError:
self.tokenizer = tokenizer(
embedding_directory=embedding_directory
)
self.patcher = ModelPatcher.ModelPatcher(
self.cond_stage_model,
load_device=load_device,
offload_device=offload_device,
)
if params["device"] == load_device:
Device.load_models_gpu([self.patcher], force_full_load=True, flux_enabled=True)
self.layer_idx = None
logging.debug(
"CLIP model load device: {}, offload device: {}, current: {}".format(
load_device, offload_device, params["device"]
)
)
def clone(self) -> "CLIP":
"""#### Clone the CLIP object.
#### Returns:
- `CLIP`: The cloned CLIP object.
"""
n = CLIP(no_init=True)
n.patcher = self.patcher.clone()
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
return n
def add_patches(
self, patches: list, strength_patch: float = 1.0, strength_model: float = 1.0
) -> None:
"""#### Add patches to the model.
#### Args:
- `patches` (list): The patches to add.
- `strength_patch` (float, optional): The strength of the patches. Defaults to 1.0.
- `strength_model` (float, optional): The strength of the model. Defaults to 1.0.
"""
return self.patcher.add_patches(patches, strength_patch, strength_model)
def clip_layer(self, layer_idx: int) -> None:
"""#### Set the clip layer.
#### Args:
- `layer_idx` (int): The layer index.
"""
self.layer_idx = layer_idx
def tokenize(self, text: str, return_word_ids: bool = False) -> list:
"""#### Tokenize the input text.
#### Args:
- `text` (str): The input text.
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
#### Returns:
- `list`: The tokenized text.
"""
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens: list, return_pooled: bool = False, return_dict: bool = False, flux_enabled:bool = False) -> tuple:
"""#### Encode the input tokens.
#### Args:
- `tokens` (list): The input tokens.
- `return_pooled` (bool, optional): Whether to return the pooled output. Defaults to False.
- `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False.
#### Returns:
- `tuple`: The encoded tokens and the pooled output.
"""
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(flux_enabled=flux_enabled)
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
if len(o) > 2:
for k in o[2]:
out[k] = o[2][k]
return out
if return_pooled:
return cond, pooled
return cond
def load_sd(self, sd: dict, full_model: bool = False) -> None:
"""#### Load the state dictionary.
#### Args:
- `sd` (dict): The state dictionary.
- `full_model` (bool, optional): Whether to load the full model. Defaults to False.
"""
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
def load_model(self, flux_enabled:bool = False) -> ModelPatcher:
"""#### Load the model.
#### Returns:
- `ModelPatcher`: The model patcher.
"""
Device.load_model_gpu(self.patcher, flux_enabled=flux_enabled)
return self.patcher
def encode(self, text):
"""#### Encode the input text.
#### Args:
- `text` (str): The input text.
#### Returns:
- `torch.Tensor`: The encoded text.
"""
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
def get_sd(self):
"""#### Get the state dictionary.
#### Returns:
- `dict`: The state dictionary.
"""
sd_clip = self.cond_stage_model.state_dict()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def get_key_patches(self):
"""#### Get the key patches.
#### Returns:
- `list`: The key patches.
"""
return self.patcher.get_key_patches()
class CLIPType(Enum):
STABLE_DIFFUSION = 1
SD3 = 3
FLUX = 6
def load_text_encoder_state_dicts(
state_dicts=[],
embedding_directory=None,
clip_type=CLIPType.STABLE_DIFFUSION,
model_options={},
):
"""#### Load the text encoder state dictionaries.
#### Args:
- `state_dicts` (list, optional): The state dictionaries. Defaults to [].
- `embedding_directory` (str, optional): The embedding directory. Defaults to None.
- `clip_type` (CLIPType, optional): The CLIP type. Defaults to CLIPType.STABLE_DIFFUSION.
- `model_options` (dict, optional): The model options. Defaults to {}.
#### Returns:
- `CLIP`: The CLIP object.
"""
clip_data = state_dicts
class EmptyClass:
pass
for i in range(len(clip_data)):
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i][
"text_projection"
].transpose(
0, 1
) # old models saved with the CLIPSave node
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 2:
if clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = FluxClip.flux_clip(dtype_t5=dtype_t5)
clip_target.tokenizer = FluxClip.FluxTokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += util.calculate_parameters(c)
tokenizer_data, model_options = SDToken.model_options_long_clip(
c, tokenizer_data, model_options
)
clip = CLIP(
clip_target,
embedding_directory=embedding_directory,
parameters=parameters,
tokenizer_data=tokenizer_data,
model_options=model_options,
)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
return clip
class CLIPTextEncode:
"""#### Text encoding class for the CLIP model."""
def encode(self, clip: CLIP, text: str, flux_enabled: bool = False) -> tuple:
"""#### Encode the input text.
#### Args:
- `clip` (CLIP): The CLIP object.
- `text` (str): The input text.
- `flux_enabled` (bool, optional): Whether to enable flux. Defaults to False.
#### Returns:
- `tuple`: The encoded text and the pooled output.
"""
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True, flux_enabled=flux_enabled)
return ([[cond, {"pooled_output": pooled}]],)
class CLIPSetLastLayer:
"""#### Set the last layer class for the CLIP model."""
def set_last_layer(self, clip: CLIP, stop_at_clip_layer: int) -> tuple:
"""#### Set the last layer of the CLIP model.
works same as Automatic1111 clip skip
#### Args:
- `clip` (CLIP): The CLIP object.
- `stop_at_clip_layer` (int): The layer to stop at.
#### Returns:
- `tuple`: Thefrom enum import Enum
"""
clip = clip.clone()
clip.clip_layer(stop_at_clip_layer)
return (clip,)
class ClipTarget:
"""#### Target class for the CLIP model."""
def __init__(self, tokenizer: object, clip: object):
"""#### Initialize the ClipTarget class.
#### Args:
- `tokenizer` (object): The tokenizer.
- `clip` (object): The CLIP model.
"""
self.clip = clip
self.tokenizer = tokenizer
self.params = {}