|
import logging |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import torch |
|
from diffusers import DiffusionPipeline |
|
from safetensors.torch import load_file |
|
from torch import Tensor |
|
|
|
from animatediff import get_dir |
|
|
|
EMBED_DIR = get_dir("data").joinpath("embeddings") |
|
EMBED_DIR_SDXL = get_dir("data").joinpath("sdxl_embeddings") |
|
EMBED_EXTS = [".pt", ".pth", ".bin", ".safetensors"] |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def scan_text_embeddings(is_sdxl=False) -> list[Path]: |
|
embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR |
|
return [x for x in embed_dir.rglob("**/*") if x.is_file() and x.suffix.lower() in EMBED_EXTS] |
|
|
|
|
|
def get_text_embeddings(return_tensors: bool = True, is_sdxl:bool = False) -> dict[str, Union[Tensor, Path]]: |
|
embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR |
|
embeds = {} |
|
skipped = {} |
|
path: Path |
|
for path in scan_text_embeddings(is_sdxl): |
|
if path.stem not in embeds: |
|
|
|
logger.debug(f"Found embedding token {path.stem} at {path.relative_to(embed_dir)}") |
|
embeds[path.stem] = path |
|
else: |
|
|
|
skipped[path.stem] = path |
|
logger.debug(f"Duplicate embedding token {path.stem} at {path.relative_to(embed_dir)}") |
|
|
|
|
|
if skipped: |
|
logger.warn(f"Skipped {len(skipped)} embeddings with duplicate tokens!") |
|
logger.warn(f"Skipped paths: {[x.relative_to(embed_dir) for x in skipped.values()]}") |
|
logger.warn("Rename these files to avoid collisions!") |
|
|
|
|
|
if return_tensors: |
|
|
|
embeds = {k: load_embed_weights(v) for k, v in embeds.items()} |
|
|
|
loaded_embeds = {k: v for k, v in embeds.items() if v is not None} |
|
if len(loaded_embeds) != len(embeds): |
|
logger.warn(f"Failed to load {len(embeds) - len(loaded_embeds)} embeddings!") |
|
logger.warn(f"Skipped embeddings: {[x for x in embeds.keys() if x not in loaded_embeds]}") |
|
|
|
|
|
return embeds |
|
|
|
|
|
def load_embed_weights(path: Path, key: Optional[str] = None) -> Optional[Tensor]: |
|
"""Load an embedding from a file. |
|
Accepts an optional key to load a specific embedding from a file with multiple embeddings, otherwise |
|
it will try to load the first one it finds. |
|
""" |
|
if not path.exists() and path.is_file(): |
|
raise ValueError(f"Embedding path {path} does not exist or is not a file!") |
|
try: |
|
if path.suffix.lower() == ".safetensors": |
|
state_dict = load_file(path, device="cpu") |
|
elif path.suffix.lower() in EMBED_EXTS: |
|
state_dict = torch.load(path, weights_only=True, map_location="cpu") |
|
except Exception: |
|
logger.error(f"Failed to load embedding {path}", exc_info=True) |
|
return None |
|
|
|
embedding = None |
|
if len(state_dict) == 1: |
|
logger.debug(f"Found single key in {path.stem}, using it") |
|
embedding = next(iter(state_dict.values())) |
|
elif key is not None and key in state_dict: |
|
logger.debug(f"Using passed key {key} for {path.stem}") |
|
embedding = state_dict[key] |
|
elif "string_to_param" in state_dict: |
|
logger.debug(f"A1111 style embedding found for {path.stem}") |
|
embedding = next(iter(state_dict["string_to_param"].values())) |
|
else: |
|
|
|
logger.warn(f"Could not find embedding key in {path.stem}!") |
|
logger.warn("Taking a wild guess and using the first Tensor we find...") |
|
for key, value in state_dict.items(): |
|
if torch.is_tensor(value): |
|
embedding = value |
|
logger.warn(f"Using key: {key}") |
|
break |
|
|
|
return embedding |
|
|
|
|
|
def load_text_embeddings( |
|
pipeline: DiffusionPipeline, text_embeds: Optional[tuple[str, torch.Tensor]] = None, is_sdxl = False |
|
) -> None: |
|
if text_embeds is None: |
|
text_embeds = get_text_embeddings(False, is_sdxl) |
|
if len(text_embeds) < 1: |
|
logger.info("No TI embeddings found") |
|
return |
|
|
|
logger.info(f"Loading {len(text_embeds)} TI embeddings...") |
|
loaded, skipped, failed = [], [], [] |
|
|
|
if True: |
|
vocab = pipeline.tokenizer.get_vocab() |
|
for token, emb_path in text_embeds.items(): |
|
try: |
|
if token not in vocab: |
|
if is_sdxl: |
|
embed = load_embed_weights(emb_path, "clip_g").to(pipeline.text_encoder_2.device) |
|
pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) |
|
embed = load_embed_weights(emb_path, "clip_l").to(pipeline.text_encoder.device) |
|
pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) |
|
else: |
|
embed = load_embed_weights(emb_path).to(pipeline.text_encoder.device) |
|
pipeline.load_textual_inversion({token: embed}) |
|
logger.debug(f"Loaded embedding '{token}'") |
|
loaded.append(token) |
|
else: |
|
logger.debug(f"Skipping embedding '{token}' (already loaded)") |
|
skipped.append(token) |
|
except Exception: |
|
logger.error(f"Failed to load TI embedding: {token}", exc_info=True) |
|
failed.append(token) |
|
|
|
else: |
|
vocab = pipeline.tokenizer.get_vocab() |
|
for token, embed in text_embeds.items(): |
|
try: |
|
if token not in vocab: |
|
if is_sdxl: |
|
pipeline.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) |
|
else: |
|
pipeline.load_textual_inversion({token: embed}) |
|
logger.debug(f"Loaded embedding '{token}'") |
|
loaded.append(token) |
|
else: |
|
logger.debug(f"Skipping embedding '{token}' (already loaded)") |
|
skipped.append(token) |
|
except Exception: |
|
logger.error(f"Failed to load TI embedding: {token}", exc_info=True) |
|
failed.append(token) |
|
|
|
|
|
logger.info(f"Loaded {len(loaded)} embeddings, {len(skipped)} existing, {len(failed)} failed") |
|
logger.info(f"Available embeddings: {', '.join(loaded + skipped)}") |
|
if len(failed) > 0: |
|
|
|
logger.warn(f"Failed to load embeddings: {', '.join(failed)}") |
|
|