Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,528 Bytes
d9a2e19 1d117d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from modules.BlackForest import Flux
from modules.Utilities import util
from modules.Model import ModelBase
from modules.SD15 import SDClip, SDToken
from modules.Utilities import Latent
from modules.clip import Clip
class sm_SD15(ModelBase.BASE):
"""#### Class representing the SD15 model.
#### Args:
- `ModelBase.BASE` (ModelBase.BASE): The base model class.
"""
unet_config: dict = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config: dict = {
"num_heads": 8,
"num_head_channels": -1,
}
latent_format: Latent.SD15 = Latent.SD15
def process_clip_state_dict(self, state_dict: dict) -> dict:
"""#### Process the state dictionary for the CLIP model.
#### Args:
- `state_dict` (dict): The state dictionary.
#### Returns:
- `dict`: The processed state dictionary.
"""
k = list(state_dict.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith(
"cond_stage_model.transformer.text_model."
):
y = x.replace(
"cond_stage_model.transformer.",
"cond_stage_model.transformer.text_model.",
)
state_dict[y] = state_dict.pop(x)
if (
"cond_stage_model.transformer.text_model.embeddings.position_ids"
in state_dict
):
ids = state_dict[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
]
if ids.dtype == torch.float32:
state_dict[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "clip_l."
state_dict = util.state_dict_prefix_replace(
state_dict, replace_prefix, filter_keys=True
)
return state_dict
def clip_target(self) -> Clip.ClipTarget:
"""#### Get the target CLIP model.
#### Returns:
- `Clip.ClipTarget`: The target CLIP model.
"""
return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
models = [
sm_SD15, Flux.Flux
] |