Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import torch
from modules.BlackForest import Flux
from modules.Utilities import util
from modules.Model import ModelBase
from modules.SD15 import SDClip, SDToken
from modules.Utilities import Latent
from modules.clip import Clip
class sm_SD15(ModelBase.BASE):
"""#### Class representing the SD15 model.
#### Args:
- `ModelBase.BASE` (ModelBase.BASE): The base model class.
"""
unet_config: dict = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config: dict = {
"num_heads": 8,
"num_head_channels": -1,
}
latent_format: Latent.SD15 = Latent.SD15
def process_clip_state_dict(self, state_dict: dict) -> dict:
"""#### Process the state dictionary for the CLIP model.
#### Args:
- `state_dict` (dict): The state dictionary.
#### Returns:
- `dict`: The processed state dictionary.
"""
k = list(state_dict.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith(
"cond_stage_model.transformer.text_model."
):
y = x.replace(
"cond_stage_model.transformer.",
"cond_stage_model.transformer.text_model.",
)
state_dict[y] = state_dict.pop(x)
if (
"cond_stage_model.transformer.text_model.embeddings.position_ids"
in state_dict
):
ids = state_dict[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
]
if ids.dtype == torch.float32:
state_dict[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
] = ids.round()
replace_prefix = {}
replace_prefix["cond_stage_model."] = "clip_l."
state_dict = util.state_dict_prefix_replace(
state_dict, replace_prefix, filter_keys=True
)
return state_dict
def clip_target(self) -> Clip.ClipTarget:
"""#### Get the target CLIP model.
#### Returns:
- `Clip.ClipTarget`: The target CLIP model.
"""
return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
models = [
sm_SD15, Flux.Flux
]