File size: 7,021 Bytes
d0ffe9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import logging
from pathlib import Path
from typing import Optional, Union
import torch
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from torch import Tensor
from animatediff import get_dir
EMBED_DIR = get_dir("data").joinpath("embeddings")
EMBED_DIR_SDXL = get_dir("data").joinpath("sdxl_embeddings")
EMBED_EXTS = [".pt", ".pth", ".bin", ".safetensors"]
logger = logging.getLogger(__name__)
def scan_text_embeddings(is_sdxl=False) -> list[Path]:
embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR
return [x for x in embed_dir.rglob("**/*") if x.is_file() and x.suffix.lower() in EMBED_EXTS]
def get_text_embeddings(return_tensors: bool = True, is_sdxl:bool = False) -> dict[str, Union[Tensor, Path]]:
embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR
embeds = {}
skipped = {}
path: Path
for path in scan_text_embeddings(is_sdxl):
if path.stem not in embeds:
# new token/name, add it
logger.debug(f"Found embedding token {path.stem} at {path.relative_to(embed_dir)}")
embeds[path.stem] = path
else:
# duplicate token/name, skip it
skipped[path.stem] = path
logger.debug(f"Duplicate embedding token {path.stem} at {path.relative_to(embed_dir)}")
# warn the user if there are duplicates we skipped
if skipped:
logger.warn(f"Skipped {len(skipped)} embeddings with duplicate tokens!")
logger.warn(f"Skipped paths: {[x.relative_to(embed_dir) for x in skipped.values()]}")
logger.warn("Rename these files to avoid collisions!")
# we can optionally return the tensors instead of the paths
if return_tensors:
# load the embeddings
embeds = {k: load_embed_weights(v) for k, v in embeds.items()}
# filter out the ones that failed to load
loaded_embeds = {k: v for k, v in embeds.items() if v is not None}
if len(loaded_embeds) != len(embeds):
logger.warn(f"Failed to load {len(embeds) - len(loaded_embeds)} embeddings!")
logger.warn(f"Skipped embeddings: {[x for x in embeds.keys() if x not in loaded_embeds]}")
# return a dict of {token: path | embedding}
return embeds
def load_embed_weights(path: Path, key: Optional[str] = None) -> Optional[Tensor]:
"""Load an embedding from a file.
Accepts an optional key to load a specific embedding from a file with multiple embeddings, otherwise
it will try to load the first one it finds.
"""
if not path.exists() and path.is_file():
raise ValueError(f"Embedding path {path} does not exist or is not a file!")
try:
if path.suffix.lower() == ".safetensors":
state_dict = load_file(path, device="cpu")
elif path.suffix.lower() in EMBED_EXTS:
state_dict = torch.load(path, weights_only=True, map_location="cpu")
except Exception:
logger.error(f"Failed to load embedding {path}", exc_info=True)
return None
embedding = None
if len(state_dict) == 1:
logger.debug(f"Found single key in {path.stem}, using it")
embedding = next(iter(state_dict.values()))
elif key is not None and key in state_dict:
logger.debug(f"Using passed key {key} for {path.stem}")
embedding = state_dict[key]
elif "string_to_param" in state_dict:
logger.debug(f"A1111 style embedding found for {path.stem}")
embedding = next(iter(state_dict["string_to_param"].values()))
else:
# we couldn't find the embedding key, warn the user and just use the first key that's a Tensor
logger.warn(f"Could not find embedding key in {path.stem}!")
logger.warn("Taking a wild guess and using the first Tensor we find...")
for key, value in state_dict.items():
if torch.is_tensor(value):
embedding = value
logger.warn(f"Using key: {key}")
break
return embedding
def load_text_embeddings(
pipeline: DiffusionPipeline, text_embeds: Optional[tuple[str, torch.Tensor]] = None, is_sdxl = False
) -> None:
if text_embeds is None:
text_embeds = get_text_embeddings(False, is_sdxl)
if len(text_embeds) < 1:
logger.info("No TI embeddings found")
return
logger.info(f"Loading {len(text_embeds)} TI embeddings...")
loaded, skipped, failed = [], [], []
if True:
vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings
for token, emb_path in text_embeds.items():
try:
if token not in vocab:
if is_sdxl:
embed = load_embed_weights(emb_path, "clip_g").to(pipeline.text_encoder_2.device)
pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
embed = load_embed_weights(emb_path, "clip_l").to(pipeline.text_encoder.device)
pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
else:
embed = load_embed_weights(emb_path).to(pipeline.text_encoder.device)
pipeline.load_textual_inversion({token: embed})
logger.debug(f"Loaded embedding '{token}'")
loaded.append(token)
else:
logger.debug(f"Skipping embedding '{token}' (already loaded)")
skipped.append(token)
except Exception:
logger.error(f"Failed to load TI embedding: {token}", exc_info=True)
failed.append(token)
else:
vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings
for token, embed in text_embeds.items():
try:
if token not in vocab:
if is_sdxl:
pipeline.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
else:
pipeline.load_textual_inversion({token: embed})
logger.debug(f"Loaded embedding '{token}'")
loaded.append(token)
else:
logger.debug(f"Skipping embedding '{token}' (already loaded)")
skipped.append(token)
except Exception:
logger.error(f"Failed to load TI embedding: {token}", exc_info=True)
failed.append(token)
# Print a summary of what we loaded
logger.info(f"Loaded {len(loaded)} embeddings, {len(skipped)} existing, {len(failed)} failed")
logger.info(f"Available embeddings: {', '.join(loaded + skipped)}")
if len(failed) > 0:
# only print failed if there were failures
logger.warn(f"Failed to load embeddings: {', '.join(failed)}")
|