File size: 7,021 Bytes
d0ffe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
from pathlib import Path
from typing import Optional, Union

import torch
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
from torch import Tensor

from animatediff import get_dir

EMBED_DIR = get_dir("data").joinpath("embeddings")
EMBED_DIR_SDXL = get_dir("data").joinpath("sdxl_embeddings")
EMBED_EXTS = [".pt", ".pth", ".bin", ".safetensors"]

logger = logging.getLogger(__name__)


def scan_text_embeddings(is_sdxl=False) -> list[Path]:
    embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR
    return [x for x in embed_dir.rglob("**/*") if x.is_file() and x.suffix.lower() in EMBED_EXTS]


def get_text_embeddings(return_tensors: bool = True, is_sdxl:bool = False) -> dict[str, Union[Tensor, Path]]:
    embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR
    embeds = {}
    skipped = {}
    path: Path
    for path in scan_text_embeddings(is_sdxl):
        if path.stem not in embeds:
            # new token/name, add it
            logger.debug(f"Found embedding token {path.stem} at {path.relative_to(embed_dir)}")
            embeds[path.stem] = path
        else:
            # duplicate token/name, skip it
            skipped[path.stem] = path
            logger.debug(f"Duplicate embedding token {path.stem} at {path.relative_to(embed_dir)}")

    # warn the user if there are duplicates we skipped
    if skipped:
        logger.warn(f"Skipped {len(skipped)} embeddings with duplicate tokens!")
        logger.warn(f"Skipped paths: {[x.relative_to(embed_dir) for x in skipped.values()]}")
        logger.warn("Rename these files to avoid collisions!")

    # we can optionally return the tensors instead of the paths
    if return_tensors:
        # load the embeddings
        embeds = {k: load_embed_weights(v) for k, v in embeds.items()}
        # filter out the ones that failed to load
        loaded_embeds = {k: v for k, v in embeds.items() if v is not None}
        if len(loaded_embeds) != len(embeds):
            logger.warn(f"Failed to load {len(embeds) - len(loaded_embeds)} embeddings!")
            logger.warn(f"Skipped embeddings: {[x for x in embeds.keys() if x not in loaded_embeds]}")

    # return a dict of {token: path | embedding}
    return embeds


def load_embed_weights(path: Path, key: Optional[str] = None) -> Optional[Tensor]:
    """Load an embedding from a file.
    Accepts an optional key to load a specific embedding from a file with multiple embeddings, otherwise
    it will try to load the first one it finds.
    """
    if not path.exists() and path.is_file():
        raise ValueError(f"Embedding path {path} does not exist or is not a file!")
    try:
        if path.suffix.lower() == ".safetensors":
            state_dict = load_file(path, device="cpu")
        elif path.suffix.lower() in EMBED_EXTS:
            state_dict = torch.load(path, weights_only=True, map_location="cpu")
    except Exception:
        logger.error(f"Failed to load embedding {path}", exc_info=True)
        return None

    embedding = None
    if len(state_dict) == 1:
        logger.debug(f"Found single key in {path.stem}, using it")
        embedding = next(iter(state_dict.values()))
    elif key is not None and key in state_dict:
        logger.debug(f"Using passed key {key} for {path.stem}")
        embedding = state_dict[key]
    elif "string_to_param" in state_dict:
        logger.debug(f"A1111 style embedding found for {path.stem}")
        embedding = next(iter(state_dict["string_to_param"].values()))
    else:
        # we couldn't find the embedding key, warn the user and just use the first key that's a Tensor
        logger.warn(f"Could not find embedding key in {path.stem}!")
        logger.warn("Taking a wild guess and using the first Tensor we find...")
        for key, value in state_dict.items():
            if torch.is_tensor(value):
                embedding = value
                logger.warn(f"Using key: {key}")
                break

    return embedding


def load_text_embeddings(
    pipeline: DiffusionPipeline, text_embeds: Optional[tuple[str, torch.Tensor]] = None, is_sdxl = False
) -> None:
    if text_embeds is None:
        text_embeds = get_text_embeddings(False, is_sdxl)
    if len(text_embeds) < 1:
        logger.info("No TI embeddings found")
        return

    logger.info(f"Loading {len(text_embeds)} TI embeddings...")
    loaded, skipped, failed = [], [], []

    if True:
        vocab = pipeline.tokenizer.get_vocab()  # get the tokenizer vocab so we can skip loaded embeddings
        for token, emb_path in text_embeds.items():
            try:
                if token not in vocab:
                    if is_sdxl:
                        embed = load_embed_weights(emb_path, "clip_g").to(pipeline.text_encoder_2.device)
                        pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
                        embed = load_embed_weights(emb_path, "clip_l").to(pipeline.text_encoder.device)
                        pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
                    else:
                        embed = load_embed_weights(emb_path).to(pipeline.text_encoder.device)
                        pipeline.load_textual_inversion({token: embed})
                    logger.debug(f"Loaded embedding '{token}'")
                    loaded.append(token)
                else:
                    logger.debug(f"Skipping embedding '{token}' (already loaded)")
                    skipped.append(token)
            except Exception:
                logger.error(f"Failed to load TI embedding: {token}", exc_info=True)
                failed.append(token)

    else:
        vocab = pipeline.tokenizer.get_vocab()  # get the tokenizer vocab so we can skip loaded embeddings
        for token, embed in text_embeds.items():
            try:
                if token not in vocab:
                    if is_sdxl:
                        pipeline.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
                    else:
                        pipeline.load_textual_inversion({token: embed})
                    logger.debug(f"Loaded embedding '{token}'")
                    loaded.append(token)
                else:
                    logger.debug(f"Skipping embedding '{token}' (already loaded)")
                    skipped.append(token)
            except Exception:
                logger.error(f"Failed to load TI embedding: {token}", exc_info=True)
                failed.append(token)

    # Print a summary of what we loaded
    logger.info(f"Loaded {len(loaded)} embeddings, {len(skipped)} existing, {len(failed)} failed")
    logger.info(f"Available embeddings: {', '.join(loaded + skipped)}")
    if len(failed) > 0:
        # only print failed if there were failures
        logger.warn(f"Failed to load embeddings: {', '.join(failed)}")