Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import copy | |
import math | |
import random | |
from typing import Any | |
import pdb | |
import torch | |
from accelerate import Accelerator | |
from library.device_utils import init_ipex, clean_memory_on_device | |
init_ipex() | |
from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util | |
from torchvision import transforms | |
import train_network | |
from library.utils import setup_logging | |
from diffusers.utils import load_image | |
import numpy as np | |
from PIL import Image, ImageOps | |
setup_logging() | |
import logging | |
logger = logging.getLogger(__name__) | |
# NUM_SPLIT = 2 | |
class ResizeWithPadding: | |
def __init__(self, size, fill=255): | |
self.size = size | |
self.fill = fill | |
def __call__(self, img): | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
elif not isinstance(img, Image.Image): | |
raise TypeError("Input must be a PIL Image or a NumPy array") | |
width, height = img.size | |
if width == height: | |
img = img.resize((self.size, self.size), Image.LANCZOS) | |
else: | |
max_dim = max(width, height) | |
new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill)) | |
new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2)) | |
img = new_img.resize((self.size, self.size), Image.LANCZOS) | |
return img | |
class FluxNetworkTrainer(train_network.NetworkTrainer): | |
def __init__(self): | |
super().__init__() | |
self.sample_prompts_te_outputs = None | |
self.sample_conditions = None | |
self.is_schnell: Optional[bool] = None | |
def assert_extra_args(self, args, train_dataset_group): | |
super().assert_extra_args(args, train_dataset_group) | |
# sdxl_train_util.verify_sdxl_training_args(args) | |
if args.fp8_base_unet: | |
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1 | |
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: | |
logger.warning( | |
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります" | |
) | |
args.cache_text_encoder_outputs = True | |
if args.cache_text_encoder_outputs: | |
assert ( | |
train_dataset_group.is_text_encoder_output_cacheable() | |
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" | |
# prepare CLIP-L/T5XXL training flags | |
self.train_clip_l = not args.network_train_unet_only | |
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False | |
if args.max_token_length is not None: | |
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") | |
assert not args.split_mode or not args.cpu_offload_checkpointing, ( | |
"split_mode and cpu_offload_checkpointing cannot be used together" | |
" / split_modeとcpu_offload_checkpointingは同時に使用できません" | |
) | |
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this | |
def load_target_model(self, args, weight_dtype, accelerator): | |
# currently offload to cpu for some models | |
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8) | |
loading_dtype = None if args.fp8_base else weight_dtype | |
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future | |
self.is_schnell, model = flux_utils.load_flow_model( | |
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors | |
) | |
if args.fp8_base: | |
# check dtype of model | |
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz: | |
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}") | |
elif model.dtype == torch.float8_e4m3fn: | |
logger.info("Loaded fp8 FLUX model") | |
if args.split_mode: | |
model = self.prepare_split_model(model, weight_dtype, accelerator) | |
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) | |
clip_l.eval() | |
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8) | |
if args.fp8_base and not args.fp8_base_unet: | |
loading_dtype = None # as is | |
else: | |
loading_dtype = weight_dtype | |
# loading t5xxl to cpu takes a long time, so we should load to gpu in future | |
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) | |
t5xxl.eval() | |
if args.fp8_base and not args.fp8_base_unet: | |
# check dtype of model | |
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz: | |
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}") | |
elif t5xxl.dtype == torch.float8_e4m3fn: | |
logger.info("Loaded fp8 T5XXL model") | |
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors) | |
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model | |
def prepare_split_model(self, model, weight_dtype, accelerator): | |
from accelerate import init_empty_weights | |
logger.info("prepare split model") | |
with init_empty_weights(): | |
flux_upper = flux_models.FluxUpper(model.params) | |
flux_lower = flux_models.FluxLower(model.params) | |
sd = model.state_dict() | |
# lower (trainable) | |
logger.info("load state dict for lower") | |
flux_lower.load_state_dict(sd, strict=False, assign=True) | |
flux_lower.to(dtype=weight_dtype) | |
# upper (frozen) | |
logger.info("load state dict for upper") | |
flux_upper.load_state_dict(sd, strict=False, assign=True) | |
logger.info("prepare upper model") | |
target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype | |
flux_upper.to(accelerator.device, dtype=target_dtype) | |
flux_upper.eval() | |
if args.fp8_base: | |
# this is required to run on fp8 | |
flux_upper = accelerator.prepare(flux_upper) | |
flux_upper.to("cpu") | |
self.flux_upper = flux_upper | |
del model # we don't need model anymore | |
clean_memory_on_device(accelerator.device) | |
logger.info("split model prepared") | |
return flux_lower | |
def get_tokenize_strategy(self, args): | |
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path) | |
if args.t5xxl_max_token_length is None: | |
if is_schnell: | |
t5xxl_max_token_length = 256 | |
else: | |
t5xxl_max_token_length = 512 | |
else: | |
t5xxl_max_token_length = args.t5xxl_max_token_length | |
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") | |
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) | |
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): | |
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] | |
def get_latents_caching_strategy(self, args): | |
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False) | |
return latents_caching_strategy | |
def get_text_encoding_strategy(self, args): | |
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) | |
def post_process_network(self, args, accelerator, network, text_encoders, unet): | |
# check t5xxl is trained or not | |
self.train_t5xxl = network.train_t5xxl | |
if self.train_t5xxl and args.cache_text_encoder_outputs: | |
raise ValueError( | |
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません" | |
) | |
def get_models_for_text_encoding(self, args, accelerator, text_encoders): | |
if args.cache_text_encoder_outputs: | |
if self.train_clip_l and not self.train_t5xxl: | |
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached | |
else: | |
return None # no text encoders are needed for encoding because both are cached | |
else: | |
return text_encoders # both CLIP-L and T5XXL are needed for encoding | |
def get_text_encoders_train_flags(self, args, text_encoders): | |
return [self.train_clip_l, self.train_t5xxl] | |
def get_text_encoder_outputs_caching_strategy(self, args): | |
if args.cache_text_encoder_outputs: | |
# if the text encoders is trained, we need tokenization, so is_partial is True | |
return strategy_flux.FluxTextEncoderOutputsCachingStrategy( | |
args.cache_text_encoder_outputs_to_disk, | |
args.text_encoder_batch_size, | |
args.skip_cache_check, | |
is_partial=self.train_clip_l or self.train_t5xxl, | |
apply_t5_attn_mask=args.apply_t5_attn_mask, | |
) | |
else: | |
return None | |
def cache_text_encoder_outputs_if_needed( | |
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype | |
): | |
if args.cache_text_encoder_outputs: | |
if not args.lowram: | |
# メモリ消費を減らす | |
logger.info("move vae and unet to cpu to save memory") | |
org_vae_device = vae.device | |
org_unet_device = unet.device | |
vae.to("cpu") | |
unet.to("cpu") | |
clean_memory_on_device(accelerator.device) | |
# When TE is not be trained, it will not be prepared so we need to use explicit autocast | |
logger.info("move text encoders to gpu") | |
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 | |
text_encoders[1].to(accelerator.device) | |
if text_encoders[1].dtype == torch.float8_e4m3fn: | |
# if we load fp8 weights, the model is already fp8, so we use it as is | |
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) | |
else: | |
# otherwise, we need to convert it to target dtype | |
text_encoders[1].to(weight_dtype) | |
with accelerator.autocast(): | |
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) | |
# cache sample prompts | |
if args.sample_prompts is not None: | |
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}") | |
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy() | |
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() | |
prompts = train_util.load_prompts(args.sample_prompts) | |
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs | |
with accelerator.autocast(), torch.no_grad(): | |
for prompt_dict in prompts: | |
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: | |
if p not in sample_prompts_te_outputs: | |
logger.info(f"cache Text Encoder outputs for prompt: {p}") | |
tokens_and_masks = tokenize_strategy.tokenize(p) | |
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( | |
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask | |
) | |
self.sample_prompts_te_outputs = sample_prompts_te_outputs | |
# 添加conditions缓存逻辑 | |
if args.sample_images is not None: | |
logger.info(f"cache conditions for sample images: {args.sample_images}") | |
# lc03lc | |
resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255) | |
img_transforms = transforms.Compose([ | |
resize_transform, | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
]) | |
if args.sample_images.endswith(".txt"): | |
with open(args.sample_images, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
sample_images = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] | |
else: | |
raise NotImplementedError(f"sample_images file format not supported: {args.sample_images}") | |
prompts = train_util.load_prompts(args.sample_prompts) | |
conditions = {} # key: prompt, value: latents | |
with torch.no_grad(): | |
for image, prompt_dict in zip(sample_images, prompts): | |
prompt = prompt_dict.get("prompt", "") | |
if prompt not in conditions: | |
logger.info(f"cache conditions for image: {image} with prompt: {prompt}") | |
image = img_transforms(np.array(load_image(image), dtype=np.uint8)).unsqueeze(0).to(vae.device, dtype=vae.dtype) | |
latents = self.encode_images_to_latents2(args, accelerator, vae, image) | |
# lc03lc | |
conditions[prompt] = latents | |
# if args.frame_num == 4: | |
# conditions[prompt] = latents[:,:,2*latents.shape[2]//3:latents.shape[2], 2*latents.shape[3]//3:latents.shape[3]].to("cpu") | |
# else: | |
# conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu") | |
self.sample_conditions = conditions | |
accelerator.wait_for_everyone() | |
# move back to cpu | |
if not self.is_train_text_encoder(args): | |
logger.info("move CLIP-L back to cpu") | |
text_encoders[0].to("cpu") | |
logger.info("move t5XXL back to cpu") | |
text_encoders[1].to("cpu") | |
clean_memory_on_device(accelerator.device) | |
if not args.lowram: | |
logger.info("move vae and unet back to original device") | |
vae.to(org_vae_device) | |
unet.to(org_unet_device) | |
else: | |
# Text Encoderから毎回出力を取得するので、GPUに乗せておく | |
text_encoders[0].to(accelerator.device, dtype=weight_dtype) | |
text_encoders[1].to(accelerator.device) | |
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): | |
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype | |
# # get size embeddings | |
# orig_size = batch["original_sizes_hw"] | |
# crop_size = batch["crop_top_lefts"] | |
# target_size = batch["target_sizes_hw"] | |
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) | |
# # concat embeddings | |
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds | |
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) | |
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) | |
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) | |
# return noise_pred | |
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): | |
text_encoders = text_encoder # for compatibility | |
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) | |
# 直接使用预先计算的conditions | |
conditions = None | |
if self.sample_conditions is not None: | |
conditions = {k: v.to(accelerator.device) for k, v in self.sample_conditions.items()} | |
if not args.split_mode: | |
flux_train_utils.sample_images( | |
accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs, None, conditions | |
) | |
return | |
class FluxUpperLowerWrapper(torch.nn.Module): | |
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device): | |
super().__init__() | |
self.flux_upper = flux_upper | |
self.flux_lower = flux_lower | |
self.target_device = device | |
def prepare_block_swap_before_forward(self): | |
pass | |
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): | |
self.flux_lower.to("cpu") | |
clean_memory_on_device(self.target_device) | |
self.flux_upper.to(self.target_device) | |
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask) | |
self.flux_upper.to("cpu") | |
clean_memory_on_device(self.target_device) | |
self.flux_lower.to(self.target_device) | |
return self.flux_lower(img, txt, vec, pe, txt_attention_mask) | |
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device) | |
clean_memory_on_device(accelerator.device) | |
flux_train_utils.sample_images( | |
accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs, conditions | |
) | |
clean_memory_on_device(accelerator.device) | |
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: | |
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) | |
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) | |
return noise_scheduler | |
def encode_images_to_latents(self, args, accelerator, vae, images): | |
# 获取图像尺寸 | |
b, c, h, w = images.shape | |
# num_split = NUM_SPLIT | |
num_split = 2 if args.frame_num == 4 else 3 | |
# 将图像分成三个部分 | |
img_parts = [images[:,:,:,i*w//num_split:(i+1)*w//num_split] for i in range(num_split)] | |
# 分别编码 | |
latents = [vae.encode(img) for img in img_parts] | |
# 在latent空间拼接回完整图像 | |
latents = torch.cat(latents, dim=-1) | |
return latents | |
def encode_images_to_latents2(self, args, accelerator, vae, images): | |
# 获取图像尺寸 | |
b, c, h, w = images.shape | |
# num_split = NUM_SPLIT | |
num_split = 2 if args.frame_num == 4 else 3 | |
latents = vae.encode(images) | |
return latents | |
def encode_images_to_latents3(self, args, accelerator, vae, images): | |
b, c, h, w = images.shape | |
# Number of splits along each dimension | |
num_split = 3 | |
# Check if the image can be evenly divided into 3x3 grid | |
assert h % num_split == 0 and w % num_split == 0, "Image dimensions must be divisible by 3." | |
# Height and width of each split | |
split_h, split_w = h // num_split, w // num_split | |
# Store latents for each split | |
latents = [] | |
for i in range(num_split): | |
for j in range(num_split): | |
# Extract the (i, j) sub-image | |
img_part = images[:, :, i * split_h:(i + 1) * split_h, j * split_w:(j + 1) * split_w] | |
# Encode the sub-image using VAE | |
latent = vae.encode(img_part) | |
# Append the latent | |
latents.append(latent) | |
# Combine latents into a 3x3 grid in the latent space | |
# Latents list -> Tensor [num_split^2, b, latent_dim, h', w'] | |
latents = torch.stack(latents, dim=0) | |
# Reshape into a 3x3 grid | |
# Shape: [num_split, num_split, b, latent_dim, h', w'] | |
latents = latents.view(num_split, num_split, b, *latents.shape[2:]) | |
# Combine the 3x3 grid along height and width in latent space | |
# Concatenate along width for each row, then concatenate rows along height | |
latents = torch.cat([torch.cat(latents[i], dim=-1) for i in range(num_split)], dim=-2) | |
# Final shape: [b, latent_dim, h', w'] | |
return latents | |
def shift_scale_latents(self, args, latents): | |
return latents | |
def get_noise_pred_and_target( | |
self, | |
args, | |
accelerator, | |
noise_scheduler, | |
latents, | |
batch, | |
text_encoder_conds, | |
unet: flux_models.Flux, | |
network, | |
weight_dtype, | |
train_unet, | |
): | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
# get noisy model input and timesteps | |
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( | |
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype | |
) | |
# pack latents and get img_ids | |
# yiren ? need modify? | |
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 | |
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 | |
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) | |
# get guidance | |
# ensure guidance_scale in args is float | |
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) | |
# ensure the hidden state will require grad | |
if args.gradient_checkpointing: | |
noisy_model_input.requires_grad_(True) | |
for t in text_encoder_conds: | |
if t is not None and t.dtype.is_floating_point: | |
t.requires_grad_(True) | |
img_ids.requires_grad_(True) | |
guidance_vec.requires_grad_(True) | |
# Predict the noise residual | |
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds | |
if not args.apply_t5_attn_mask: | |
t5_attn_mask = None | |
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): | |
if not args.split_mode: | |
# normal forward | |
with accelerator.autocast(): | |
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) | |
model_pred = unet( | |
img=img, | |
img_ids=img_ids, | |
txt=t5_out, | |
txt_ids=txt_ids, | |
y=l_pooled, | |
timesteps=timesteps / 1000, | |
guidance=guidance_vec, | |
txt_attention_mask=t5_attn_mask, | |
) | |
else: | |
# split forward to reduce memory usage | |
assert network.train_blocks == "single", "train_blocks must be single for split mode" | |
with accelerator.autocast(): | |
# move flux lower to cpu, and then move flux upper to gpu | |
unet.to("cpu") | |
clean_memory_on_device(accelerator.device) | |
self.flux_upper.to(accelerator.device) | |
# upper model does not require grad | |
with torch.no_grad(): | |
intermediate_img, intermediate_txt, vec, pe = self.flux_upper( | |
img=packed_noisy_model_input, | |
img_ids=img_ids, | |
txt=t5_out, | |
txt_ids=txt_ids, | |
y=l_pooled, | |
timesteps=timesteps / 1000, | |
guidance=guidance_vec, | |
txt_attention_mask=t5_attn_mask, | |
) | |
# move flux upper back to cpu, and then move flux lower to gpu | |
self.flux_upper.to("cpu") | |
clean_memory_on_device(accelerator.device) | |
unet.to(accelerator.device) | |
# lower model requires grad | |
intermediate_img.requires_grad_(True) | |
intermediate_txt.requires_grad_(True) | |
vec.requires_grad_(True) | |
pe.requires_grad_(True) | |
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) | |
return model_pred | |
model_pred = call_dit( | |
img=packed_noisy_model_input, | |
img_ids=img_ids, | |
t5_out=t5_out, | |
txt_ids=txt_ids, | |
l_pooled=l_pooled, | |
timesteps=timesteps, | |
guidance_vec=guidance_vec, | |
t5_attn_mask=t5_attn_mask, | |
) | |
# unpack latents | |
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) | |
# apply model prediction type | |
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) | |
# flow matching loss: this is different from SD3 | |
target = noise - latents | |
# differential output preservation | |
if "custom_attributes" in batch: | |
diff_output_pr_indices = [] | |
for i, custom_attributes in enumerate(batch["custom_attributes"]): | |
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: | |
diff_output_pr_indices.append(i) | |
if len(diff_output_pr_indices) > 0: | |
network.set_multiplier(0.0) | |
with torch.no_grad(): | |
model_pred_prior = call_dit( | |
img=packed_noisy_model_input[diff_output_pr_indices], | |
img_ids=img_ids[diff_output_pr_indices], | |
t5_out=t5_out[diff_output_pr_indices], | |
txt_ids=txt_ids[diff_output_pr_indices], | |
l_pooled=l_pooled[diff_output_pr_indices], | |
timesteps=timesteps[diff_output_pr_indices], | |
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, | |
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, | |
) | |
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step | |
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) | |
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( | |
args, | |
model_pred_prior, | |
noisy_model_input[diff_output_pr_indices], | |
sigmas[diff_output_pr_indices] if sigmas is not None else None, | |
) | |
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) | |
# elimilate the loss in the left top quarter of the image | |
h, w = target.shape[2], target.shape[3] | |
# num_split = NUM_SPLIT | |
num_split = 2 if args.frame_num == 4 else 3 | |
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split] | |
# target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split] | |
target[:, :, 2*h//num_split:h, 2*w//num_split:w] = model_pred[:, :, 2*h//num_split:h, 2*w//num_split:w] | |
return model_pred, target, timesteps, None, weighting | |
def post_process_loss(self, loss, args, timesteps, noise_scheduler): | |
return loss | |
def get_sai_model_spec(self, args): | |
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev") | |
def update_metadata(self, metadata, args): | |
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask | |
metadata["ss_weighting_scheme"] = args.weighting_scheme | |
metadata["ss_logit_mean"] = args.logit_mean | |
metadata["ss_logit_std"] = args.logit_std | |
metadata["ss_mode_scale"] = args.mode_scale | |
metadata["ss_guidance_scale"] = args.guidance_scale | |
metadata["ss_timestep_sampling"] = args.timestep_sampling | |
metadata["ss_sigmoid_scale"] = args.sigmoid_scale | |
metadata["ss_model_prediction_type"] = args.model_prediction_type | |
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift | |
def is_text_encoder_not_needed_for_training(self, args): | |
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) | |
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder): | |
if index == 0: # CLIP-L | |
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder) | |
else: # T5XXL | |
text_encoder.encoder.embed_tokens.requires_grad_(True) | |
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): | |
if index == 0: # CLIP-L | |
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") | |
text_encoder.to(te_weight_dtype) # fp8 | |
text_encoder.text_model.embeddings.to(dtype=weight_dtype) | |
else: # T5XXL | |
def prepare_fp8(text_encoder, target_dtype): | |
def forward_hook(module): | |
def forward(hidden_states): | |
hidden_gelu = module.act(module.wi_0(hidden_states)) | |
hidden_linear = module.wi_1(hidden_states) | |
hidden_states = hidden_gelu * hidden_linear | |
hidden_states = module.dropout(hidden_states) | |
hidden_states = module.wo(hidden_states) | |
return hidden_states | |
return forward | |
for module in text_encoder.modules(): | |
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: | |
# print("set", module.__class__.__name__, "to", target_dtype) | |
module.to(target_dtype) | |
if module.__class__.__name__ in ["T5DenseGatedActDense"]: | |
# print("set", module.__class__.__name__, "hooks") | |
module.forward = forward_hook(module) | |
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype: | |
logger.info(f"T5XXL already prepared for fp8") | |
else: | |
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") | |
text_encoder.to(te_weight_dtype) # fp8 | |
prepare_fp8(text_encoder, weight_dtype) | |
def setup_parser() -> argparse.ArgumentParser: | |
parser = train_network.setup_parser() | |
flux_train_utils.add_flux_train_arguments(parser) | |
parser.add_argument( | |
"--split_mode", | |
action="store_true", | |
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" | |
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", | |
) | |
parser.add_argument( | |
'--frame_num', | |
type=int, | |
choices=[4, 9], | |
required=True, | |
help="The number of steps in the generated step diagram (choose 4 or 9)" | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = setup_parser() | |
args = parser.parse_args() | |
train_util.verify_command_line_training_args(args) | |
args = train_util.read_config_from_file(args, parser) | |
trainer = FluxNetworkTrainer() | |
trainer.train(args) |