Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import traceback | |
import torch | |
from transformers import CLIPTokenizerFast | |
def model_options_long_clip(sd, tokenizer_data, model_options): | |
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None) | |
if w is None: | |
w = sd.get("text_model.embeddings.position_embedding.weight", None) | |
return tokenizer_data, model_options | |
def parse_parentheses(string: str) -> list: | |
"""#### Parse a string with nested parentheses. | |
#### Args: | |
- `string` (str): The input string. | |
#### Returns: | |
- `list`: The parsed list of strings. | |
""" | |
result = [] | |
current_item = "" | |
nesting_level = 0 | |
for char in string: | |
if char == "(": | |
if nesting_level == 0: | |
if current_item: | |
result.append(current_item) | |
current_item = "(" | |
else: | |
current_item = "(" | |
else: | |
current_item += char | |
nesting_level += 1 | |
elif char == ")": | |
nesting_level -= 1 | |
if nesting_level == 0: | |
result.append(current_item + ")") | |
current_item = "" | |
else: | |
current_item += char | |
else: | |
current_item += char | |
if current_item: | |
result.append(current_item) | |
return result | |
def token_weights(string: str, current_weight: float) -> list: | |
"""#### Parse a string into tokens with weights. | |
#### Args: | |
- `string` (str): The input string. | |
- `current_weight` (float): The current weight. | |
#### Returns: | |
- `list`: The list of token-weight pairs. | |
""" | |
a = parse_parentheses(string) | |
out = [] | |
for x in a: | |
weight = current_weight | |
if len(x) >= 2 and x[-1] == ")" and x[0] == "(": | |
x = x[1:-1] | |
xx = x.rfind(":") | |
weight *= 1.1 | |
if xx > 0: | |
try: | |
weight = float(x[xx + 1 :]) | |
x = x[:xx] | |
except: | |
pass | |
out += token_weights(x, weight) | |
else: | |
out += [(x, current_weight)] | |
return out | |
def escape_important(text: str) -> str: | |
"""#### Escape important characters in a string. | |
#### Args: | |
- `text` (str): The input text. | |
#### Returns: | |
- `str`: The escaped text. | |
""" | |
text = text.replace("\\)", "\0\1") | |
text = text.replace("\\(", "\0\2") | |
return text | |
def unescape_important(text: str) -> str: | |
"""#### Unescape important characters in a string. | |
#### Args: | |
- `text` (str): The input text. | |
#### Returns: | |
- `str`: The unescaped text. | |
""" | |
text = text.replace("\0\1", ")") | |
text = text.replace("\0\2", "(") | |
return text | |
def expand_directory_list(directories: list) -> list: | |
"""#### Expand a list of directories to include all subdirectories. | |
#### Args: | |
- `directories` (list): The list of directories. | |
#### Returns: | |
- `list`: The expanded list of directories. | |
""" | |
dirs = set() | |
for x in directories: | |
dirs.add(x) | |
for root, subdir, file in os.walk(x, followlinks=True): | |
dirs.add(root) | |
return list(dirs) | |
def load_embed(embedding_name: str, embedding_directory: list, embedding_size: int, embed_key: str = None) -> torch.Tensor: | |
"""#### Load an embedding from a directory. | |
#### Args: | |
- `embedding_name` (str): The name of the embedding. | |
- `embedding_directory` (list): The list of directories to search. | |
- `embedding_size` (int): The size of the embedding. | |
- `embed_key` (str, optional): The key for the embedding. Defaults to None. | |
#### Returns: | |
- `torch.Tensor`: The loaded embedding. | |
""" | |
if isinstance(embedding_directory, str): | |
embedding_directory = [embedding_directory] | |
embedding_directory = expand_directory_list(embedding_directory) | |
valid_file = None | |
for embed_dir in embedding_directory: | |
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) | |
embed_dir = os.path.abspath(embed_dir) | |
try: | |
if os.path.commonpath((embed_dir, embed_path)) != embed_dir: | |
continue | |
except: | |
continue | |
if not os.path.isfile(embed_path): | |
extensions = [".safetensors", ".pt", ".bin"] | |
for x in extensions: | |
t = embed_path + x | |
if os.path.isfile(t): | |
valid_file = t | |
break | |
else: | |
valid_file = embed_path | |
if valid_file is not None: | |
break | |
if valid_file is None: | |
return None | |
embed_path = valid_file | |
embed_out = None | |
try: | |
if embed_path.lower().endswith(".safetensors"): | |
import safetensors.torch | |
embed = safetensors.torch.load_file(embed_path, device="cpu") | |
else: | |
if "weights_only" in torch.load.__code__.co_varnames: | |
embed = torch.load(embed_path, weights_only=True, map_location="cpu") | |
else: | |
embed = torch.load(embed_path, map_location="cpu") | |
except Exception: | |
logging.warning( | |
"{}\n\nerror loading embedding, skipping loading: {}".format( | |
traceback.format_exc(), embedding_name | |
) | |
) | |
return None | |
if embed_out is None: | |
if "string_to_param" in embed: | |
values = embed["string_to_param"].values() | |
embed_out = next(iter(values)) | |
elif isinstance(embed, list): | |
out_list = [] | |
for x in range(len(embed)): | |
for k in embed[x]: | |
t = embed[x][k] | |
if t.shape[-1] != embedding_size: | |
continue | |
out_list.append(t.reshape(-1, t.shape[-1])) | |
embed_out = torch.cat(out_list, dim=0) | |
elif embed_key is not None and embed_key in embed: | |
embed_out = embed[embed_key] | |
else: | |
values = embed.values() | |
embed_out = next(iter(values)) | |
return embed_out | |
class SDTokenizer: | |
"""#### Class representing a Stable Diffusion tokenizer.""" | |
def __init__( | |
self, | |
tokenizer_path: str = None, | |
max_length: int = 77, | |
pad_with_end: bool = True, | |
embedding_directory: str = None, | |
embedding_size: int = 768, | |
embedding_key: str = "clip_l", | |
tokenizer_class: type = CLIPTokenizerFast, | |
has_start_token: bool = True, | |
pad_to_max_length: bool = True, | |
min_length: int = None, | |
): | |
"""#### Initialize the SDTokenizer. | |
#### Args: | |
- `tokenizer_path` (str, optional): The path to the tokenizer. Defaults to None. | |
- `max_length` (int, optional): The maximum length of the input. Defaults to 77. | |
- `pad_with_end` (bool, optional): Whether to pad with the end token. Defaults to True. | |
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None. | |
- `embedding_size` (int, optional): The size of the embeddings. Defaults to 768. | |
- `embedding_key` (str, optional): The key for the embeddings. Defaults to "clip_l". | |
- `tokenizer_class` (type, optional): The tokenizer class. Defaults to CLIPTokenizer. | |
- `has_start_token` (bool, optional): Whether the tokenizer has a start token. Defaults to True. | |
- `pad_to_max_length` (bool, optional): Whether to pad to the maximum length. Defaults to True. | |
- `min_length` (int, optional): The minimum length of the input. Defaults to None. | |
""" | |
if tokenizer_path is None: | |
tokenizer_path = "_internal/sd1_tokenizer/" | |
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path) | |
self.max_length = max_length | |
self.min_length = min_length | |
empty = self.tokenizer("")["input_ids"] | |
if has_start_token: | |
self.tokens_start = 1 | |
self.start_token = empty[0] | |
self.end_token = empty[1] | |
else: | |
self.tokens_start = 0 | |
self.start_token = None | |
self.end_token = empty[0] | |
self.pad_with_end = pad_with_end | |
self.pad_to_max_length = pad_to_max_length | |
vocab = self.tokenizer.get_vocab() | |
self.inv_vocab = {v: k for k, v in vocab.items()} | |
self.embedding_directory = embedding_directory | |
self.max_word_length = 8 | |
self.embedding_identifier = "embedding:" | |
self.embedding_size = embedding_size | |
self.embedding_key = embedding_key | |
def _try_get_embedding(self, embedding_name: str) -> tuple: | |
"""#### Try to get an embedding. | |
#### Args: | |
- `embedding_name` (str): The name of the embedding. | |
#### Returns: | |
- `tuple`: The embedding and any leftover text. | |
""" | |
embed = load_embed( | |
embedding_name, | |
self.embedding_directory, | |
self.embedding_size, | |
self.embedding_key, | |
) | |
if embed is None: | |
stripped = embedding_name.strip(",") | |
if len(stripped) < len(embedding_name): | |
embed = load_embed( | |
stripped, | |
self.embedding_directory, | |
self.embedding_size, | |
self.embedding_key, | |
) | |
return (embed, embedding_name[len(stripped) :]) | |
return (embed, "") | |
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> list: | |
"""#### Tokenize text with weights. | |
#### Args: | |
- `text` (str): The input text. | |
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False. | |
#### Returns: | |
- `list`: The tokenized text with weights. | |
""" | |
if self.pad_with_end: | |
pad_token = self.end_token | |
else: | |
pad_token = 0 | |
text = escape_important(text) | |
parsed_weights = token_weights(text, 1.0) | |
# tokenize words | |
tokens = [] | |
for weighted_segment, weight in parsed_weights: | |
to_tokenize = ( | |
unescape_important(weighted_segment).replace("\n", " ").split(" ") | |
) | |
to_tokenize = [x for x in to_tokenize if x != ""] | |
for word in to_tokenize: | |
# if we find an embedding, deal with the embedding | |
if ( | |
word.startswith(self.embedding_identifier) | |
and self.embedding_directory is not None | |
): | |
embedding_name = word[len(self.embedding_identifier) :].strip("\n") | |
embed, leftover = self._try_get_embedding(embedding_name) | |
if embed is None: | |
logging.warning( | |
f"warning, embedding:{embedding_name} does not exist, ignoring" | |
) | |
else: | |
if len(embed.shape) == 1: | |
tokens.append([(embed, weight)]) | |
else: | |
tokens.append( | |
[(embed[x], weight) for x in range(embed.shape[0])] | |
) | |
print("loading ", embedding_name) | |
# if we accidentally have leftover text, continue parsing using leftover, else move on to next word | |
if leftover != "": | |
word = leftover | |
else: | |
continue | |
# parse word | |
tokens.append( | |
[ | |
(t, weight) | |
for t in self.tokenizer(word)["input_ids"][ | |
self.tokens_start : -1 | |
] | |
] | |
) | |
# reshape token array to CLIP input size | |
batched_tokens = [] | |
batch = [] | |
if self.start_token is not None: | |
batch.append((self.start_token, 1.0, 0)) | |
batched_tokens.append(batch) | |
for i, t_group in enumerate(tokens): | |
# determine if we're going to try and keep the tokens in a single batch | |
is_large = len(t_group) >= self.max_word_length | |
while len(t_group) > 0: | |
if len(t_group) + len(batch) > self.max_length - 1: | |
remaining_length = self.max_length - len(batch) - 1 | |
# break word in two and add end token | |
if is_large: | |
batch.extend( | |
[(t, w, i + 1) for t, w in t_group[:remaining_length]] | |
) | |
batch.append((self.end_token, 1.0, 0)) | |
t_group = t_group[remaining_length:] | |
# add end token and pad | |
else: | |
batch.append((self.end_token, 1.0, 0)) | |
if self.pad_to_max_length: | |
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) | |
# start new batch | |
batch = [] | |
if self.start_token is not None: | |
batch.append((self.start_token, 1.0, 0)) | |
batched_tokens.append(batch) | |
else: | |
batch.extend([(t, w, i + 1) for t, w in t_group]) | |
t_group = [] | |
# fill last batch | |
batch.append((self.end_token, 1.0, 0)) | |
if self.pad_to_max_length: | |
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) | |
if self.min_length is not None and len(batch) < self.min_length: | |
batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) | |
if not return_word_ids: | |
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens] | |
return batched_tokens | |
def untokenize(self, token_weight_pair: list) -> list: | |
"""#### Untokenize a list of token-weight pairs. | |
#### Args: | |
- `token_weight_pair` (list): The list of token-weight pairs. | |
#### Returns: | |
- `list`: The untokenized list. | |
""" | |
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) | |
class SD1Tokenizer: | |
"""#### Class representing the SD1Tokenizer.""" | |
def __init__(self, embedding_directory: str = None, clip_name: str = "l", tokenizer: type = SDTokenizer): | |
"""#### Initialize the SD1Tokenizer. | |
#### Args: | |
- `embedding_directory` (str, optional): The directory for embeddings. Defaults to None. | |
- `clip_name` (str, optional): The name of the CLIP model. Defaults to "l". | |
- `tokenizer` (type, optional): The tokenizer class. Defaults to SDTokenizer. | |
""" | |
self.clip_name = clip_name | |
self.clip = "clip_{}".format(self.clip_name) | |
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) | |
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict: | |
"""#### Tokenize text with weights. | |
#### Args: | |
- `text` (str): The input text. | |
- `return_word_ids` (bool, optional): Whether to return word IDs. Defaults to False. | |
#### Returns: | |
- `dict`: The tokenized text with weights. | |
""" | |
out = {} | |
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights( | |
text, return_word_ids | |
) | |
return out | |
def untokenize(self, token_weight_pair: list) -> list: | |
"""#### Untokenize a list of token-weight pairs. | |
#### Args: | |
- `token_weight_pair` (list): The list of token-weight pairs. | |
#### Returns: | |
- `list`: The untokenized list. | |
""" | |
return getattr(self, self.clip).untokenize(token_weight_pair) |