import json import logging import numbers import torch from modules.Device import Device from modules.cond import cast from modules.clip.CLIPTextModel import CLIPTextModel def gen_empty_tokens(special_tokens: dict, length: int) -> list: """#### Generate a list of empty tokens. #### Args: - `special_tokens` (dict): The special tokens. - `length` (int): The length of the token list. #### Returns: - `list`: The list of empty tokens. """ start_token = special_tokens.get("start", None) end_token = special_tokens.get("end", None) pad_token = special_tokens.get("pad") output = [] if start_token is not None: output.append(start_token) if end_token is not None: output.append(end_token) output += [pad_token] * (length - len(output)) return output class ClipTokenWeightEncoder: """#### Class representing a CLIP token weight encoder.""" def encode_token_weights(self, token_weight_pairs: list) -> tuple: """#### Encode token weights. #### Args: - `token_weight_pairs` (list): The token weight pairs. #### Returns: - `tuple`: The encoded tokens and the pooled output. """ to_encode = list() max_token_len = 0 has_weights = False for x in token_weight_pairs: tokens = list(map(lambda a: a[0], x)) max_token_len = max(len(tokens), max_token_len) has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) to_encode.append(tokens) sections = len(to_encode) if has_weights or sections == 0: to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) o = self.encode(to_encode) out, pooled = o[:2] if pooled is not None: first_pooled = pooled[0:1].to(Device.intermediate_device()) else: first_pooled = pooled output = [] for k in range(0, sections): z = out[k : k + 1] if has_weights: z_empty = out[-1] for i in range(len(z)): for j in range(len(z[i])): weight = token_weight_pairs[k][j][1] if weight != 1.0: z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] output.append(z) if len(output) == 0: r = (out[-1:].to(Device.intermediate_device()), first_pooled) else: r = (torch.cat(output, dim=-2).to(Device.intermediate_device()), first_pooled) if len(o) > 2: extra = {} for k in o[2]: v = o[2][k] if k == "attention_mask": v = ( v[:sections] .flatten() .unsqueeze(dim=0) .to(Device.intermediate_device()) ) extra[k] = v r = r + (extra,) return r class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """#### Uses the CLIP transformer encoder for text (from huggingface).""" LAYERS = ["last", "pooled", "hidden"] def __init__( self, version: str = "openai/clip-vit-large-patch14", device: str = "cpu", max_length: int = 77, freeze: bool = True, layer: str = "last", layer_idx: int = None, textmodel_json_config: str = None, dtype: torch.dtype = None, model_class: type = CLIPTextModel, special_tokens: dict = {"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state: bool = True, enable_attention_masks: bool = False, zero_out_masked:bool = False, return_projected_pooled: bool = True, return_attention_masks: bool = False, model_options={}, ): """#### Initialize the SDClipModel. #### Args: - `version` (str, optional): The version of the model. Defaults to "openai/clip-vit-large-patch14". - `device` (str, optional): The device to use. Defaults to "cpu". - `max_length` (int, optional): The maximum length of the input. Defaults to 77. - `freeze` (bool, optional): Whether to freeze the model parameters. Defaults to True. - `layer` (str, optional): The layer to use. Defaults to "last". - `layer_idx` (int, optional): The index of the layer. Defaults to None. - `textmodel_json_config` (str, optional): The path to the JSON config file. Defaults to None. - `dtype` (torch.dtype, optional): The data type. Defaults to None. - `model_class` (type, optional): The model class. Defaults to CLIPTextModel. - `special_tokens` (dict, optional): The special tokens. Defaults to {"start": 49406, "end": 49407, "pad": 49407}. - `layer_norm_hidden_state` (bool, optional): Whether to normalize the hidden state. Defaults to True. - `enable_attention_masks` (bool, optional): Whether to enable attention masks. Defaults to False. - `zero_out_masked` (bool, optional): Whether to zero out masked tokens. Defaults to False. - `return_projected_pooled` (bool, optional): Whether to return the projected pooled output. Defaults to True. - `return_attention_masks` (bool, optional): Whether to return the attention masks. Defaults to False. - `model_options` (dict, optional): Additional model options. Defaults to {}. """ super().__init__() assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = "./_internal/clip/sd1_clip_config.json" with open(textmodel_json_config) as f: config = json.load(f) operations = model_options.get("custom_operations", None) if operations is None: operations = cast.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) self.num_layers = self.transformer.num_layers self.max_length = max_length if freeze: self.freeze() self.layer = layer self.layer_idx = None self.special_tokens = special_tokens self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.enable_attention_masks = enable_attention_masks self.zero_out_masked = zero_out_masked self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks if layer == "hidden": assert layer_idx is not None assert abs(layer_idx) < self.num_layers self.set_clip_options({"layer": layer_idx}) self.options_default = ( self.layer, self.layer_idx, self.return_projected_pooled, ) def freeze(self) -> None: """#### Freeze the model parameters.""" self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def set_clip_options(self, options: dict) -> None: """#### Set the CLIP options. #### Args: - `options` (dict): The options to set. """ layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get( "projected_pooled", self.return_projected_pooled ) if layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" self.layer_idx = layer_idx def reset_clip_options(self) -> None: """#### Reset the CLIP options to default.""" self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] def set_up_textual_embeddings(self, tokens: list, current_embeds: torch.nn.Embedding) -> list: """#### Set up the textual embeddings. #### Args: - `tokens` (list): The input tokens. - `current_embeds` (torch.nn.Embedding): The current embeddings. #### Returns: - `list`: The processed tokens. """ out_tokens = [] next_new_token = token_dict_size = current_embeds.weight.shape[0] embedding_weights = [] for x in tokens: tokens_temp = [] for y in x: if isinstance(y, numbers.Integral): tokens_temp += [int(y)] else: if y.shape[0] == current_embeds.weight.shape[1]: embedding_weights += [y] tokens_temp += [next_new_token] next_new_token += 1 else: logging.warning( "WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format( y.shape[0], current_embeds.weight.shape[1] ) ) while len(tokens_temp) < len(x): tokens_temp += [self.special_tokens["pad"]] out_tokens += [tokens_temp] n = token_dict_size if len(embedding_weights) > 0: new_embedding = self.operations.Embedding( next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype, ) new_embedding.weight[:token_dict_size] = current_embeds.weight for x in embedding_weights: new_embedding.weight[n] = x n += 1 self.transformer.set_input_embeddings(new_embedding) processed_tokens = [] for x in out_tokens: processed_tokens += [ list(map(lambda a: n if a == -1 else a, x)) ] # The EOS token should always be the largest one return processed_tokens def forward(self, tokens: list) -> tuple: """#### Forward pass of the model. #### Args: - `tokens` (list): The input tokens. #### Returns: - `tuple`: The output and the pooled output. """ backup_embeds = self.transformer.get_input_embeddings() device = backup_embeds.weight.device tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) attention_mask = None if ( self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks ): attention_mask = torch.zeros_like(tokens) end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): for y in range(attention_mask.shape[1]): attention_mask[x, y] = 1 if tokens[x, y] == end_token: break attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask outputs = self.transformer( tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, ) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": z = outputs[0].float() else: z = outputs[1].float() if self.zero_out_masked: z *= attention_mask.unsqueeze(-1).float() pooled_output = None if len(outputs) >= 3: if ( not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None ): pooled_output = outputs[3].float() elif outputs[2] is not None: pooled_output = outputs[2].float() extra = {} if self.return_attention_masks: extra["attention_mask"] = attention_mask if len(extra) > 0: return z, pooled_output, extra return z, pooled_output def encode(self, tokens: list) -> tuple: """#### Encode the input tokens. #### Args: - `tokens` (list): The input tokens. #### Returns: - `tuple`: The encoded tokens and the pooled output. """ return self(tokens) def load_sd(self, sd: dict) -> None: """#### Load the state dictionary. #### Args: - `sd` (dict): The state dictionary. """ return self.transformer.load_state_dict(sd, strict=False) class SD1ClipModel(torch.nn.Module): """#### Class representing the SD1ClipModel.""" def __init__( self, device: str = "cpu", dtype: torch.dtype = None, clip_name: str = "l", clip_model: type = SDClipModel, **kwargs ): """#### Initialize the SD1ClipModel. #### Args: - `device` (str, optional): The device to use. Defaults to "cpu". - `dtype` (torch.dtype, optional): The data type. Defaults to None. - `clip_name` (str, optional): The name of the CLIP model. Defaults to "l". - `clip_model` (type, optional): The CLIP model class. Defaults to SDClipModel. - `**kwargs`: Additional keyword arguments. """ super().__init__() self.clip_name = clip_name self.clip = "clip_{}".format(self.clip_name) self.lowvram_patch_counter = 0 self.model_loaded_weight_memory = 0 setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) def set_clip_options(self, options: dict) -> None: """#### Set the CLIP options. #### Args: - `options` (dict): The options to set. """ getattr(self, self.clip).set_clip_options(options) def reset_clip_options(self) -> None: """#### Reset the CLIP options to default.""" getattr(self, self.clip).reset_clip_options() def encode_token_weights(self, token_weight_pairs: dict) -> tuple: """#### Encode token weights. #### Args: - `token_weight_pairs` (dict): The token weight pairs. #### Returns: - `tuple`: The encoded tokens and the pooled output. """ token_weight_pairs = token_weight_pairs[self.clip_name] out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) return out, pooled