File size: 2,528 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d117d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from modules.BlackForest import Flux
from modules.Utilities import util
from modules.Model import ModelBase
from modules.SD15 import SDClip, SDToken
from modules.Utilities import Latent
from modules.clip import Clip


class sm_SD15(ModelBase.BASE):
    """#### Class representing the SD15 model.



    #### Args:

        - `ModelBase.BASE` (ModelBase.BASE): The base model class.

    """

    unet_config: dict = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
        "use_temporal_attention": False,
    }

    unet_extra_config: dict = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

    latent_format: Latent.SD15 = Latent.SD15

    def process_clip_state_dict(self, state_dict: dict) -> dict:
        """#### Process the state dictionary for the CLIP model.



        #### Args:

            - `state_dict` (dict): The state dictionary.



        #### Returns:

            - `dict`: The processed state dictionary.

        """
        k = list(state_dict.keys())
        for x in k:
            if x.startswith("cond_stage_model.transformer.") and not x.startswith(
                "cond_stage_model.transformer.text_model."
            ):
                y = x.replace(
                    "cond_stage_model.transformer.",
                    "cond_stage_model.transformer.text_model.",
                )
                state_dict[y] = state_dict.pop(x)

        if (
            "cond_stage_model.transformer.text_model.embeddings.position_ids"
            in state_dict
        ):
            ids = state_dict[
                "cond_stage_model.transformer.text_model.embeddings.position_ids"
            ]
            if ids.dtype == torch.float32:
                state_dict[
                    "cond_stage_model.transformer.text_model.embeddings.position_ids"
                ] = ids.round()

        replace_prefix = {}
        replace_prefix["cond_stage_model."] = "clip_l."
        state_dict = util.state_dict_prefix_replace(
            state_dict, replace_prefix, filter_keys=True
        )
        return state_dict

    def clip_target(self) -> Clip.ClipTarget:
        """#### Get the target CLIP model.



        #### Returns:

            - `Clip.ClipTarget`: The target CLIP model.

        """
        return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
    
models = [
    sm_SD15, Flux.Flux
]