Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import numbers | |
import torch | |
from modules.Device import Device | |
from modules.cond import cast | |
from modules.clip.CLIPTextModel import CLIPTextModel | |
def gen_empty_tokens(special_tokens: dict, length: int) -> list: | |
"""#### Generate a list of empty tokens. | |
#### Args: | |
- `special_tokens` (dict): The special tokens. | |
- `length` (int): The length of the token list. | |
#### Returns: | |
- `list`: The list of empty tokens. | |
""" | |
start_token = special_tokens.get("start", None) | |
end_token = special_tokens.get("end", None) | |
pad_token = special_tokens.get("pad") | |
output = [] | |
if start_token is not None: | |
output.append(start_token) | |
if end_token is not None: | |
output.append(end_token) | |
output += [pad_token] * (length - len(output)) | |
return output | |
class ClipTokenWeightEncoder: | |
"""#### Class representing a CLIP token weight encoder.""" | |
def encode_token_weights(self, token_weight_pairs: list) -> tuple: | |
"""#### Encode token weights. | |
#### Args: | |
- `token_weight_pairs` (list): The token weight pairs. | |
#### Returns: | |
- `tuple`: The encoded tokens and the pooled output. | |
""" | |
to_encode = list() | |
max_token_len = 0 | |
has_weights = False | |
for x in token_weight_pairs: | |
tokens = list(map(lambda a: a[0], x)) | |
max_token_len = max(len(tokens), max_token_len) | |
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) | |
to_encode.append(tokens) | |
sections = len(to_encode) | |
if has_weights or sections == 0: | |
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) | |
o = self.encode(to_encode) | |
out, pooled = o[:2] | |
if pooled is not None: | |
first_pooled = pooled[0:1].to(Device.intermediate_device()) | |
else: | |
first_pooled = pooled | |
output = [] | |
for k in range(0, sections): | |
z = out[k : k + 1] | |
if has_weights: | |
z_empty = out[-1] | |
for i in range(len(z)): | |
for j in range(len(z[i])): | |
weight = token_weight_pairs[k][j][1] | |
if weight != 1.0: | |
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] | |
output.append(z) | |
if len(output) == 0: | |
r = (out[-1:].to(Device.intermediate_device()), first_pooled) | |
else: | |
r = (torch.cat(output, dim=-2).to(Device.intermediate_device()), first_pooled) | |
if len(o) > 2: | |
extra = {} | |
for k in o[2]: | |
v = o[2][k] | |
if k == "attention_mask": | |
v = ( | |
v[:sections] | |
.flatten() | |
.unsqueeze(dim=0) | |
.to(Device.intermediate_device()) | |
) | |
extra[k] = v | |
r = r + (extra,) | |
return r | |
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): | |
"""#### Uses the CLIP transformer encoder for text (from huggingface).""" | |
LAYERS = ["last", "pooled", "hidden"] | |
def __init__( | |
self, | |
version: str = "openai/clip-vit-large-patch14", | |
device: str = "cpu", | |
max_length: int = 77, | |
freeze: bool = True, | |
layer: str = "last", | |
layer_idx: int = None, | |
textmodel_json_config: str = None, | |
dtype: torch.dtype = None, | |
model_class: type = CLIPTextModel, | |
special_tokens: dict = {"start": 49406, "end": 49407, "pad": 49407}, | |
layer_norm_hidden_state: bool = True, | |
enable_attention_masks: bool = False, | |
zero_out_masked:bool = False, | |
return_projected_pooled: bool = True, | |
return_attention_masks: bool = False, | |
model_options={}, | |
): | |
"""#### Initialize the SDClipModel. | |
#### Args: | |
- `version` (str, optional): The version of the model. Defaults to "openai/clip-vit-large-patch14". | |
- `device` (str, optional): The device to use. Defaults to "cpu". | |
- `max_length` (int, optional): The maximum length of the input. Defaults to 77. | |
- `freeze` (bool, optional): Whether to freeze the model parameters. Defaults to True. | |
- `layer` (str, optional): The layer to use. Defaults to "last". | |
- `layer_idx` (int, optional): The index of the layer. Defaults to None. | |
- `textmodel_json_config` (str, optional): The path to the JSON config file. Defaults to None. | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `model_class` (type, optional): The model class. Defaults to CLIPTextModel. | |
- `special_tokens` (dict, optional): The special tokens. Defaults to {"start": 49406, "end": 49407, "pad": 49407}. | |
- `layer_norm_hidden_state` (bool, optional): Whether to normalize the hidden state. Defaults to True. | |
- `enable_attention_masks` (bool, optional): Whether to enable attention masks. Defaults to False. | |
- `zero_out_masked` (bool, optional): Whether to zero out masked tokens. Defaults to False. | |
- `return_projected_pooled` (bool, optional): Whether to return the projected pooled output. Defaults to True. | |
- `return_attention_masks` (bool, optional): Whether to return the attention masks. Defaults to False. | |
- `model_options` (dict, optional): Additional model options. Defaults to {}. | |
""" | |
super().__init__() | |
assert layer in self.LAYERS | |
if textmodel_json_config is None: | |
textmodel_json_config = "./_internal/clip/sd1_clip_config.json" | |
with open(textmodel_json_config) as f: | |
config = json.load(f) | |
operations = model_options.get("custom_operations", None) | |
if operations is None: | |
operations = cast.manual_cast | |
self.operations = operations | |
self.transformer = model_class(config, dtype, device, self.operations) | |
self.num_layers = self.transformer.num_layers | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
self.layer_idx = None | |
self.special_tokens = special_tokens | |
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) | |
self.enable_attention_masks = enable_attention_masks | |
self.zero_out_masked = zero_out_masked | |
self.layer_norm_hidden_state = layer_norm_hidden_state | |
self.return_projected_pooled = return_projected_pooled | |
self.return_attention_masks = return_attention_masks | |
if layer == "hidden": | |
assert layer_idx is not None | |
assert abs(layer_idx) < self.num_layers | |
self.set_clip_options({"layer": layer_idx}) | |
self.options_default = ( | |
self.layer, | |
self.layer_idx, | |
self.return_projected_pooled, | |
) | |
def freeze(self) -> None: | |
"""#### Freeze the model parameters.""" | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def set_clip_options(self, options: dict) -> None: | |
"""#### Set the CLIP options. | |
#### Args: | |
- `options` (dict): The options to set. | |
""" | |
layer_idx = options.get("layer", self.layer_idx) | |
self.return_projected_pooled = options.get( | |
"projected_pooled", self.return_projected_pooled | |
) | |
if layer_idx is None or abs(layer_idx) > self.num_layers: | |
self.layer = "last" | |
else: | |
self.layer = "hidden" | |
self.layer_idx = layer_idx | |
def reset_clip_options(self) -> None: | |
"""#### Reset the CLIP options to default.""" | |
self.layer = self.options_default[0] | |
self.layer_idx = self.options_default[1] | |
self.return_projected_pooled = self.options_default[2] | |
def set_up_textual_embeddings(self, tokens: list, current_embeds: torch.nn.Embedding) -> list: | |
"""#### Set up the textual embeddings. | |
#### Args: | |
- `tokens` (list): The input tokens. | |
- `current_embeds` (torch.nn.Embedding): The current embeddings. | |
#### Returns: | |
- `list`: The processed tokens. | |
""" | |
out_tokens = [] | |
next_new_token = token_dict_size = current_embeds.weight.shape[0] | |
embedding_weights = [] | |
for x in tokens: | |
tokens_temp = [] | |
for y in x: | |
if isinstance(y, numbers.Integral): | |
tokens_temp += [int(y)] | |
else: | |
if y.shape[0] == current_embeds.weight.shape[1]: | |
embedding_weights += [y] | |
tokens_temp += [next_new_token] | |
next_new_token += 1 | |
else: | |
logging.warning( | |
"WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format( | |
y.shape[0], current_embeds.weight.shape[1] | |
) | |
) | |
while len(tokens_temp) < len(x): | |
tokens_temp += [self.special_tokens["pad"]] | |
out_tokens += [tokens_temp] | |
n = token_dict_size | |
if len(embedding_weights) > 0: | |
new_embedding = self.operations.Embedding( | |
next_new_token + 1, | |
current_embeds.weight.shape[1], | |
device=current_embeds.weight.device, | |
dtype=current_embeds.weight.dtype, | |
) | |
new_embedding.weight[:token_dict_size] = current_embeds.weight | |
for x in embedding_weights: | |
new_embedding.weight[n] = x | |
n += 1 | |
self.transformer.set_input_embeddings(new_embedding) | |
processed_tokens = [] | |
for x in out_tokens: | |
processed_tokens += [ | |
list(map(lambda a: n if a == -1 else a, x)) | |
] # The EOS token should always be the largest one | |
return processed_tokens | |
def forward(self, tokens: list) -> tuple: | |
"""#### Forward pass of the model. | |
#### Args: | |
- `tokens` (list): The input tokens. | |
#### Returns: | |
- `tuple`: The output and the pooled output. | |
""" | |
backup_embeds = self.transformer.get_input_embeddings() | |
device = backup_embeds.weight.device | |
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) | |
tokens = torch.LongTensor(tokens).to(device) | |
attention_mask = None | |
if ( | |
self.enable_attention_masks | |
or self.zero_out_masked | |
or self.return_attention_masks | |
): | |
attention_mask = torch.zeros_like(tokens) | |
end_token = self.special_tokens.get("end", -1) | |
for x in range(attention_mask.shape[0]): | |
for y in range(attention_mask.shape[1]): | |
attention_mask[x, y] = 1 | |
if tokens[x, y] == end_token: | |
break | |
attention_mask_model = None | |
if self.enable_attention_masks: | |
attention_mask_model = attention_mask | |
outputs = self.transformer( | |
tokens, | |
attention_mask_model, | |
intermediate_output=self.layer_idx, | |
final_layer_norm_intermediate=self.layer_norm_hidden_state, | |
dtype=torch.float32, | |
) | |
self.transformer.set_input_embeddings(backup_embeds) | |
if self.layer == "last": | |
z = outputs[0].float() | |
else: | |
z = outputs[1].float() | |
if self.zero_out_masked: | |
z *= attention_mask.unsqueeze(-1).float() | |
pooled_output = None | |
if len(outputs) >= 3: | |
if ( | |
not self.return_projected_pooled | |
and len(outputs) >= 4 | |
and outputs[3] is not None | |
): | |
pooled_output = outputs[3].float() | |
elif outputs[2] is not None: | |
pooled_output = outputs[2].float() | |
extra = {} | |
if self.return_attention_masks: | |
extra["attention_mask"] = attention_mask | |
if len(extra) > 0: | |
return z, pooled_output, extra | |
return z, pooled_output | |
def encode(self, tokens: list) -> tuple: | |
"""#### Encode the input tokens. | |
#### Args: | |
- `tokens` (list): The input tokens. | |
#### Returns: | |
- `tuple`: The encoded tokens and the pooled output. | |
""" | |
return self(tokens) | |
def load_sd(self, sd: dict) -> None: | |
"""#### Load the state dictionary. | |
#### Args: | |
- `sd` (dict): The state dictionary. | |
""" | |
return self.transformer.load_state_dict(sd, strict=False) | |
class SD1ClipModel(torch.nn.Module): | |
"""#### Class representing the SD1ClipModel.""" | |
def __init__( | |
self, device: str = "cpu", dtype: torch.dtype = None, clip_name: str = "l", clip_model: type = SDClipModel, **kwargs | |
): | |
"""#### Initialize the SD1ClipModel. | |
#### Args: | |
- `device` (str, optional): The device to use. Defaults to "cpu". | |
- `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
- `clip_name` (str, optional): The name of the CLIP model. Defaults to "l". | |
- `clip_model` (type, optional): The CLIP model class. Defaults to SDClipModel. | |
- `**kwargs`: Additional keyword arguments. | |
""" | |
super().__init__() | |
self.clip_name = clip_name | |
self.clip = "clip_{}".format(self.clip_name) | |
self.lowvram_patch_counter = 0 | |
self.model_loaded_weight_memory = 0 | |
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) | |
def set_clip_options(self, options: dict) -> None: | |
"""#### Set the CLIP options. | |
#### Args: | |
- `options` (dict): The options to set. | |
""" | |
getattr(self, self.clip).set_clip_options(options) | |
def reset_clip_options(self) -> None: | |
"""#### Reset the CLIP options to default.""" | |
getattr(self, self.clip).reset_clip_options() | |
def encode_token_weights(self, token_weight_pairs: dict) -> tuple: | |
"""#### Encode token weights. | |
#### Args: | |
- `token_weight_pairs` (dict): The token weight pairs. | |
#### Returns: | |
- `tuple`: The encoded tokens and the pooled output. | |
""" | |
token_weight_pairs = token_weight_pairs[self.clip_name] | |
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) | |
return out, pooled |