diff --git "a/stable_diffusion_model.py" "b/stable_diffusion_model.py" new file mode 100644--- /dev/null +++ "b/stable_diffusion_model.py" @@ -0,0 +1,2745 @@ +import copy +import gc +import json +import random +import shutil +import typing +from typing import Union, List, Literal, Iterator +import sys +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ + ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from safetensors.torch import save_file, load_file +from torch import autocast +from torch.nn import Parameter +from torch.utils.checkpoint import checkpoint +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.assistant_lora import load_assistant_lora_from_path +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.ip_adapter import IPAdapter +from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ + convert_vae_state_dict, load_vae +from toolkit import train_tools +from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.metadata import get_meta_for_safetensors +from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.sampler import get_sampler +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler +from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from einops import rearrange, repeat +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ + StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ + StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ + StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline +from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT +from huggingface_hub import hf_hub_download + +from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + + +class StableDiffusion: + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + quantize_device=None, + ): + self.custom_pipeline = custom_pipeline + self.device = device + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = torch.device(self.device) + + self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( + model_config.vae_device) + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( + model_config.te_device) + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.unet: Union[None, 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None + self.is_xl = model_config.is_xl + self.is_v2 = model_config.is_v2 + self.is_ssd = model_config.is_ssd + self.is_v3 = model_config.is_v3 + self.is_vega = model_config.is_vega + self.is_pixart = model_config.is_pixart + self.is_auraflow = model_config.is_auraflow + self.is_flux = model_config.is_flux + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): + self.is_flow_matching = True + + self.quantize_device = quantize_device if quantize_device is not None else self.device + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + + def load_model(self): + if self.is_loaded: + return + dtype = get_torch_dtype(self.dtype) + + # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why + # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) + # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) + # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) + + model_path = self.model_config.name_or_path + if 'civitai.com' in self.model_config.name_or_path: + # load is a civit ai model, use the loader. + from toolkit.civitai import get_model_path_from_url + model_path = get_model_path_from_url(self.model_config.name_or_path) + + load_args = {} + if self.noise_scheduler: + load_args['scheduler'] = self.noise_scheduler + + if self.model_config.vae_path is not None: + load_args['vae'] = load_vae(self.model_config.vae_path, dtype) + if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionXLPipeline + # pipln = StableDiffusionKDiffusionXLPipeline + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + ) + + if 'vae' in load_args and load_args['vae'] is not None: + pipe.vae = load_args['vae'] + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + for text_encoder in text_encoders: + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + + if self.model_config.experimental_xl: + print("Experimental XL mode enabled") + print("Loading and injecting alt weights") + # load the mismatched weight and force it in + raw_state_dict = load_file(model_path) + replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() + del raw_state_dict + # get state dict for for 2nd text encoder + te1_state_dict = text_encoders[1].state_dict() + # replace weight with mismatched weight + te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) + flush() + print("Injecting alt weights") + elif self.model_config.is_v3: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusion3Pipeline + + print("Loading SD3 model") + # assume it is the large model + base_model_path = "stabilityai/stable-diffusion-3.5-large" + print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + else: + # is remote use whatever path we were given + base_model_path = model_path + + transformer = SD3Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(torch.device(self.quantize_device), dtype=dtype) + flush() + + if self.model_config.lora_path is not None: + raise ValueError("LoRA is not supported for SD3 models currently") + + if self.model_config.quantize: + quantization_type = qfloat8 + print("Quantizing transformer") + quantize(transformer, weights=quantization_type) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print("Loading t5") + tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) + text_encoder_3 = T5EncoderModel.from_pretrained( + base_model_path, + subfolder="text_encoder_3", + torch_dtype=dtype + ) + + text_encoder_3.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize: + print("Quantizing T5") + quantize(text_encoder_3, weights=qfloat8) + freeze(text_encoder_3) + flush() + + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + try: + # try to load with default diffusers + pipe = pipln.from_pretrained( + base_model_path, + dtype=dtype, + device=self.device_torch, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + transformer=transformer, + # variant="fp16", + use_safetensors=True, + repo_type="model", + ignore_patterns=["*.md", "*..gitattributes"], + **load_args + ) + except Exception as e: + print(f"Error loading from pretrained: {e}") + raise e + + else: + pipe = pipln.from_single_file( + model_path, + transformer=transformer, + device=self.device_torch, + torch_dtype=self.torch_dtype, + tokenizer_3=tokenizer_3, + text_encoder_3=text_encoder_3, + **load_args + ) + + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] + # replace the to function with a no-op since it throws an error instead of a warning + # text_encoders[2].to = lambda *args, **kwargs: None + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + + elif self.model_config.is_pixart: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" + if self.model_config.is_pixart_sigma: + main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = T5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + + if self.model_config.is_pixart_sigma: + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + else: + + # load the transformer only from the save + transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, + subfolder=subfolder) + pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ).to(self.device_torch) + + if self.model_config.unet_sample_size is not None: + pipe.transformer.config.sample_size = self.model_config.unet_sample_size + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + if self.noise_scheduler is None: + self.noise_scheduler = pipe.scheduler + + + elif self.model_config.is_auraflow: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = UMT5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + # load the transformer only from the save + transformer = AuraFlowTransformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + # patch auraflow so it can handle other aspect ratios + # patch_auraflow_pos_embed(pipe.transformer.pos_embed) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer + + elif self.model_config.is_flux: + print("Loading Flux model") + base_model_path = "black-forest-labs/FLUX.1-schnell" + print("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + local_files_only = False + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + # low_cpu_mem_usage=False, + # device_map=None + ) + if not self.low_vram: + # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu + transformer.to(torch.device(self.quantize_device), dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: + raise ValueError("Cannot load both assistant lora and inference lora at the same time") + + if self.model_config.lora_path: + raise ValueError("Cannot load both assistant lora and lora at the same time") + + if not self.is_flux: + raise ValueError("Assistant/ inference lora is only supported for flux models currently") + + load_lora_path = self.model_config.inference_lora_path + if load_lora_path is None: + load_lora_path = self.model_config.assistant_lora_path + + if os.path.isdir(load_lora_path): + load_lora_path = os.path.join( + load_lora_path, "pytorch_lora_weights.safetensors" + ) + elif not os.path.exists(load_lora_path): + print(f"Grabbing lora from the hub: {load_lora_path}") + new_lora_path = hf_hub_download( + load_lora_path, + filename="pytorch_lora_weights.safetensors" + ) + # replace the path + load_lora_path = new_lora_path + + if self.model_config.inference_lora_path is not None: + self.model_config.inference_lora_path = new_lora_path + if self.model_config.assistant_lora_path is not None: + self.model_config.assistant_lora_path = new_lora_path + + if self.model_config.assistant_lora_path is not None: + # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on + # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps + # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half + # so we will merge in now and sample with -1 weight later + self.invert_assistant_lora = True + # trigger it to get merged in + self.model_config.lora_path = self.model_config.assistant_lora_path + + if self.model_config.lora_path is not None: + print("Fusing in LoRA") + # need the pipe for peft + pipe: FluxPipeline = FluxPipeline( + scheduler=None, + text_encoder=None, + tokenizer=None, + text_encoder_2=None, + tokenizer_2=None, + vae=None, + transformer=transformer, + ) + if self.low_vram: + # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts + # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu + # we are going to separate it into the two transformer blocks one at a time + + lora_state_dict = load_file(self.model_config.lora_path) + single_transformer_lora = {} + single_block_key = "transformer.single_transformer_blocks." + double_transformer_lora = {} + double_block_key = "transformer.transformer_blocks." + for key, value in lora_state_dict.items(): + if single_block_key in key: + single_transformer_lora[key] = value + elif double_block_key in key: + double_transformer_lora[key] = value + else: + raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") + + # double blocks + transformer.transformer_blocks = transformer.transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.transformer_blocks = transformer.transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # single blocks + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + torch.device(self.quantize_device), dtype=dtype + ) + pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") + pipe.fuse_lora() + pipe.unload_lora_weights() + transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( + 'cpu', dtype=dtype + ) + + # cleanup + del single_transformer_lora + del double_transformer_lora + del lora_state_dict + flush() + + else: + # need the pipe to do this unfortunately for now + # we have to fuse in the weights before quantizing + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + flush() + + if self.model_config.quantize: + quantization_type = qfloat8 + print("Quantizing transformer") + quantize(transformer, weights=quantization_type) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print("Loading t5") + tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", + torch_dtype=dtype) + + text_encoder_2.to(self.device_torch, dtype=dtype) + flush() + + print("Quantizing T5") + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + flush() + + print("Loading clip") + text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder.to(self.device_torch, dtype=dtype) + + print("making pipe") + pipe: FluxPipeline = FluxPipeline( + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=None, + tokenizer_2=tokenizer_2, + vae=vae, + transformer=None, + ) + pipe.text_encoder_2 = text_encoder_2 + pipe.transformer = transformer + + print("preparing") + + text_encoder = [pipe.text_encoder, pipe.text_encoder_2] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2] + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder[0].to(self.device_torch) + text_encoder[0].requires_grad_(False) + text_encoder[0].eval() + text_encoder[1].to(self.device_torch) + text_encoder[1].requires_grad_(False) + text_encoder[1].eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + else: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusionPipeline + + if self.model_config.text_encoder_bits < 16: + # this is only supported for T5 models for now + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + text_encoder = T5EncoderModel.from_pretrained( + model_path, + subfolder="text_encoder", + torch_dtype=self.te_torch_dtype, + **te_kwargs + ) + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + load_args['text_encoder'] = text_encoder + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + safety_checker=None, + # variant="fp16", + trust_remote_code=True, + **load_args + ) + else: + pipe = pipln.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + load_safety_checker=False, + requires_safety_checker=False, + torch_dtype=self.torch_dtype, + safety_checker=None, + trust_remote_code=True, + **load_args + ) + flush() + + pipe.register_to_config(requires_safety_checker=False) + text_encoder = pipe.text_encoder + text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + tokenizer = pipe.tokenizer + + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = self.noise_scheduler + + # add hacks to unet to help training + # pipe.unet = prepare_unet_for_training(pipe.unet) + + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + # pixart and sd3 dont use a unet + self.unet = pipe.transformer + else: + self.unet: 'UNet2DConditionModel' = pipe.unet + self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) + self.vae.eval() + self.vae.requires_grad_(False) + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + self.vae_scale_factor = VAE_SCALE_FACTOR + self.unet.to(self.device_torch, dtype=dtype) + self.unet.requires_grad_(False) + self.unet.eval() + + # load any loras we have + if self.model_config.lora_path is not None and not self.is_flux: + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + # unfortunately, not an easier way with peft + pipe.unload_lora_weights() + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.pipeline = pipe + self.load_refiner() + self.is_loaded = True + + if self.model_config.assistant_lora_path is not None: + print("Loading assistant lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.assistant_lora_path, self) + + if self.invert_assistant_lora: + # invert and disable during training + self.assistant_lora.multiplier = -1.0 + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.inference_lora_path, self) + # disable during training + self.assistant_lora.is_active = False + + if self.is_pixart and self.vae_scale_factor == 16: + # TODO make our own pipeline? + # we generate an image 2x larger, so we need to copy the sizes from larger ones down + # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN + for key in ASPECT_RATIO_256_BIN.keys(): + ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] + for key in ASPECT_RATIO_512_BIN.keys(): + ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] + for key in ASPECT_RATIO_1024_BIN.keys(): + ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] + for key in ASPECT_RATIO_2048_BIN.keys(): + ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + else: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + else: + self.text_encoder.eval() + + def load_refiner(self): + # for now, we are just going to rely on the TE from the base model + # which is TE2 for SDXL and TE for SD (no refiner currently) + # and completely ignore a TE that may or may not be packaged with the refiner + if self.model_config.refiner_name_or_path is not None: + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config.refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + # variant="fp16", + use_safetensors=True, + ).to(self.device_torch) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device=self.device_torch, + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ).to(self.device_torch) + + self.refiner_unet = refiner.unet + del refiner + flush() + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, + ): + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if self.network is not None: + self.network.eval() + network = self.network + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set([x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and self.network.can_merge_in: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + noise_scheduler = self.noise_scheduler + if sampler is not None: + if sampler.startswith("sample_"): # sample_dpmpp_2m + # using ksampler + noise_scheduler = get_sampler( + 'lms', { + "prediction_type": self.prediction_type, + }) + else: + noise_scheduler = get_sampler( + sampler, + { + "prediction_type": self.prediction_type, + }, + 'sd' if not self.is_pixart else 'pixart' + ) + + try: + noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) + except: + pass + + if sampler.startswith("sample_") and self.is_xl: + # using kdiffusion + Pipe = StableDiffusionKDiffusionXLPipeline + elif self.is_xl: + Pipe = StableDiffusionXLPipeline + elif self.is_v3: + Pipe = StableDiffusion3Pipeline + else: + Pipe = StableDiffusionPipeline + + extra_args = {} + if self.adapter is not None: + if isinstance(self.adapter, T2IAdapter): + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter + elif isinstance(self.adapter, ControlNetModel): + if self.is_xl: + Pipe = StableDiffusionXLControlNetPipeline + else: + Pipe = StableDiffusionControlNetPipeline + extra_args['controlnet'] = self.adapter + elif isinstance(self.adapter, ReferenceAdapter): + # pass the noise scheduler to the adapter + self.adapter.noise_scheduler = noise_scheduler + else: + if self.is_xl: + extra_args['add_watermarker'] = False + + # TODO add clip skip + if self.is_xl: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ).to(self.device_torch) + pipeline.watermark = None + elif self.is_flux: + if self.model_config.use_flux_cfg: + pipeline = FluxWithCFGPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + + else: + pipeline = FluxPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + scheduler=noise_scheduler, + **extra_args + ) + pipeline.watermark = None + elif self.is_v3: + pipeline = Pipe( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + text_encoder_3=self.text_encoder[2], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + tokenizer_3=self.tokenizer[2], + scheduler=noise_scheduler, + **extra_args + ) + elif self.is_pixart: + pipeline = PixArtSigmaPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + elif self.is_auraflow: + pipeline = AuraFlowPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + + else: + pipeline = Pipe( + vae=self.vae, + unet=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + **extra_args + ) + flush() + # disable progress bar + pipeline.set_progress_bar_config(disable=True) + + if sampler.startswith("sample_"): + pipeline.set_scheduler(sampler) + + refiner_pipeline = None + if self.refiner_unet: + # build refiner pipeline + refiner_pipeline = StableDiffusionXLImg2ImgPipeline( + vae=pipeline.vae, + unet=self.refiner_unet, + text_encoder=None, + text_encoder_2=pipeline.text_encoder_2, + tokenizer=None, + tokenizer_2=pipeline.tokenizer_2, + scheduler=pipeline.scheduler, + add_watermarker=False, + requires_aesthetics_score=True, + ).to(self.device_torch) + # refiner_pipeline.register_to_config(requires_aesthetics_score=False) + refiner_pipeline.watermark = None + refiner_pipeline.set_progress_bar_config(disable=True) + flush() + + start_multiplier = 1.0 + if self.network is not None: + start_multiplier = self.network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if self.network is not None: + assert self.network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize((gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if self.network is not None: + self.network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, + CustomAdapter) and validation_image is not None: + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values(extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError("Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) + + if self.is_xl: + # fix guidance rescale for sdxl + # was trained on 0.7 (I believe) + + grs = gen_config.guidance_rescale + # if grs is None or grs < 0.00001: + # grs = 0.7 + # grs = 0.0 + + if sampler.startswith("sample_"): + extra['use_karras_sigmas'] = True + extra = { + **extra, + **gen_config.extra_kwargs, + } + + img = pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + # negative_prompt=gen_config.negative_prompt, + # negative_prompt_2=gen_config.negative_prompt_2, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + latents=gen_config.latents, + **extra + ).images[0] + elif self.is_v3: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + elif self.is_flux: + if self.model_config.use_flux_cfg: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + else: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + # negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + elif self.is_pixart: + # needs attention masks for some reason + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + elif self.is_auraflow: + pipeline: AuraFlowPipeline = pipeline + + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + else: + img = pipeline( + # prompt=gen_config.prompt, + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # slide off just the last 1280 on the last dim as refiner does not use first text encoder + # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ + refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] + refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] + # run through refiner + img = refiner_pipeline( + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + + # slice these as it does not use both text encoders + # height=gen_config.height, + # width=gen_config.width, + prompt_embeds=refiner_text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=refiner_unconditional_text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + guidance_rescale=grs, + denoising_start=gen_config.refiner_start_at, + denoising_end=gen_config.num_inference_steps, + image=img.unsqueeze(0) + ).images[0] + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + if refiner_pipeline is not None: + del refiner_pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if self.network is not None: + self.network.train() + self.network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + num_channels = self.unet.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + if self.is_xl: + bs, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + dtype = latents.dtype + # just do it without any cropping nonsense + target_size = (height, width) + original_size = (height, width) + crops_coords_top_left = (0, 0) + if requires_aesthetic_score: + # refiner + # https://huggingface.co/papers/2307.01952 + aesthetic_score = 6.0 # simulate one + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(latents.device, dtype=dtype) + + batch_time_ids = torch.cat( + [add_time_ids for _ in range(bs)] + ) + return batch_time_ids + else: + return None + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError("Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + if self.is_xl: + with torch.no_grad(): + # 16, 6 for bs of 4 + if add_time_ids is None: + add_time_ids = self.get_time_ids_from_latents(latents) + + if do_classifier_free_guidance: + # todo check this with larget batches + add_time_ids = torch.cat([add_time_ids] * 2) + + if do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + added_cond_kwargs = { + # todo can we zero here the second text encoder? or match a blank string? + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": add_time_ids, + } + + if self.model_config.refiner_name_or_path is not None: + # we have the refiner on the second half of everything. Do Both + if do_classifier_free_guidance: + raise ValueError("Refiner is not supported with classifier free guidance") + + if self.unet.training: + input_chunks = torch.chunk(latent_model_input, 2, dim=0) + timestep_chunks = torch.chunk(timestep, 2, dim=0) + added_cond_kwargs_chunked = { + "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), + "time_ids": torch.chunk(add_time_ids, 2, dim=0), + } + text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) + + # predict the noise residual + base_pred = self.unet( + input_chunks[0], + timestep_chunks[0], + encoder_hidden_states=text_embeds_chunks[0], + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], + "time_ids": added_cond_kwargs_chunked['time_ids'][0], + }, + **kwargs, + ).sample + + refiner_pred = self.refiner_unet( + input_chunks[1], + timestep_chunks[1], + encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], + # "time_ids": added_cond_kwargs_chunked['time_ids'][1], + "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + noise_pred = torch.cat([base_pred, refiner_pred], dim=0) + else: + noise_pred = self.refiner_unet( + latent_model_input, + timestep, + encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], + # just use the first second text encoder + added_cond_kwargs={ + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": self.get_time_ids_from_latents(latent_model_input, + requires_aesthetic_score=True), + }, + **kwargs, + ).sample + + else: + + # predict the noise residual + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + added_cond_kwargs=added_cond_kwargs, + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + conditional_pred = noise_pred_text + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + else: + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input(latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat([timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.is_pixart: + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + batch_size, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + + if self.pipeline.transformer.config.sample_size == 256: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.pipeline.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.pipeline.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.pipeline.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") + orig_height, orig_width = height, width + height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, + ratios=aspect_ratio_bin) + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.unet.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): + resolution = torch.tensor([height, width]).repeat(batch_size, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) + resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + encoder_hidden_states=text_embeddings.text_embeds, + encoder_attention_mask=text_embeddings.attention_mask, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + **kwargs + )[0] + + # learned sigma + if self.unet.config.out_channels // 2 == self.unet.config.in_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + else: + if self.unet.device != self.device_torch: + self.unet.to(self.device_torch) + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + if self.is_flux: + with torch.no_grad(): + + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 + ) + + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) + + txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + + # # handle guidance + if self.unet.config.guidance_embeds: + if isinstance(guidance_scale, list): + guidance = torch.tensor(guidance_scale, device=self.device_torch) + else: + guidance = torch.tensor([guidance_scale], device=self.device_torch) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + cast_dtype = self.unet.dtype + # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): + noise_pred = self.unet( + hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + # todo make sure this doesnt change + timestep=timestep / 1000, # timestep is 1000 scale + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), + # [1, 512, 4096] + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] + txt_ids=txt_ids, # [1, 512, 3] + img_ids=img_ids, # [1, 4096, 3] + guidance=guidance, + return_dict=False, + **kwargs, + )[0] + + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + + noise_pred = rearrange( + noise_pred, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=latent_model_input.shape[1], + ) + elif self.is_v3: + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + if isinstance(noise_pred, QTensor): + noise_pred = noise_pred.dequantize() + elif self.is_auraflow: + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) + t = t.to(self.device_torch, self.torch_dtype) + + noise_pred = self.unet( + latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + timestep=t, + return_dict=False, + )[0] + else: + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + if self.is_xl: + # todo make this a config + # 50% chance to use an encoder anyway even if it is disabled + # allows the other TE to compensate for the disabled one + # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 + # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 + use_encoder_1 = True + use_encoder_2 = True + + return PromptEmbeds( + train_tools.encode_prompts_xl( + self.tokenizer, + self.text_encoder, + prompt, + prompt2, + num_images_per_prompt=num_images_per_prompt, + use_text_encoder_1=use_encoder_1, + use_text_encoder_2=use_encoder_2, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + ) + ) + if self.is_v3: + return PromptEmbeds( + train_tools.encode_prompts_sd3( + self.tokenizer, + self.text_encoder, + prompt, + num_images_per_prompt=num_images_per_prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + pipeline=self.pipeline, + ) + ) + elif self.is_pixart: + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=300 if self.model_config.is_pixart_sigma else 120, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, + ) + elif self.is_auraflow: + embeds, attention_mask = train_tools.encode_prompts_auraflow( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, # not used + ) + elif self.is_flux: + prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( + self.tokenizer, # list + self.text_encoder, # list + prompt, + truncate=not long_prompts, + max_length=512, + dropout_prob=dropout_prob, + attn_mask=self.model_config.attn_masking + ) + pe = PromptEmbeds( + prompt_embeds + ) + pe.pooled_embeds = pooled_prompt_embeds + return pe + + + elif isinstance(self.text_encoder, T5EncoderModel): + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + + # just mask the attention mask + prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) + embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) + return PromptEmbeds( + embeds, + + # do we want attn mask here? + # attention_mask=attention_mask, + ) + else: + return PromptEmbeds( + train_tools.encode_prompts( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob + ) + ) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux: + # Just train the middle 2 blocks of each transformer block + # block_list = [] + # num_transformer_blocks = 2 + # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) + # for i in range(num_transformer_blocks): + # block_list.append(self.unet.transformer_blocks[start_block + i]) + # + # num_single_transformer_blocks = 4 + # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) + # for i in range(num_single_transformer_blocks): + # block_list.append(self.unet.single_transformer_blocks[start_block + i]) + # + # for block in block_list: + # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param + + # train the guidance embedding + # if self.unet.config.guidance_embeds: + # transformer: FluxTransformer2DModel = self.unet + # for name, param in transformer.time_text_embed.named_parameters(recurse=True, + # prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param + + for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, + prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, + prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): + + # load the full refiner since we only train unet + if self.model_config.refiner_name_or_path is None: + raise ValueError("Refiner must be specified to save it") + refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') + # load the refiner model + dtype = get_torch_dtype(self.dtype) + model_path = self.model_config._original_refiner_name_or_path + if not os.path.exists(model_path) or os.path.isdir(model_path): + # TODO only load unet?? + refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( + model_path, + dtype=dtype, + device='cpu', + # variant="fp16", + use_safetensors=True, + ) + else: + refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( + model_path, + dtype=dtype, + device='cpu', + torch_dtype=self.torch_dtype, + original_config_file=refiner_config_path, + ) + # replace original unet + refiner.unet = self.refiner_unet + flush() + + diffusers_state_dict = OrderedDict() + for k, v in refiner.vae.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.text_encoder_2.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" + diffusers_state_dict[new_key] = v + for k, v in refiner.unet.state_dict().items(): + new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + diffusers_state_dict[new_key] = v + + converted_state_dict = get_ldm_state_dict_from_diffusers( + diffusers_state_dict, + 'sdxl_refiner', + device='cpu', + dtype=save_dtype + ) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + version_string = '1' + if self.is_v2: + version_string = '2' + if self.is_xl: + version_string = 'sdxl' + if self.is_ssd: + # overwrite sdxl because both wil be true here + version_string = 'ssd' + if self.is_ssd and self.is_vega: + version_string = 'vega' + # if output file does not end in .safetensors, then it is a directory and we are + # saving in diffusers format + if not output_file.endswith('.safetensors'): + # diffusers + # if self.is_pixart: + # self.unet.save_pretrained( + # save_directory=output_file, + # safe_serialization=True, + # ) + # else: + if self.is_flux: + # only save the unet + transformer: FluxTransformer2DModel = self.unet + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + else: + + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_file, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + else: + save_ldm_model_from_diffusers( + sd=self, + output_file=output_file, + meta=meta, + save_dtype=save_dtype, + sd_version=version_string, + ) + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + if self.is_pixart or self.is_auraflow or self.is_flux: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print(f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + unet_has_grad = self.unet.proj_out.weight.requires_grad + else: + unet_has_grad = self.unet.conv_in.weight.requires_grad + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + try: + te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad + except: + te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): + te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + else: + te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_(state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs)