Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import logging
import os
import traceback
import torch
from transformers import CLIPTokenizerFast
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
return tokenizer_data, model_options
def parse_parentheses(string: str) -> list:
"""#### Parse a string with nested parentheses.
#### Args:
- `string` (str): The input string.
#### Returns:
- `list`: The parsed list of strings.
"""
result = []
current_item = ""
nesting_level = 0
for char in string:
if char == "(":
if nesting_level == 0:
if current_item:
result.append(current_item)
current_item = "("
else:
current_item = "("
else:
current_item += char
nesting_level += 1
elif char == ")":
nesting_level -= 1
if nesting_level == 0:
result.append(current_item + ")")
current_item = ""
else:
current_item += char
else:
current_item += char
if current_item:
result.append(current_item)
return result
def token_weights(string: str, current_weight: float) -> list:
"""#### Parse a string into tokens with weights.
#### Args:
- `string` (str): The input string.
- `current_weight` (float): The current weight.
#### Returns:
- `list`: The list of token-weight pairs.
"""
a = parse_parentheses(string)
out = []
for x in a:
weight = current_weight
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x = x[1:-1]
xx = x.rfind(":")
weight *= 1.1
if xx > 0:
try:
weight = float(x[xx + 1 :])
x = x[:xx]
except:
pass
out += token_weights(x, weight)
else:
out += [(x, current_weight)]
return out
def escape_important(text: str) -> str:
"""#### Escape important characters in a string.
#### Args:
- `text` (str): The input text.
#### Returns:
- `str`: The escaped text.
"""
text = text.replace("\\)", "\0\1")
text = text.replace("\\(", "\0\2")
return text
def unescape_important(text: str) -> str:
"""#### Unescape important characters in a string.
#### Args:
- `text` (str): The input text.
#### Returns:
- `str`: The unescaped text.
"""
text = text.replace("\0\1", ")")
text = text.replace("\0\2", "(")
return text
def expand_directory_list(directories: list) -> list:
"""#### Expand a list of directories to include all subdirectories.
#### Args:
- `directories` (list): The list of directories.
#### Returns:
- `list`: The expanded list of directories.
"""
dirs = set()
for x in directories:
dirs.add(x)
for root, subdir, file in os.walk(x, followlinks=True):
dirs.add(root)
return list(dirs)
def load_embed(embedding_name: str, embedding_directory: list, embedding_size: int, embed_key: str = None) -> torch.Tensor:
"""#### Load an embedding from a directory.
#### Args:
- `embedding_name` (str): The name of the embedding.
- `embedding_directory` (list): The list of directories to search.
- `embedding_size` (int): The size of the embedding.
- `embed_key` (str, optional): The key for the embedding. Defaults to None.
#### Returns:
- `torch.Tensor`: The loaded embedding.
"""
if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory]
embedding_directory = expand_directory_list(embedding_directory)
valid_file = None
for embed_dir in embedding_directory:
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
embed_dir = os.path.abspath(embed_dir)
try:
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
continue
except:
continue
if not os.path.isfile(embed_path):
extensions = [".safetensors", ".pt", ".bin"]
for x in extensions:
t = embed_path + x
if os.path.isfile(t):
valid_file = t
break
else:
valid_file = embed_path
if valid_file is not None:
break
if valid_file is None:
return None
embed_path = valid_file
embed_out = None
try:
if embed_path.lower().endswith(".safetensors"):
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
if "weights_only" in torch.load.__code__.co_varnames:
embed = torch.load(embed_path, weights_only=True, map_location="cpu")
else:
embed = torch.load(embed_path, map_location="cpu")
except Exception:
logging.warning(
"{}\n\nerror loading embedding, skipping loading: {}".format(
traceback.format_exc(), embedding_name
)
)
return None
if embed_out is None:
if "string_to_param" in embed:
values = embed["string_to_param"].values()
embed_out = next(iter(values))
elif isinstance(embed, list):
out_list = []
for x in range(len(embed)):
for k in embed[x]:
t = embed[x][k]
if t.shape[-1] != embedding_size:
continue
out_list.append(t.reshape(-1, t.shape[-1]))
embed_out = torch.cat(out_list, dim=0)
elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key]
else:
values = embed.values()
embed_out = next(iter(values))
return embed_out
class SDTokenizer:
"""#### Class representing a Stable Diffusion tokenizer."""
def __init__(
self,
tokenizer_path: str = None,
max_length: int = 77,
pad_with_end: bool = True,
embedding_directory: str = None,
embedding_size: int = 768,
embedding_key: str = "clip_l",
tokenizer_class: type = CLIPTokenizerFast,
has_start_token: bool = True,
pad_to_max_length: bool = True,
min_length: int = None,
):
"""#### Initialize the SDTokenizer.
#### Args:
- `tokenizer_path` (str, optional): The path to the tokenizer. Defaults to None.
- `max_length` (int, optional): The maximum length of the input. Defaults to 77.
- `pad_with_end` (bool, optional): Whether to pad with the end token. Defaults to True.
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
- `embedding_size` (int, optional): The size of the embeddings. Defaults to 768.
- `embedding_key` (str, optional): The key for the embeddings. Defaults to "clip_l".
- `tokenizer_class` (type, optional): The tokenizer class. Defaults to CLIPTokenizer.
- `has_start_token` (bool, optional): Whether the tokenizer has a start token. Defaults to True.
- `pad_to_max_length` (bool, optional): Whether to pad to the maximum length. Defaults to True.
- `min_length` (int, optional): The minimum length of the input. Defaults to None.
"""
if tokenizer_path is None:
tokenizer_path = "_internal/sd1_tokenizer/"
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
self.max_word_length = 8
self.embedding_identifier = "embedding:"
self.embedding_size = embedding_size
self.embedding_key = embedding_key
def _try_get_embedding(self, embedding_name: str) -> tuple:
"""#### Try to get an embedding.
#### Args:
- `embedding_name` (str): The name of the embedding.
#### Returns:
- `tuple`: The embedding and any leftover text.
"""
embed = load_embed(
embedding_name,
self.embedding_directory,
self.embedding_size,
self.embedding_key,
)
if embed is None:
stripped = embedding_name.strip(",")
if len(stripped) < len(embedding_name):
embed = load_embed(
stripped,
self.embedding_directory,
self.embedding_size,
self.embedding_key,
)
return (embed, embedding_name[len(stripped) :])
return (embed, "")
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> list:
"""#### Tokenize text with weights.
#### Args:
- `text` (str): The input text.
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
#### Returns:
- `list`: The tokenized text with weights.
"""
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []
for weighted_segment, weight in parsed_weights:
to_tokenize = (
unescape_important(weighted_segment).replace("\n", " ").split(" ")
)
to_tokenize = [x for x in to_tokenize if x != ""]
for word in to_tokenize:
# if we find an embedding, deal with the embedding
if (
word.startswith(self.embedding_identifier)
and self.embedding_directory is not None
):
embedding_name = word[len(self.embedding_identifier) :].strip("\n")
embed, leftover = self._try_get_embedding(embedding_name)
if embed is None:
logging.warning(
f"warning, embedding:{embedding_name} does not exist, ignoring"
)
else:
if len(embed.shape) == 1:
tokens.append([(embed, weight)])
else:
tokens.append(
[(embed[x], weight) for x in range(embed.shape[0])]
)
print("loading ", embedding_name)
# if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else:
continue
# parse word
tokens.append(
[
(t, weight)
for t in self.tokenizer(word)["input_ids"][
self.tokens_start : -1
]
]
)
# reshape token array to CLIP input size
batched_tokens = []
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
# determine if we're going to try and keep the tokens in a single batch
is_large = len(t_group) >= self.max_word_length
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
# break word in two and add end token
if is_large:
batch.extend(
[(t, w, i + 1) for t, w in t_group[:remaining_length]]
)
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
# add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
# start new batch
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t, w, i + 1) for t, w in t_group])
t_group = []
# fill last batch
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
return batched_tokens
def untokenize(self, token_weight_pair: list) -> list:
"""#### Untokenize a list of token-weight pairs.
#### Args:
- `token_weight_pair` (list): The list of token-weight pairs.
#### Returns:
- `list`: The untokenized list.
"""
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
class SD1Tokenizer:
"""#### Class representing the SD1Tokenizer."""
def __init__(self, embedding_directory: str = None, clip_name: str = "l", tokenizer: type = SDTokenizer):
"""#### Initialize the SD1Tokenizer.
#### Args:
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None.
- `clip_name` (str, optional): The name of the CLIP model. Defaults to "l".
- `tokenizer` (type, optional): The tokenizer class. Defaults to SDTokenizer.
"""
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict:
"""#### Tokenize text with weights.
#### Args:
- `text` (str): The input text.
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False.
#### Returns:
- `dict`: The tokenized text with weights.
"""
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(
text, return_word_ids
)
return out
def untokenize(self, token_weight_pair: list) -> list:
"""#### Untokenize a list of token-weight pairs.
#### Args:
- `token_weight_pair` (list): The list of token-weight pairs.
#### Returns:
- `list`: The untokenized list.
"""
return getattr(self, self.clip).untokenize(token_weight_pair)