import copy import gc import json import random import shutil import typing from typing import Union, List, Literal, Iterator import sys import os from collections import OrderedDict import copy import yaml from PIL import Image from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch import autocast from torch.nn import Parameter from torch.utils.checkpoint import checkpoint from tqdm import tqdm from torchvision.transforms import Resize, transforms from toolkit.assistant_lora import load_assistant_lora_from_path from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset from einops import rearrange, repeat import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from typing import TYPE_CHECKING if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) SD_PREFIX_VAE = "vae" SD_PREFIX_UNET = "unet" SD_PREFIX_REFINER_UNET = "refiner_unet" SD_PREFIX_TEXT_ENCODER = "te" SD_PREFIX_TEXT_ENCODER1 = "te0" SD_PREFIX_TEXT_ENCODER2 = "te1" # prefixed diffusers keys DO_NOT_TRAIN_WEIGHTS = [ "unet_time_embedding.linear_1.bias", "unet_time_embedding.linear_1.weight", "unet_time_embedding.linear_2.bias", "unet_time_embedding.linear_2.weight", "refiner_unet_time_embedding.linear_1.bias", "refiner_unet_time_embedding.linear_1.weight", "refiner_unet_time_embedding.linear_2.bias", "refiner_unet_time_embedding.linear_2.weight", ] DeviceStatePreset = Literal['cache_latents', 'generate'] class BlankNetwork: def __init__(self): self.multiplier = 1.0 self.is_active = True self.is_merged_in = False self.can_merge_in = False def __enter__(self): self.is_active = True def __exit__(self, exc_type, exc_val, exc_tb): self.is_active = False def flush(): torch.cuda.empty_cache() gc.collect() UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 # VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 class StableDiffusion: def __init__( self, device, model_config: ModelConfig, dtype='fp16', custom_pipeline=None, noise_scheduler=None, quantize_device=None, ): self.custom_pipeline = custom_pipeline self.device = device self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) self.device_torch = torch.device(self.device) self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( model_config.vae_device) self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( model_config.te_device) self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" self.device_state = None self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None # sdxl stuff self.logit_scale = None self.ckppt_info = None self.is_loaded = False # to hold network if there is one self.network = None self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 self.is_ssd = model_config.is_ssd self.is_v3 = model_config.is_v3 self.is_vega = model_config.is_vega self.is_pixart = model_config.is_pixart self.is_auraflow = model_config.is_auraflow self.is_flux = model_config.is_flux self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 self.config_file = None self.is_flow_matching = False if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): self.is_flow_matching = True self.quantize_device = quantize_device if quantize_device is not None else self.device self.low_vram = self.model_config.low_vram # merge in and preview active with -1 weight self.invert_assistant_lora = False def load_model(self): if self.is_loaded: return dtype = get_torch_dtype(self.dtype) # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why # self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch) # self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch) # self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch) model_path = self.model_config.name_or_path if 'civitai.com' in self.model_config.name_or_path: # load is a civit ai model, use the loader. from toolkit.civitai import get_model_path_from_url model_path = get_model_path_from_url(self.model_config.name_or_path) load_args = {} if self.noise_scheduler: load_args['scheduler'] = self.noise_scheduler if self.model_config.vae_path is not None: load_args['vae'] = load_vae(self.model_config.vae_path, dtype) if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: pipln = StableDiffusionXLPipeline # pipln = StableDiffusionKDiffusionXLPipeline # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): # try to load with default diffusers pipe = pipln.from_pretrained( model_path, dtype=dtype, device=self.device_torch, # variant="fp16", use_safetensors=True, **load_args ) else: pipe = pipln.from_single_file( model_path, device=self.device_torch, torch_dtype=self.torch_dtype, ) if 'vae' in load_args and load_args['vae'] is not None: pipe.vae = load_args['vae'] flush() text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] for text_encoder in text_encoders: text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) text_encoder.requires_grad_(False) text_encoder.eval() text_encoder = text_encoders pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) if self.model_config.experimental_xl: print("Experimental XL mode enabled") print("Loading and injecting alt weights") # load the mismatched weight and force it in raw_state_dict = load_file(model_path) replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone() del raw_state_dict # get state dict for for 2nd text encoder te1_state_dict = text_encoders[1].state_dict() # replace weight with mismatched weight te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) flush() print("Injecting alt weights") elif self.model_config.is_v3: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: pipln = StableDiffusion3Pipeline print("Loading SD3 model") # assume it is the large model base_model_path = "stabilityai/stable-diffusion-3.5-large" print("Loading transformer") subfolder = 'transformer' transformer_path = model_path # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set if os.path.exists(transformer_path): subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') # check if the path is a full checkpoint. te_folder_path = os.path.join(model_path, 'text_encoder') # if we have the te, this folder is a full checkpoint, use it as the base if os.path.exists(te_folder_path): base_model_path = model_path else: # is remote use whatever path we were given base_model_path = model_path transformer = SD3Transformer2DModel.from_pretrained( transformer_path, subfolder=subfolder, torch_dtype=dtype, ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() if self.model_config.lora_path is not None: raise ValueError("LoRA is not supported for SD3 models currently") if self.model_config.quantize: quantization_type = qfloat8 print("Quantizing transformer") quantize(transformer, weights=quantization_type) freeze(transformer) transformer.to(self.device_torch) else: transformer.to(self.device_torch, dtype=dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") print("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() print("Loading t5") tokenizer_3 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_3", torch_dtype=dtype) text_encoder_3 = T5EncoderModel.from_pretrained( base_model_path, subfolder="text_encoder_3", torch_dtype=dtype ) text_encoder_3.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize: print("Quantizing T5") quantize(text_encoder_3, weights=qfloat8) freeze(text_encoder_3) flush() # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): try: # try to load with default diffusers pipe = pipln.from_pretrained( base_model_path, dtype=dtype, device=self.device_torch, tokenizer_3=tokenizer_3, text_encoder_3=text_encoder_3, transformer=transformer, # variant="fp16", use_safetensors=True, repo_type="model", ignore_patterns=["*.md", "*..gitattributes"], **load_args ) except Exception as e: print(f"Error loading from pretrained: {e}") raise e else: pipe = pipln.from_single_file( model_path, transformer=transformer, device=self.device_torch, torch_dtype=self.torch_dtype, tokenizer_3=tokenizer_3, text_encoder_3=text_encoder_3, **load_args ) flush() text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] # replace the to function with a no-op since it throws an error instead of a warning # text_encoders[2].to = lambda *args, **kwargs: None for text_encoder in text_encoders: text_encoder.to(self.device_torch, dtype=dtype) text_encoder.requires_grad_(False) text_encoder.eval() text_encoder = text_encoders elif self.model_config.is_pixart: te_kwargs = {} # handle quantization of TE te_is_quantized = False if self.model_config.text_encoder_bits == 8: te_kwargs['load_in_8bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True elif self.model_config.text_encoder_bits == 4: te_kwargs['load_in_4bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" if self.model_config.is_pixart_sigma: main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" main_model_path = model_path # load the TE in 8bit mode text_encoder = T5EncoderModel.from_pretrained( main_model_path, subfolder="text_encoder", torch_dtype=self.torch_dtype, **te_kwargs ) # load the transformer subfolder = "transformer" # check if it is just the unet if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): subfolder = None if te_is_quantized: # replace the to function with a no-op since it throws an error instead of a warning text_encoder.to = lambda *args, **kwargs: None text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) if self.model_config.is_pixart_sigma: # load the transformer only from the save transformer = Transformer2DModel.from_pretrained( model_path if self.model_config.unet_path is None else self.model_config.unet_path, torch_dtype=self.torch_dtype, subfolder='transformer' ) pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained( main_model_path, transformer=transformer, text_encoder=text_encoder, dtype=dtype, device=self.device_torch, **load_args ) else: # load the transformer only from the save transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder) pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( main_model_path, transformer=transformer, text_encoder=text_encoder, dtype=dtype, device=self.device_torch, **load_args ).to(self.device_torch) if self.model_config.unet_sample_size is not None: pipe.transformer.config.sample_size = self.model_config.unet_sample_size pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) flush() # text_encoder = pipe.text_encoder # text_encoder.to(self.device_torch, dtype=dtype) text_encoder.requires_grad_(False) text_encoder.eval() pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) tokenizer = pipe.tokenizer pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) if self.noise_scheduler is None: self.noise_scheduler = pipe.scheduler elif self.model_config.is_auraflow: te_kwargs = {} # handle quantization of TE te_is_quantized = False if self.model_config.text_encoder_bits == 8: te_kwargs['load_in_8bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True elif self.model_config.text_encoder_bits == 4: te_kwargs['load_in_4bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True main_model_path = model_path # load the TE in 8bit mode text_encoder = UMT5EncoderModel.from_pretrained( main_model_path, subfolder="text_encoder", torch_dtype=self.torch_dtype, **te_kwargs ) # load the transformer subfolder = "transformer" # check if it is just the unet if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): subfolder = None if te_is_quantized: # replace the to function with a no-op since it throws an error instead of a warning text_encoder.to = lambda *args, **kwargs: None # load the transformer only from the save transformer = AuraFlowTransformer2DModel.from_pretrained( model_path if self.model_config.unet_path is None else self.model_config.unet_path, torch_dtype=self.torch_dtype, subfolder='transformer' ) pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( main_model_path, transformer=transformer, text_encoder=text_encoder, dtype=dtype, device=self.device_torch, **load_args ) pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) # patch auraflow so it can handle other aspect ratios # patch_auraflow_pos_embed(pipe.transformer.pos_embed) flush() # text_encoder = pipe.text_encoder # text_encoder.to(self.device_torch, dtype=dtype) text_encoder.requires_grad_(False) text_encoder.eval() pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) tokenizer = pipe.tokenizer elif self.model_config.is_flux: print("Loading Flux model") base_model_path = "black-forest-labs/FLUX.1-schnell" print("Loading transformer") subfolder = 'transformer' transformer_path = model_path local_files_only = False # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set if os.path.exists(transformer_path): subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') # check if the path is a full checkpoint. te_folder_path = os.path.join(model_path, 'text_encoder') # if we have the te, this folder is a full checkpoint, use it as the base if os.path.exists(te_folder_path): base_model_path = model_path transformer = FluxTransformer2DModel.from_pretrained( transformer_path, subfolder=subfolder, torch_dtype=dtype, # low_cpu_mem_usage=False, # device_map=None ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu transformer.to(torch.device(self.quantize_device), dtype=dtype) flush() if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: if self.model_config.inference_lora_path is not None and self.model_config.assistant_lora_path is not None: raise ValueError("Cannot load both assistant lora and inference lora at the same time") if self.model_config.lora_path: raise ValueError("Cannot load both assistant lora and lora at the same time") if not self.is_flux: raise ValueError("Assistant/ inference lora is only supported for flux models currently") load_lora_path = self.model_config.inference_lora_path if load_lora_path is None: load_lora_path = self.model_config.assistant_lora_path if os.path.isdir(load_lora_path): load_lora_path = os.path.join( load_lora_path, "pytorch_lora_weights.safetensors" ) elif not os.path.exists(load_lora_path): print(f"Grabbing lora from the hub: {load_lora_path}") new_lora_path = hf_hub_download( load_lora_path, filename="pytorch_lora_weights.safetensors" ) # replace the path load_lora_path = new_lora_path if self.model_config.inference_lora_path is not None: self.model_config.inference_lora_path = new_lora_path if self.model_config.assistant_lora_path is not None: self.model_config.assistant_lora_path = new_lora_path if self.model_config.assistant_lora_path is not None: # for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on # quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps # it is better to merge it in now, and sample slowly later, otherwise training is slowed in half # so we will merge in now and sample with -1 weight later self.invert_assistant_lora = True # trigger it to get merged in self.model_config.lora_path = self.model_config.assistant_lora_path if self.model_config.lora_path is not None: print("Fusing in LoRA") # need the pipe for peft pipe: FluxPipeline = FluxPipeline( scheduler=None, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, vae=None, transformer=transformer, ) if self.low_vram: # we cannot fuse the loras all at once without ooming in lowvram mode, so we have to do it in parts # we can do it on the cpu but it takes about 5-10 mins vs seconds on the gpu # we are going to separate it into the two transformer blocks one at a time lora_state_dict = load_file(self.model_config.lora_path) single_transformer_lora = {} single_block_key = "transformer.single_transformer_blocks." double_transformer_lora = {} double_block_key = "transformer.transformer_blocks." for key, value in lora_state_dict.items(): if single_block_key in key: single_transformer_lora[key] = value elif double_block_key in key: double_transformer_lora[key] = value else: raise ValueError(f"Unknown lora key: {key}. Cannot load this lora in low vram mode") # double blocks transformer.transformer_blocks = transformer.transformer_blocks.to( torch.device(self.quantize_device), dtype=dtype ) pipe.load_lora_weights(double_transformer_lora, adapter_name=f"lora1_double") pipe.fuse_lora() pipe.unload_lora_weights() transformer.transformer_blocks = transformer.transformer_blocks.to( 'cpu', dtype=dtype ) # single blocks transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( torch.device(self.quantize_device), dtype=dtype ) pipe.load_lora_weights(single_transformer_lora, adapter_name=f"lora1_single") pipe.fuse_lora() pipe.unload_lora_weights() transformer.single_transformer_blocks = transformer.single_transformer_blocks.to( 'cpu', dtype=dtype ) # cleanup del single_transformer_lora del double_transformer_lora del lora_state_dict flush() else: # need the pipe to do this unfortunately for now # we have to fuse in the weights before quantizing pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.fuse_lora() # unfortunately, not an easier way with peft pipe.unload_lora_weights() flush() if self.model_config.quantize: quantization_type = qfloat8 print("Quantizing transformer") quantize(transformer, weights=quantization_type) freeze(transformer) transformer.to(self.device_torch) else: transformer.to(self.device_torch, dtype=dtype) flush() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") print("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() print("Loading t5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) text_encoder_2.to(self.device_torch, dtype=dtype) flush() print("Quantizing T5") quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) flush() print("Loading clip") text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) print("making pipe") pipe: FluxPipeline = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=None, tokenizer_2=tokenizer_2, vae=vae, transformer=None, ) pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer print("preparing") text_encoder = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] pipe.transformer = pipe.transformer.to(self.device_torch) flush() text_encoder[0].to(self.device_torch) text_encoder[0].requires_grad_(False) text_encoder[0].eval() text_encoder[1].to(self.device_torch) text_encoder[1].requires_grad_(False) text_encoder[1].eval() pipe.transformer = pipe.transformer.to(self.device_torch) flush() else: if self.custom_pipeline is not None: pipln = self.custom_pipeline else: pipln = StableDiffusionPipeline if self.model_config.text_encoder_bits < 16: # this is only supported for T5 models for now te_kwargs = {} # handle quantization of TE te_is_quantized = False if self.model_config.text_encoder_bits == 8: te_kwargs['load_in_8bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True elif self.model_config.text_encoder_bits == 4: te_kwargs['load_in_4bit'] = True te_kwargs['device_map'] = "auto" te_is_quantized = True text_encoder = T5EncoderModel.from_pretrained( model_path, subfolder="text_encoder", torch_dtype=self.te_torch_dtype, **te_kwargs ) # replace the to function with a no-op since it throws an error instead of a warning text_encoder.to = lambda *args, **kwargs: None load_args['text_encoder'] = text_encoder # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): # try to load with default diffusers pipe = pipln.from_pretrained( model_path, dtype=dtype, device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, safety_checker=None, # variant="fp16", trust_remote_code=True, **load_args ) else: pipe = pipln.from_single_file( model_path, dtype=dtype, device=self.device_torch, load_safety_checker=False, requires_safety_checker=False, torch_dtype=self.torch_dtype, safety_checker=None, trust_remote_code=True, **load_args ) flush() pipe.register_to_config(requires_safety_checker=False) text_encoder = pipe.text_encoder text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype) text_encoder.requires_grad_(False) text_encoder.eval() tokenizer = pipe.tokenizer # scheduler doesn't get set sometimes, so we set it here pipe.scheduler = self.noise_scheduler # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: # pixart and sd3 dont use a unet self.unet = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) self.vae.eval() self.vae.requires_grad_(False) VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) self.vae_scale_factor = VAE_SCALE_FACTOR self.unet.to(self.device_torch, dtype=dtype) self.unet.requires_grad_(False) self.unet.eval() # load any loras we have if self.model_config.lora_path is not None and not self.is_flux: pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.fuse_lora() # unfortunately, not an easier way with peft pipe.unload_lora_weights() self.tokenizer = tokenizer self.text_encoder = text_encoder self.pipeline = pipe self.load_refiner() self.is_loaded = True if self.model_config.assistant_lora_path is not None: print("Loading assistant lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.assistant_lora_path, self) if self.invert_assistant_lora: # invert and disable during training self.assistant_lora.multiplier = -1.0 self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: print("Loading inference lora") self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( self.model_config.inference_lora_path, self) # disable during training self.assistant_lora.is_active = False if self.is_pixart and self.vae_scale_factor == 16: # TODO make our own pipeline? # we generate an image 2x larger, so we need to copy the sizes from larger ones down # ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN for key in ASPECT_RATIO_256_BIN.keys(): ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2] for key in ASPECT_RATIO_512_BIN.keys(): ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2] for key in ASPECT_RATIO_1024_BIN.keys(): ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2] for key in ASPECT_RATIO_2048_BIN.keys(): ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2] def te_train(self): if isinstance(self.text_encoder, list): for te in self.text_encoder: te.train() else: self.text_encoder.train() def te_eval(self): if isinstance(self.text_encoder, list): for te in self.text_encoder: te.eval() else: self.text_encoder.eval() def load_refiner(self): # for now, we are just going to rely on the TE from the base model # which is TE2 for SDXL and TE for SD (no refiner currently) # and completely ignore a TE that may or may not be packaged with the refiner if self.model_config.refiner_name_or_path is not None: refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') # load the refiner model dtype = get_torch_dtype(self.dtype) model_path = self.model_config.refiner_name_or_path if not os.path.exists(model_path) or os.path.isdir(model_path): # TODO only load unet?? refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( model_path, dtype=dtype, device=self.device_torch, # variant="fp16", use_safetensors=True, ).to(self.device_torch) else: refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( model_path, dtype=dtype, device=self.device_torch, torch_dtype=self.torch_dtype, original_config_file=refiner_config_path, ).to(self.device_torch) self.refiner_unet = refiner.unet del refiner flush() @torch.no_grad() def generate_images( self, image_configs: List[GenerateImageConfig], sampler=None, pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): merge_multiplier = 1.0 flush() # if using assistant, unfuse it if self.model_config.assistant_lora_path is not None: print("Unloading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = True # move weights on to the device self.assistant_lora.force_to(self.device_torch, self.torch_dtype) else: self.assistant_lora.is_active = False if self.model_config.inference_lora_path is not None: print("Loading inference lora") self.assistant_lora.is_active = True # move weights on to the device self.assistant_lora.force_to(self.device_torch, self.torch_dtype) if self.network is not None: self.network.eval() network = self.network # check if we have the same network weight for all samples. If we do, we can merge in th # the network to drastically speed up inference unique_network_weights = set([x.network_multiplier for x in image_configs]) if len(unique_network_weights) == 1 and self.network.can_merge_in: can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) else: network = BlankNetwork() self.save_device_state() self.set_device_state_preset('generate') # save current seed state for training rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None if pipeline is None: noise_scheduler = self.noise_scheduler if sampler is not None: if sampler.startswith("sample_"): # sample_dpmpp_2m # using ksampler noise_scheduler = get_sampler( 'lms', { "prediction_type": self.prediction_type, }) else: noise_scheduler = get_sampler( sampler, { "prediction_type": self.prediction_type, }, 'sd' if not self.is_pixart else 'pixart' ) try: noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype) except: pass if sampler.startswith("sample_") and self.is_xl: # using kdiffusion Pipe = StableDiffusionKDiffusionXLPipeline elif self.is_xl: Pipe = StableDiffusionXLPipeline elif self.is_v3: Pipe = StableDiffusion3Pipeline else: Pipe = StableDiffusionPipeline extra_args = {} if self.adapter is not None: if isinstance(self.adapter, T2IAdapter): if self.is_xl: Pipe = StableDiffusionXLAdapterPipeline else: Pipe = StableDiffusionAdapterPipeline extra_args['adapter'] = self.adapter elif isinstance(self.adapter, ControlNetModel): if self.is_xl: Pipe = StableDiffusionXLControlNetPipeline else: Pipe = StableDiffusionControlNetPipeline extra_args['controlnet'] = self.adapter elif isinstance(self.adapter, ReferenceAdapter): # pass the noise scheduler to the adapter self.adapter.noise_scheduler = noise_scheduler else: if self.is_xl: extra_args['add_watermarker'] = False # TODO add clip skip if self.is_xl: pipeline = Pipe( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, **extra_args ).to(self.device_torch) pipeline.watermark = None elif self.is_flux: if self.model_config.use_flux_cfg: pipeline = FluxWithCFGPipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, **extra_args ) else: pipeline = FluxPipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, **extra_args ) pipeline.watermark = None elif self.is_v3: pipeline = Pipe( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], text_encoder_3=self.text_encoder[2], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], tokenizer_3=self.tokenizer[2], scheduler=noise_scheduler, **extra_args ) elif self.is_pixart: pipeline = PixArtSigmaPipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=noise_scheduler, **extra_args ) elif self.is_auraflow: pipeline = AuraFlowPipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=noise_scheduler, **extra_args ) else: pipeline = Pipe( vae=self.vae, unet=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=noise_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, **extra_args ) flush() # disable progress bar pipeline.set_progress_bar_config(disable=True) if sampler.startswith("sample_"): pipeline.set_scheduler(sampler) refiner_pipeline = None if self.refiner_unet: # build refiner pipeline refiner_pipeline = StableDiffusionXLImg2ImgPipeline( vae=pipeline.vae, unet=self.refiner_unet, text_encoder=None, text_encoder_2=pipeline.text_encoder_2, tokenizer=None, tokenizer_2=pipeline.tokenizer_2, scheduler=pipeline.scheduler, add_watermarker=False, requires_aesthetics_score=True, ).to(self.device_torch) # refiner_pipeline.register_to_config(requires_aesthetics_score=False) refiner_pipeline.watermark = None refiner_pipeline.set_progress_bar_config(disable=True) flush() start_multiplier = 1.0 if self.network is not None: start_multiplier = self.network.multiplier # pipeline.to(self.device_torch) with network: with torch.no_grad(): if self.network is not None: assert self.network.is_active for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): gen_config = image_configs[i] extra = {} validation_image = None if self.adapter is not None and gen_config.adapter_image_path is not None: validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") if isinstance(self.adapter, T2IAdapter): # not sure why this is double?? validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) extra['image'] = validation_image extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, ControlNetModel): validation_image = validation_image.resize((gen_config.width, gen_config.height)) extra['image'] = validation_image extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): transform = transforms.Compose([ transforms.ToTensor(), ]) validation_image = transform(validation_image) if isinstance(self.adapter, CustomAdapter): # todo allow loading multiple transform = transforms.Compose([ transforms.ToTensor(), ]) validation_image = transform(validation_image) self.adapter.num_images = 1 if isinstance(self.adapter, ReferenceAdapter): # need -1 to 1 validation_image = transforms.ToTensor()(validation_image) validation_image = validation_image * 2.0 - 1.0 validation_image = validation_image.unsqueeze(0) self.adapter.set_reference_images(validation_image) if self.network is not None: self.network.multiplier = gen_config.network_multiplier torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ and gen_config.adapter_image_path is not None: # run through the adapter to saturate the embeds conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) self.adapter(conditional_clip_embeds) if self.adapter is not None and isinstance(self.adapter, CustomAdapter): # handle condition the prompts gen_config.prompt = self.adapter.condition_prompt( gen_config.prompt, is_unconditional=False, ) gen_config.prompt_2 = gen_config.prompt gen_config.negative_prompt = self.adapter.condition_prompt( gen_config.negative_prompt, is_unconditional=True, ) gen_config.negative_prompt_2 = gen_config.negative_prompt if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: self.adapter.trigger_pre_te( tensors_0_1=validation_image, is_training=False, has_been_preprocessed=False, quad_count=4 ) # encode the prompt ourselves so we can do fun stuff with embeddings if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True unconditional_embeds = self.encode_prompt( gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True ) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False # allow any manipulations to take place to embeddings gen_config.post_process_embeddings( conditional_embeds, unconditional_embeds, ) if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ and gen_config.adapter_image_path is not None: # apply the image projection conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=conditional_embeds, is_training=False, has_been_preprocessed=False, is_generating_samples=True, ) unconditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=unconditional_embeds, is_training=False, has_been_preprocessed=False, is_unconditional=True, is_generating_samples=True, ) if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( gen_config.extra_values) > 0: extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype) # apply extra values to the embeddings self.adapter.add_extra_values(extra_values, is_unconditional=False) self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) pass # todo remove, for debugging if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # if we have a refiner loaded, set the denoising end at the refiner start extra['denoising_end'] = gen_config.refiner_start_at extra['output_type'] = 'latent' if not self.is_xl: raise ValueError("Refiner is only supported for XL models") conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype) unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype) if self.is_xl: # fix guidance rescale for sdxl # was trained on 0.7 (I believe) grs = gen_config.guidance_rescale # if grs is None or grs < 0.00001: # grs = 0.7 # grs = 0.0 if sampler.startswith("sample_"): extra['use_karras_sigmas'] = True extra = { **extra, **gen_config.extra_kwargs, } img = pipeline( # prompt=gen_config.prompt, # prompt_2=gen_config.prompt_2, prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, negative_prompt_embeds=unconditional_embeds.text_embeds, negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, # negative_prompt=gen_config.negative_prompt, # negative_prompt_2=gen_config.negative_prompt_2, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, guidance_rescale=grs, latents=gen_config.latents, **extra ).images[0] elif self.is_v3: img = pipeline( prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, negative_prompt_embeds=unconditional_embeds.text_embeds, negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] elif self.is_flux: if self.model_config.use_flux_cfg: img = pipeline( prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, negative_prompt_embeds=unconditional_embeds.text_embeds, negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] else: img = pipeline( prompt_embeds=conditional_embeds.text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, # negative_prompt_embeds=unconditional_embeds.text_embeds, # negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] elif self.is_pixart: # needs attention masks for some reason img = pipeline( prompt=None, prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt=None, # negative_prompt=gen_config.negative_prompt, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] elif self.is_auraflow: pipeline: AuraFlowPipeline = pipeline img = pipeline( prompt=None, prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), negative_prompt=None, # negative_prompt=gen_config.negative_prompt, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] else: img = pipeline( # prompt=gen_config.prompt, prompt_embeds=conditional_embeds.text_embeds, negative_prompt_embeds=unconditional_embeds.text_embeds, # negative_prompt=gen_config.negative_prompt, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, **extra ).images[0] if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # slide off just the last 1280 on the last dim as refiner does not use first text encoder # todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:] refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:] # run through refiner img = refiner_pipeline( # prompt=gen_config.prompt, # prompt_2=gen_config.prompt_2, # slice these as it does not use both text encoders # height=gen_config.height, # width=gen_config.width, prompt_embeds=refiner_text_embeds, pooled_prompt_embeds=conditional_embeds.pooled_embeds, negative_prompt_embeds=refiner_unconditional_text_embeds, negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, guidance_rescale=grs, denoising_start=gen_config.refiner_start_at, denoising_end=gen_config.num_inference_steps, image=img.unsqueeze(0) ).images[0] gen_config.save_image(img, i) gen_config.log_image(img, i) if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): self.adapter.clear_memory() # clear pipeline and cache to reduce vram usage del pipeline if refiner_pipeline is not None: del refiner_pipeline torch.cuda.empty_cache() # restore training state torch.set_rng_state(rng_state) if cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) self.restore_device_state() if self.network is not None: self.network.train() self.network.multiplier = start_multiplier self.unet.to(self.device_torch, dtype=self.torch_dtype) if network.is_merged_in: network.merge_out(merge_multiplier) # self.tokenizer.to(original_device_dict['tokenizer']) # refuse loras if self.model_config.assistant_lora_path is not None: print("Loading assistant lora") if self.invert_assistant_lora: self.assistant_lora.is_active = False # move weights off the device self.assistant_lora.force_to('cpu', self.torch_dtype) else: self.assistant_lora.is_active = True if self.model_config.inference_lora_path is not None: print("Unloading inference lora") self.assistant_lora.is_active = False # move weights off the device self.assistant_lora.force_to('cpu', self.torch_dtype) flush() def get_latent_noise( self, height=None, width=None, pixel_height=None, pixel_width=None, batch_size=1, noise_offset=0.0, ): VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) if height is None and pixel_height is None: raise ValueError("height or pixel_height must be specified") if width is None and pixel_width is None: raise ValueError("width or pixel_width must be specified") if height is None: height = pixel_height // VAE_SCALE_FACTOR if width is None: width = pixel_width // VAE_SCALE_FACTOR num_channels = self.unet.config['in_channels'] if self.is_flux: # has 64 channels in for some reason num_channels = 16 noise = torch.randn( ( batch_size, num_channels, height, width, ), device=self.unet.device, ) noise = apply_noise_offset(noise, noise_offset) return noise def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) if self.is_xl: bs, ch, h, w = list(latents.shape) height = h * VAE_SCALE_FACTOR width = w * VAE_SCALE_FACTOR dtype = latents.dtype # just do it without any cropping nonsense target_size = (height, width) original_size = (height, width) crops_coords_top_left = (0, 0) if requires_aesthetic_score: # refiner # https://huggingface.co/papers/2307.01952 aesthetic_score = 6.0 # simulate one add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(latents.device, dtype=dtype) batch_time_ids = torch.cat( [add_time_ids for _ in range(bs)] ) return batch_time_ids else: return None def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) noisy_latents_chunks = [] for idx in range(original_samples.shape[0]): noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx]) noisy_latents_chunks.append(noisy_latents) noisy_latents = torch.cat(noisy_latents_chunks, dim=0) return noisy_latents def predict_noise( self, latents: torch.Tensor, text_embeddings: Union[PromptEmbeds, None] = None, timestep: Union[int, torch.Tensor] = 1, guidance_scale=7.5, guidance_rescale=0, add_time_ids=None, conditional_embeddings: Union[PromptEmbeds, None] = None, unconditional_embeddings: Union[PromptEmbeds, None] = None, is_input_scaled=False, detach_unconditional=False, rescale_cfg=None, return_conditional_pred=False, guidance_embedding_scale=1.0, **kwargs, ): conditional_pred = None # get the embeddings if text_embeddings is None and conditional_embeddings is None: raise ValueError("Either text_embeddings or conditional_embeddings must be specified") if text_embeddings is None and unconditional_embeddings is not None: text_embeddings = concat_prompt_embeds([ unconditional_embeddings, # negative embedding conditional_embeddings, # positive embedding ]) elif text_embeddings is None and conditional_embeddings is not None: # not doing cfg text_embeddings = conditional_embeddings # CFG is comparing neg and positive, if we have concatenated embeddings # then we are doing it, otherwise we are not and takes half the time. do_classifier_free_guidance = True # check if batch size of embeddings matches batch size of latents if latents.shape[0] == text_embeddings.text_embeds.shape[0]: do_classifier_free_guidance = False elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings") latents = latents.to(self.device_torch) text_embeddings = text_embeddings.to(self.device_torch) timestep = timestep.to(self.device_torch) # if timestep is zero dim, unsqueeze it if len(timestep.shape) == 0: timestep = timestep.unsqueeze(0) # if we only have 1 timestep, we can just use the same timestep for all if timestep.shape[0] == 1 and latents.shape[0] > 1: # check if it is rank 1 or 2 if len(timestep.shape) == 1: timestep = timestep.repeat(latents.shape[0]) else: timestep = timestep.repeat(latents.shape[0], 0) # handle t2i adapters if 'down_intrablock_additional_residuals' in kwargs: # go through each item and concat if doing cfg and it doesnt have the same shape for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) # handle controlnet if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: # go through each item and concat if doing cfg and it doesnt have the same shape for idx, item in enumerate(kwargs['down_block_additional_residuals']): if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0) for idx, item in enumerate(kwargs['mid_block_additional_residual']): if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0) def scale_model_input(model_input, timestep_tensor): if is_input_scaled: return model_input mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] # unsqueeze if timestep is zero dim for idx in range(model_input.shape[0]): # if scheduler has step_index if hasattr(self.noise_scheduler, '_step_index'): self.noise_scheduler._step_index = None out_chunks.append( self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx]) ) return torch.cat(out_chunks, dim=0) if self.is_xl: with torch.no_grad(): # 16, 6 for bs of 4 if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) if do_classifier_free_guidance: # todo check this with larget batches add_time_ids = torch.cat([add_time_ids] * 2) if do_classifier_free_guidance: latent_model_input = torch.cat([latents] * 2) timestep = torch.cat([timestep] * 2) else: latent_model_input = latents latent_model_input = scale_model_input(latent_model_input, timestep) added_cond_kwargs = { # todo can we zero here the second text encoder? or match a blank string? "text_embeds": text_embeddings.pooled_embeds, "time_ids": add_time_ids, } if self.model_config.refiner_name_or_path is not None: # we have the refiner on the second half of everything. Do Both if do_classifier_free_guidance: raise ValueError("Refiner is not supported with classifier free guidance") if self.unet.training: input_chunks = torch.chunk(latent_model_input, 2, dim=0) timestep_chunks = torch.chunk(timestep, 2, dim=0) added_cond_kwargs_chunked = { "text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0), "time_ids": torch.chunk(add_time_ids, 2, dim=0), } text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0) # predict the noise residual base_pred = self.unet( input_chunks[0], timestep_chunks[0], encoder_hidden_states=text_embeds_chunks[0], added_cond_kwargs={ "text_embeds": added_cond_kwargs_chunked['text_embeds'][0], "time_ids": added_cond_kwargs_chunked['time_ids'][0], }, **kwargs, ).sample refiner_pred = self.refiner_unet( input_chunks[1], timestep_chunks[1], encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:], # just use the first second text encoder added_cond_kwargs={ "text_embeds": added_cond_kwargs_chunked['text_embeds'][1], # "time_ids": added_cond_kwargs_chunked['time_ids'][1], "time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True), }, **kwargs, ).sample noise_pred = torch.cat([base_pred, refiner_pred], dim=0) else: noise_pred = self.refiner_unet( latent_model_input, timestep, encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:], # just use the first second text encoder added_cond_kwargs={ "text_embeds": text_embeddings.pooled_embeds, "time_ids": self.get_time_ids_from_latents(latent_model_input, requires_aesthetic_score=True), }, **kwargs, ).sample else: # predict the noise residual noise_pred = self.unet( latent_model_input.to(self.device_torch, self.torch_dtype), timestep, encoder_hidden_states=text_embeddings.text_embeds, added_cond_kwargs=added_cond_kwargs, **kwargs, ).sample conditional_pred = noise_pred if do_classifier_free_guidance: # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) conditional_pred = noise_pred_text noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 if guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) else: with torch.no_grad(): if do_classifier_free_guidance: # if we are doing classifier free guidance, need to double up latent_model_input = torch.cat([latents] * 2, dim=0) timestep = torch.cat([timestep] * 2) else: latent_model_input = latents latent_model_input = scale_model_input(latent_model_input, timestep) # check if we need to concat timesteps if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: ts_bs = timestep.shape[0] if ts_bs != latent_model_input.shape[0]: if ts_bs == 1: timestep = torch.cat([timestep] * latent_model_input.shape[0]) elif ts_bs * 2 == latent_model_input.shape[0]: timestep = torch.cat([timestep] * 2, dim=0) else: raise ValueError( f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") # predict the noise residual if self.is_pixart: VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) batch_size, ch, h, w = list(latents.shape) height = h * VAE_SCALE_FACTOR width = w * VAE_SCALE_FACTOR if self.pipeline.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.pipeline.transformer.config.sample_size == 128: aspect_ratio_bin = ASPECT_RATIO_1024_BIN elif self.pipeline.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_512_BIN elif self.pipeline.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_256_BIN else: raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") orig_height, orig_width = height, width height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} if self.unet.config.sample_size == 128 or ( self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): resolution = torch.tensor([height, width]).repeat(batch_size, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) if do_classifier_free_guidance: resolution = torch.cat([resolution, resolution], dim=0) aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} noise_pred = self.unet( latent_model_input.to(self.device_torch, self.torch_dtype), encoder_hidden_states=text_embeddings.text_embeds, encoder_attention_mask=text_embeddings.attention_mask, timestep=timestep, added_cond_kwargs=added_cond_kwargs, return_dict=False, **kwargs )[0] # learned sigma if self.unet.config.out_channels // 2 == self.unet.config.in_channels: noise_pred = noise_pred.chunk(2, dim=1)[0] else: noise_pred = noise_pred else: if self.unet.device != self.device_torch: self.unet.to(self.device_torch) if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) if self.is_flux: with torch.no_grad(): bs, c, h, w = latent_model_input.shape latent_model_input_packed = rearrange( latent_model_input, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 ) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) # # handle guidance if self.unet.config.guidance_embeds: if isinstance(guidance_scale, list): guidance = torch.tensor(guidance_scale, device=self.device_torch) else: guidance = torch.tensor([guidance_scale], device=self.device_torch) guidance = guidance.expand(latents.shape[0]) else: guidance = None cast_dtype = self.unet.dtype # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): noise_pred = self.unet( hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # todo make sure this doesnt change timestep=timestep / 1000, # timestep is 1000 scale encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096] pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] txt_ids=txt_ids, # [1, 512, 3] img_ids=img_ids, # [1, 4096, 3] guidance=guidance, return_dict=False, **kwargs, )[0] if isinstance(noise_pred, QTensor): noise_pred = noise_pred.dequantize() noise_pred = rearrange( noise_pred, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=latent_model_input.shape[2] // 2, w=latent_model_input.shape[3] // 2, ph=2, pw=2, c=latent_model_input.shape[1], ) elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), timestep=timestep, encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample if isinstance(noise_pred, QTensor): noise_pred = noise_pred.dequantize() elif self.is_auraflow: # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) t = t.to(self.device_torch, self.torch_dtype) noise_pred = self.unet( latent_model_input, encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), timestep=t, return_dict=False, )[0] else: noise_pred = self.unet( latent_model_input.to(self.device_torch, self.torch_dtype), timestep=timestep, encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample conditional_pred = noise_pred if do_classifier_free_guidance: # perform guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) conditional_pred = noise_pred_text if detach_unconditional: noise_pred_uncond = noise_pred_uncond.detach() noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) if rescale_cfg is not None and rescale_cfg != guidance_scale: with torch.no_grad(): # do cfg at the target rescale so we can match it target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( noise_pred_text - noise_pred_uncond ) target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach() pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() # match the mean and std noise_pred = (noise_pred - pred_mean) / pred_std noise_pred = (noise_pred * target_std) + target_mean # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 if guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) if return_conditional_pred: return noise_pred, conditional_pred return noise_pred def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): if noise_scheduler is None: noise_scheduler = self.noise_scheduler # // sometimes they are on the wrong device, no idea why if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): try: noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch) noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch) noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch) except Exception as e: pass mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) out_chunks = [] if len(timestep_chunks) == 1 and len(mi_chunks) > 1: # expand timestep to match timestep_chunks = timestep_chunks * len(mi_chunks) for idx in range(model_input.shape[0]): # Reset it so it is unique for the if hasattr(noise_scheduler, '_step_index'): noise_scheduler._step_index = None if hasattr(noise_scheduler, 'is_scale_input_called'): noise_scheduler.is_scale_input_called = True out_chunks.append( noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ 0] ) return torch.cat(out_chunks, dim=0) # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 def diffuse_some_steps( self, latents: torch.FloatTensor, text_embeddings: PromptEmbeds, total_timesteps: int = 1000, start_timesteps=0, guidance_scale=1, add_time_ids=None, bleed_ratio: float = 0.5, bleed_latents: torch.FloatTensor = None, is_input_scaled=False, return_first_prediction=False, **kwargs, ): timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] first_prediction = None for timestep in tqdm(timesteps_to_run, leave=False): timestep = timestep.unsqueeze_(0) noise_pred, conditional_pred = self.predict_noise( latents, text_embeddings, timestep, guidance_scale=guidance_scale, add_time_ids=add_time_ids, is_input_scaled=is_input_scaled, return_conditional_pred=True, **kwargs, ) # some schedulers need to run separately, so do that. (euler for example) if return_first_prediction and first_prediction is None: first_prediction = conditional_pred latents = self.step_scheduler(noise_pred, latents, timestep) # if not last step, and bleeding, bleed in some latents if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio) # only skip first scaling is_input_scaled = False # return latents_steps if return_first_prediction: return latents, first_prediction return latents def encode_prompt( self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False, long_prompts=False, max_length=None, dropout_prob=0.0, ) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt # if it is not a list, make it one if not isinstance(prompt, list): prompt = [prompt] if prompt2 is not None and not isinstance(prompt2, list): prompt2 = [prompt2] if self.is_xl: # todo make this a config # 50% chance to use an encoder anyway even if it is disabled # allows the other TE to compensate for the disabled one # use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5 # use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5 use_encoder_1 = True use_encoder_2 = True return PromptEmbeds( train_tools.encode_prompts_xl( self.tokenizer, self.text_encoder, prompt, prompt2, num_images_per_prompt=num_images_per_prompt, use_text_encoder_1=use_encoder_1, use_text_encoder_2=use_encoder_2, truncate=not long_prompts, max_length=max_length, dropout_prob=dropout_prob, ) ) if self.is_v3: return PromptEmbeds( train_tools.encode_prompts_sd3( self.tokenizer, self.text_encoder, prompt, num_images_per_prompt=num_images_per_prompt, truncate=not long_prompts, max_length=max_length, dropout_prob=dropout_prob, pipeline=self.pipeline, ) ) elif self.is_pixart: embeds, attention_mask = train_tools.encode_prompts_pixart( self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=300 if self.model_config.is_pixart_sigma else 120, dropout_prob=dropout_prob ) return PromptEmbeds( embeds, attention_mask=attention_mask, ) elif self.is_auraflow: embeds, attention_mask = train_tools.encode_prompts_auraflow( self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=256, dropout_prob=dropout_prob ) return PromptEmbeds( embeds, attention_mask=attention_mask, # not used ) elif self.is_flux: prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux( self.tokenizer, # list self.text_encoder, # list prompt, truncate=not long_prompts, max_length=512, dropout_prob=dropout_prob, attn_mask=self.model_config.attn_masking ) pe = PromptEmbeds( prompt_embeds ) pe.pooled_embeds = pooled_prompt_embeds return pe elif isinstance(self.text_encoder, T5EncoderModel): embeds, attention_mask = train_tools.encode_prompts_pixart( self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=256, dropout_prob=dropout_prob ) # just mask the attention mask prompt_attention_mask = attention_mask.unsqueeze(-1).expand(embeds.shape) embeds = embeds * prompt_attention_mask.to(dtype=embeds.dtype, device=embeds.device) return PromptEmbeds( embeds, # do we want attn mask here? # attention_mask=attention_mask, ) else: return PromptEmbeds( train_tools.encode_prompts( self.tokenizer, self.text_encoder, prompt, truncate=not long_prompts, max_length=max_length, dropout_prob=dropout_prob ) ) @torch.no_grad() def encode_images( self, image_list: List[torch.Tensor], device=None, dtype=None ): if device is None: device = self.vae_device_torch if dtype is None: dtype = self.vae_torch_dtype latent_list = [] # Move to vae to device if on cpu if self.vae.device == 'cpu': self.vae.to(device) self.vae.eval() self.vae.requires_grad_(False) # move to device and dtype image_list = [image.to(device, dtype=dtype) for image in image_list] VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # resize images if not divisible by 8 for i in range(len(image_list)): image = image_list[i] if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) images = torch.stack(image_list) if isinstance(self.vae, AutoencoderTiny): latents = self.vae.encode(images, return_dict=False)[0] else: latents = self.vae.encode(images).latent_dist.sample() shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 # z = self.scale_factor * (z - self.shift_factor) latents = self.vae.config['scaling_factor'] * (latents - shift) latents = latents.to(device, dtype=dtype) return latents def decode_latents( self, latents: torch.Tensor, device=None, dtype=None ): if device is None: device = self.device if dtype is None: dtype = self.torch_dtype # Move to vae to device if on cpu if self.vae.device == 'cpu': self.vae.to(self.device) latents = latents.to(device, dtype=dtype) latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] images = self.vae.decode(latents).sample images = images.to(device, dtype=dtype) return images def encode_image_prompt_pairs( self, prompt_list: List[str], image_list: List[torch.Tensor], device=None, dtype=None ): # todo check image types and expand and rescale as needed # device and dtype are for outputs if device is None: device = self.device if dtype is None: dtype = self.torch_dtype embedding_list = [] latent_list = [] # embed the prompts for prompt in prompt_list: embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype) embedding_list.append(embedding) return embedding_list, latent_list def get_weight_by_name(self, name): # weights begin with te{te_num}_ for text encoder # weights begin with unet_ for unet_ if name.startswith('te'): key = name[4:] # text encoder te_num = int(name[2]) if isinstance(self.text_encoder, list): return self.text_encoder[te_num].state_dict()[key] else: return self.text_encoder.state_dict()[key] elif name.startswith('unet'): key = name[5:] # unet return self.unet.state_dict()[key] raise ValueError(f"Unknown weight name: {name}") def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): return inject_trigger_into_prompt( prompt, trigger=trigger, to_replace_list=to_replace_list, add_if_not_present=add_if_not_present, ) def state_dict(self, vae=True, text_encoder=True, unet=True): state_dict = OrderedDict() if vae: for k, v in self.vae.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" state_dict[new_key] = v if text_encoder: if isinstance(self.text_encoder, list): for i, encoder in enumerate(self.text_encoder): for k, v in encoder.state_dict().items(): new_key = k if k.startswith( f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" state_dict[new_key] = v else: for k, v in self.text_encoder.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" state_dict[new_key] = v if unet: for k, v in self.unet.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" state_dict[new_key] = v return state_dict def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ OrderedDict[ str, Parameter]: named_params: OrderedDict[str, Parameter] = OrderedDict() if vae: for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): named_params[name] = param if text_encoder: if isinstance(self.text_encoder, list): for i, encoder in enumerate(self.text_encoder): if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: # dont add these params continue if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: # dont add these params continue for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): named_params[name] = param else: for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): named_params[name] = param if unet: if self.is_flux: # Just train the middle 2 blocks of each transformer block # block_list = [] # num_transformer_blocks = 2 # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) # for i in range(num_transformer_blocks): # block_list.append(self.unet.transformer_blocks[start_block + i]) # # num_single_transformer_blocks = 4 # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) # for i in range(num_single_transformer_blocks): # block_list.append(self.unet.single_transformer_blocks[start_block + i]) # # for block in block_list: # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): # named_params[name] = param # train the guidance embedding # if self.unet.config.guidance_embeds: # transformer: FluxTransformer2DModel = self.unet # for name, param in transformer.time_text_embed.named_parameters(recurse=True, # prefix=f"{SD_PREFIX_UNET}"): # named_params[name] = param for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param else: for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param if refiner: for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): named_params[name] = param # convert to state dict keys, jsut replace . with _ on keys if state_dict_keys: new_named_params = OrderedDict() for k, v in named_params.items(): # replace only the first . with an _ new_key = k.replace('.', '_', 1) new_named_params[new_key] = v named_params = new_named_params return named_params def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')): # load the full refiner since we only train unet if self.model_config.refiner_name_or_path is None: raise ValueError("Refiner must be specified to save it") refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml') # load the refiner model dtype = get_torch_dtype(self.dtype) model_path = self.model_config._original_refiner_name_or_path if not os.path.exists(model_path) or os.path.isdir(model_path): # TODO only load unet?? refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained( model_path, dtype=dtype, device='cpu', # variant="fp16", use_safetensors=True, ) else: refiner = StableDiffusionXLImg2ImgPipeline.from_single_file( model_path, dtype=dtype, device='cpu', torch_dtype=self.torch_dtype, original_config_file=refiner_config_path, ) # replace original unet refiner.unet = self.refiner_unet flush() diffusers_state_dict = OrderedDict() for k, v in refiner.vae.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" diffusers_state_dict[new_key] = v for k, v in refiner.text_encoder_2.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}" diffusers_state_dict[new_key] = v for k, v in refiner.unet.state_dict().items(): new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" diffusers_state_dict[new_key] = v converted_state_dict = get_ldm_state_dict_from_diffusers( diffusers_state_dict, 'sdxl_refiner', device='cpu', dtype=save_dtype ) # make sure parent folder exists os.makedirs(os.path.dirname(output_file), exist_ok=True) save_file(converted_state_dict, output_file, metadata=meta) if self.config_file is not None: output_path_no_ext = os.path.splitext(output_file)[0] output_config_path = f"{output_path_no_ext}.yaml" shutil.copyfile(self.config_file, output_config_path) def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): version_string = '1' if self.is_v2: version_string = '2' if self.is_xl: version_string = 'sdxl' if self.is_ssd: # overwrite sdxl because both wil be true here version_string = 'ssd' if self.is_ssd and self.is_vega: version_string = 'vega' # if output file does not end in .safetensors, then it is a directory and we are # saving in diffusers format if not output_file.endswith('.safetensors'): # diffusers # if self.is_pixart: # self.unet.save_pretrained( # save_directory=output_file, # safe_serialization=True, # ) # else: if self.is_flux: # only save the unet transformer: FluxTransformer2DModel = self.unet transformer.save_pretrained( save_directory=os.path.join(output_file, 'transformer'), safe_serialization=True, ) else: self.pipeline.save_pretrained( save_directory=output_file, safe_serialization=True, ) # save out meta config meta_path = os.path.join(output_file, 'aitk_meta.yaml') with open(meta_path, 'w') as f: yaml.dump(meta, f) else: save_ldm_model_from_diffusers( sd=self, output_file=output_file, meta=meta, save_dtype=save_dtype, sd_version=version_string, ) if self.config_file is not None: output_path_no_ext = os.path.splitext(output_file)[0] output_config_path = f"{output_path_no_ext}.yaml" shutil.copyfile(self.config_file, output_config_path) def prepare_optimizer_params( self, unet=False, text_encoder=False, text_encoder_lr=None, unet_lr=None, refiner_lr=None, refiner=False, default_lr=1e-6, ): # todo maybe only get locon ones? # not all items are saved, to make it match, we need to match out save mappings # and not train anything not mapped. Also add learning rate version = 'sd1' if self.is_xl: version = 'sdxl' if self.is_v2: version = 'sd2' mapping_filename = f"stable_diffusion_{version}.json" mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) with open(mapping_path, 'r') as f: mapping = json.load(f) ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] trainable_parameters = [] # we use state dict to find params if unet: named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] if self.is_pixart or self.is_auraflow or self.is_flux: for param in named_params.values(): if param.requires_grad: params.append(param) else: for key, diffusers_key in ldm_diffusers_keymap.items(): if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: if named_params[diffusers_key].requires_grad: params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": unet_lr} trainable_parameters.append(param_data) print(f"Found {len(params)} trainable parameter in unet") if text_encoder: named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr params = [] for key, diffusers_key in ldm_diffusers_keymap.items(): if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: if named_params[diffusers_key].requires_grad: params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": text_encoder_lr} trainable_parameters.append(param_data) print(f"Found {len(params)} trainable parameter in text encoder") if refiner: named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, state_dict_keys=True) refiner_lr = refiner_lr if refiner_lr is not None else default_lr params = [] for key, diffusers_key in ldm_diffusers_keymap.items(): diffusers_key = f"refiner_{diffusers_key}" if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: if named_params[diffusers_key].requires_grad: params.append(named_params[diffusers_key]) param_data = {"params": params, "lr": refiner_lr} trainable_parameters.append(param_data) print(f"Found {len(params)} trainable parameter in refiner") return trainable_parameters def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: unet_has_grad = self.unet.proj_out.weight.requires_grad else: unet_has_grad = self.unet.conv_in.weight.requires_grad self.device_state = { **empty_preset, 'vae': { 'training': self.vae.training, 'device': self.vae.device, }, 'unet': { 'training': self.unet.training, 'device': self.unet.device, 'requires_grad': unet_has_grad, }, } if isinstance(self.text_encoder, list): self.device_state['text_encoder']: List[dict] = [] for encoder in self.text_encoder: try: te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad except: te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad self.device_state['text_encoder'].append({ 'training': encoder.training, 'device': encoder.device, # todo there has to be a better way to do this 'requires_grad': te_has_grad }) else: if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad else: te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad self.device_state['text_encoder'] = { 'training': self.text_encoder.training, 'device': self.text_encoder.device, 'requires_grad': te_has_grad } if self.adapter is not None: if isinstance(self.adapter, IPAdapter): requires_grad = self.adapter.image_proj_model.training adapter_device = self.unet.device elif isinstance(self.adapter, T2IAdapter): requires_grad = self.adapter.adapter.conv_in.weight.requires_grad adapter_device = self.adapter.device elif isinstance(self.adapter, ControlNetModel): requires_grad = self.adapter.conv_in.training adapter_device = self.adapter.device elif isinstance(self.adapter, ClipVisionAdapter): requires_grad = self.adapter.embedder.training adapter_device = self.adapter.device elif isinstance(self.adapter, CustomAdapter): requires_grad = self.adapter.training adapter_device = self.adapter.device elif isinstance(self.adapter, ReferenceAdapter): # todo update this!! requires_grad = True adapter_device = self.adapter.device else: raise ValueError(f"Unknown adapter type: {type(self.adapter)}") self.device_state['adapter'] = { 'training': self.adapter.training, 'device': adapter_device, 'requires_grad': requires_grad, } if self.refiner_unet is not None: self.device_state['refiner_unet'] = { 'training': self.refiner_unet.training, 'device': self.refiner_unet.device, 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, } def restore_device_state(self): # restores the device state for all modules # this is useful for when we want to alter the state and restore it if self.device_state is None: return self.set_device_state(self.device_state) self.device_state = None def set_device_state(self, state): if state['vae']['training']: self.vae.train() else: self.vae.eval() self.vae.to(state['vae']['device']) if state['unet']['training']: self.unet.train() else: self.unet.eval() self.unet.to(state['unet']['device']) if state['unet']['requires_grad']: self.unet.requires_grad_(True) else: self.unet.requires_grad_(False) if isinstance(self.text_encoder, list): for i, encoder in enumerate(self.text_encoder): if isinstance(state['text_encoder'], list): if state['text_encoder'][i]['training']: encoder.train() else: encoder.eval() encoder.to(state['text_encoder'][i]['device']) encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) else: if state['text_encoder']['training']: encoder.train() else: encoder.eval() encoder.to(state['text_encoder']['device']) encoder.requires_grad_(state['text_encoder']['requires_grad']) else: if state['text_encoder']['training']: self.text_encoder.train() else: self.text_encoder.eval() self.text_encoder.to(state['text_encoder']['device']) self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) if self.adapter is not None: self.adapter.to(state['adapter']['device']) self.adapter.requires_grad_(state['adapter']['requires_grad']) if state['adapter']['training']: self.adapter.train() else: self.adapter.eval() if self.refiner_unet is not None: self.refiner_unet.to(state['refiner_unet']['device']) self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad']) if state['refiner_unet']['training']: self.refiner_unet.train() else: self.refiner_unet.eval() flush() def set_device_state_preset(self, device_state_preset: DeviceStatePreset): # sets a preset for device state # save current state first self.save_device_state() active_modules = [] training_modules = [] if device_state_preset in ['cache_latents']: active_modules = ['vae'] if device_state_preset in ['cache_clip']: active_modules = ['clip'] if device_state_preset in ['generate']: active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet'] state = copy.deepcopy(empty_preset) # vae state['vae'] = { 'training': 'vae' in training_modules, 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', 'requires_grad': 'vae' in training_modules, } # unet state['unet'] = { 'training': 'unet' in training_modules, 'device': self.device_torch if 'unet' in active_modules else 'cpu', 'requires_grad': 'unet' in training_modules, } if self.refiner_unet is not None: state['refiner_unet'] = { 'training': 'refiner_unet' in training_modules, 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', 'requires_grad': 'refiner_unet' in training_modules, } # text encoder if isinstance(self.text_encoder, list): state['text_encoder'] = [] for i, encoder in enumerate(self.text_encoder): state['text_encoder'].append({ 'training': 'text_encoder' in training_modules, 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', 'requires_grad': 'text_encoder' in training_modules, }) else: state['text_encoder'] = { 'training': 'text_encoder' in training_modules, 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', 'requires_grad': 'text_encoder' in training_modules, } if self.adapter is not None: state['adapter'] = { 'training': 'adapter' in training_modules, 'device': self.device_torch if 'adapter' in active_modules else 'cpu', 'requires_grad': 'adapter' in training_modules, } self.set_device_state(state) def text_encoder_to(self, *args, **kwargs): if isinstance(self.text_encoder, list): for encoder in self.text_encoder: encoder.to(*args, **kwargs) else: self.text_encoder.to(*args, **kwargs)