Spaces:
Runtime error
Runtime error
| # ************************************************************************* | |
| # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- | |
| # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- | |
| # ytedance Inc.. | |
| # ************************************************************************* | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| from einops import rearrange | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from PIL import Image | |
| from transformers import AutoTokenizer, PretrainedConfig | |
| import diffusers | |
| from diffusers import ( | |
| AutoencoderKL, | |
| DDPMScheduler, | |
| DiffusionPipeline, | |
| DPMSolverMultistepScheduler, | |
| StableDiffusionPipeline, | |
| UNet2DConditionModel, | |
| ) | |
| from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin | |
| from diffusers.models.attention_processor import ( | |
| AttnAddedKVProcessor, | |
| AttnAddedKVProcessor2_0, | |
| SlicedAttnAddedKVProcessor, | |
| ) | |
| from diffusers.models.lora import LoRALinearLayer | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.training_utils import unet_lora_state_dict | |
| from diffusers.utils import check_min_version, is_wandb_available | |
| from diffusers.utils.import_utils import is_xformers_available | |
| # Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
| check_min_version("0.24.0") | |
| def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| revision=revision, | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| elif model_class == "RobertaSeriesModelWithTransformation": | |
| from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation | |
| return RobertaSeriesModelWithTransformation | |
| elif model_class == "T5EncoderModel": | |
| from transformers import T5EncoderModel | |
| return T5EncoderModel | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): | |
| if tokenizer_max_length is not None: | |
| max_length = tokenizer_max_length | |
| else: | |
| max_length = tokenizer.model_max_length | |
| text_inputs = tokenizer( | |
| prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt", | |
| ) | |
| return text_inputs | |
| def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): | |
| text_input_ids = input_ids.to(text_encoder.device) | |
| if text_encoder_use_attention_mask: | |
| attention_mask = attention_mask.to(text_encoder.device) | |
| else: | |
| attention_mask = None | |
| prompt_embeds = text_encoder( | |
| text_input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| prompt_embeds = prompt_embeds[0] | |
| return prompt_embeds | |
| # model_path: path of the model | |
| # image: input image, have not been pre-processed | |
| # save_lora_path: the path to save the lora | |
| # prompt: the user input prompt | |
| # lora_step: number of lora training step | |
| # lora_lr: learning rate of lora training | |
| # lora_rank: the rank of lora | |
| # save_interval: the frequency of saving lora checkpoints | |
| def train_lora(image, | |
| prompt, | |
| model_path, | |
| vae_path, | |
| save_lora_path, | |
| lora_step, | |
| lora_lr, | |
| lora_batch_size, | |
| lora_rank, | |
| progress, | |
| save_interval=-1): | |
| # initialize accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=1, | |
| mixed_precision='fp16' | |
| ) | |
| set_seed(0) | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| subfolder="tokenizer", | |
| revision=None, | |
| use_fast=False, | |
| ) | |
| # initialize the model | |
| noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") | |
| text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) | |
| text_encoder = text_encoder_cls.from_pretrained( | |
| model_path, subfolder="text_encoder", revision=None | |
| ) | |
| if vae_path == "default": | |
| vae = AutoencoderKL.from_pretrained( | |
| model_path, subfolder="vae", revision=None | |
| ) | |
| else: | |
| vae = AutoencoderKL.from_pretrained(vae_path) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| model_path, subfolder="unet", revision=None | |
| ) | |
| pipeline = StableDiffusionPipeline.from_pretrained( | |
| pretrained_model_name_or_path=model_path, | |
| vae=vae, | |
| unet=unet, | |
| text_encoder=text_encoder, | |
| scheduler=noise_scheduler, | |
| torch_dtype=torch.float16) | |
| # set device and dtype | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| unet.requires_grad_(False) | |
| unet.to(device, dtype=torch.float16) | |
| vae.to(device, dtype=torch.float16) | |
| text_encoder.to(device, dtype=torch.float16) | |
| # Set correct lora layers | |
| unet_lora_parameters = [] | |
| for attn_processor_name, attn_processor in unet.attn_processors.items(): | |
| # Parse the attention module. | |
| attn_module = unet | |
| for n in attn_processor_name.split(".")[:-1]: | |
| attn_module = getattr(attn_module, n) | |
| # Set the `lora_layer` attribute of the attention-related matrices. | |
| attn_module.to_q.set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.to_q.in_features, | |
| out_features=attn_module.to_q.out_features, | |
| rank=lora_rank | |
| ) | |
| ) | |
| attn_module.to_k.set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.to_k.in_features, | |
| out_features=attn_module.to_k.out_features, | |
| rank=lora_rank | |
| ) | |
| ) | |
| attn_module.to_v.set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.to_v.in_features, | |
| out_features=attn_module.to_v.out_features, | |
| rank=lora_rank | |
| ) | |
| ) | |
| attn_module.to_out[0].set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.to_out[0].in_features, | |
| out_features=attn_module.to_out[0].out_features, | |
| rank=lora_rank, | |
| ) | |
| ) | |
| # Accumulate the LoRA params to optimize. | |
| unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) | |
| unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) | |
| unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) | |
| unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) | |
| if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): | |
| attn_module.add_k_proj.set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.add_k_proj.in_features, | |
| out_features=attn_module.add_k_proj.out_features, | |
| rank=args.rank, | |
| ) | |
| ) | |
| attn_module.add_v_proj.set_lora_layer( | |
| LoRALinearLayer( | |
| in_features=attn_module.add_v_proj.in_features, | |
| out_features=attn_module.add_v_proj.out_features, | |
| rank=args.rank, | |
| ) | |
| ) | |
| unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) | |
| unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) | |
| # Optimizer creation | |
| params_to_optimize = (unet_lora_parameters) | |
| optimizer = torch.optim.AdamW( | |
| params_to_optimize, | |
| lr=lora_lr, | |
| betas=(0.9, 0.999), | |
| weight_decay=1e-2, | |
| eps=1e-08, | |
| ) | |
| lr_scheduler = get_scheduler( | |
| "constant", | |
| optimizer=optimizer, | |
| num_warmup_steps=0, | |
| num_training_steps=lora_step, | |
| num_cycles=1, | |
| power=1.0, | |
| ) | |
| # prepare accelerator | |
| # unet_lora_layers = accelerator.prepare_model(unet_lora_layers) | |
| # optimizer = accelerator.prepare_optimizer(optimizer) | |
| # lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) | |
| unet,optimizer,lr_scheduler = accelerator.prepare(unet,optimizer,lr_scheduler) | |
| # initialize text embeddings | |
| with torch.no_grad(): | |
| text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) | |
| text_embedding = encode_prompt( | |
| text_encoder, | |
| text_inputs.input_ids, | |
| text_inputs.attention_mask, | |
| text_encoder_use_attention_mask=False | |
| ) | |
| text_embedding = text_embedding.repeat(lora_batch_size, 1, 1) | |
| # initialize image transforms | |
| image_transforms_pil = transforms.Compose( | |
| [ | |
| transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.RandomCrop(512), | |
| ] | |
| ) | |
| image_transforms_tensor = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| for step in progress.tqdm(range(lora_step), desc="training LoRA"): | |
| unet.train() | |
| image_batch = [] | |
| image_pil_batch = [] | |
| for _ in range(lora_batch_size): | |
| # first store pil image | |
| image_transformed = image_transforms_pil(Image.fromarray(image)) | |
| image_pil_batch.append(image_transformed) | |
| # then store tensor image | |
| image_transformed = image_transforms_tensor(image_transformed).to(device, dtype=torch.float16) | |
| image_transformed = image_transformed.unsqueeze(dim=0) | |
| image_batch.append(image_transformed) | |
| # repeat the image_transformed to enable multi-batch training | |
| image_batch = torch.cat(image_batch, dim=0) | |
| latents_dist = vae.encode(image_batch).latent_dist | |
| model_input = latents_dist.sample() * vae.config.scaling_factor | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(model_input) | |
| bsz, channels, height, width = model_input.shape | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint( | |
| 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device | |
| ) | |
| timesteps = timesteps.long() | |
| # Add noise to the model input according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) | |
| # Predict the noise residual | |
| model_pred = unet(noisy_model_input, | |
| timesteps, | |
| text_embedding).sample | |
| # Get the target for loss depending on the prediction type | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| target = noise_scheduler.get_velocity(model_input, noise, timesteps) | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| if save_interval > 0 and (step + 1) % save_interval == 0: | |
| save_lora_path_intermediate = os.path.join(save_lora_path, str(step+1)) | |
| if not os.path.isdir(save_lora_path_intermediate): | |
| os.mkdir(save_lora_path_intermediate) | |
| # unet = unet.to(torch.float32) | |
| # unwrap_model is used to remove all special modules added when doing distributed training | |
| # so here, there is no need to call unwrap_model | |
| # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) | |
| unet_lora_layers = unet_lora_state_dict(unet) | |
| LoraLoaderMixin.save_lora_weights( | |
| save_directory=save_lora_path_intermediate, | |
| unet_lora_layers=unet_lora_layers, | |
| text_encoder_lora_layers=None, | |
| ) | |
| # unet = unet.to(torch.float16) | |
| # save the trained lora | |
| # unet = unet.to(torch.float32) | |
| # unwrap_model is used to remove all special modules added when doing distributed training | |
| # so here, there is no need to call unwrap_model | |
| # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) | |
| unet_lora_layers = unet_lora_state_dict(unet) | |
| LoraLoaderMixin.save_lora_weights( | |
| save_directory=save_lora_path, | |
| unet_lora_layers=unet_lora_layers, | |
| text_encoder_lora_layers=None, | |
| ) | |
| return | |