Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import textwrap | |
| import warnings | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Any, Callable, Optional, Union | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from transformers import is_wandb_available | |
| from ..models import DDPOStableDiffusionPipeline | |
| from .alignprop_config import AlignPropConfig | |
| from .utils import generate_model_card, get_comet_experiment_url | |
| if is_wandb_available(): | |
| import wandb | |
| logger = get_logger(__name__) | |
| class AlignPropTrainer(PyTorchModelHubMixin): | |
| """ | |
| The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is | |
| heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based | |
| pipelines are supported | |
| Attributes: | |
| config (`AlignPropConfig`): | |
| Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details. | |
| reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`): | |
| Reward function to be used | |
| prompt_function (`Callable[[], tuple[str, Any]]`): | |
| Function to generate prompts to guide model | |
| sd_pipeline (`DDPOStableDiffusionPipeline`): | |
| Stable Diffusion pipeline to be used for training. | |
| image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`): | |
| Hook to be called to log images | |
| """ | |
| _tag_names = ["trl", "alignprop"] | |
| def __init__( | |
| self, | |
| config: AlignPropConfig, | |
| reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], | |
| prompt_function: Callable[[], tuple[str, Any]], | |
| sd_pipeline: DDPOStableDiffusionPipeline, | |
| image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, | |
| ): | |
| warnings.warn( | |
| "AlignPropTrainer is deprecated and will be removed in version 0.23.0.", | |
| DeprecationWarning, | |
| ) | |
| if image_samples_hook is None: | |
| warnings.warn("No image_samples_hook provided; no images will be logged") | |
| self.prompt_fn = prompt_function | |
| self.reward_fn = reward_function | |
| self.config = config | |
| self.image_samples_callback = image_samples_hook | |
| accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) | |
| if self.config.resume_from: | |
| self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) | |
| if "checkpoint_" not in os.path.basename(self.config.resume_from): | |
| # get the most recent checkpoint in this directory | |
| checkpoints = list( | |
| filter( | |
| lambda x: "checkpoint_" in x, | |
| os.listdir(self.config.resume_from), | |
| ) | |
| ) | |
| if len(checkpoints) == 0: | |
| raise ValueError(f"No checkpoints found in {self.config.resume_from}") | |
| checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) | |
| self.config.resume_from = os.path.join( | |
| self.config.resume_from, | |
| f"checkpoint_{checkpoint_numbers[-1]}", | |
| ) | |
| accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 | |
| self.accelerator = Accelerator( | |
| log_with=self.config.log_with, | |
| mixed_precision=self.config.mixed_precision, | |
| project_config=accelerator_project_config, | |
| # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the | |
| # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get | |
| # the total number of optimizer steps to accumulate across. | |
| gradient_accumulation_steps=self.config.train_gradient_accumulation_steps, | |
| **self.config.accelerator_kwargs, | |
| ) | |
| is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" | |
| if self.accelerator.is_main_process: | |
| self.accelerator.init_trackers( | |
| self.config.tracker_project_name, | |
| config=dict(alignprop_trainer_config=config.to_dict()) | |
| if not is_using_tensorboard | |
| else config.to_dict(), | |
| init_kwargs=self.config.tracker_kwargs, | |
| ) | |
| logger.info(f"\n{config}") | |
| set_seed(self.config.seed, device_specific=True) | |
| self.sd_pipeline = sd_pipeline | |
| self.sd_pipeline.set_progress_bar_config( | |
| position=1, | |
| disable=not self.accelerator.is_local_main_process, | |
| leave=False, | |
| desc="Timestep", | |
| dynamic_ncols=True, | |
| ) | |
| # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision | |
| # as these weights are only used for inference, keeping weights in full precision is not required. | |
| if self.accelerator.mixed_precision == "fp16": | |
| inference_dtype = torch.float16 | |
| elif self.accelerator.mixed_precision == "bf16": | |
| inference_dtype = torch.bfloat16 | |
| else: | |
| inference_dtype = torch.float32 | |
| self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) | |
| self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) | |
| self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) | |
| trainable_layers = self.sd_pipeline.get_trainable_layers() | |
| self.accelerator.register_save_state_pre_hook(self._save_model_hook) | |
| self.accelerator.register_load_state_pre_hook(self._load_model_hook) | |
| # Enable TF32 for faster training on Ampere GPUs, | |
| # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
| if self.config.allow_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| self.optimizer = self._setup_optimizer( | |
| trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers | |
| ) | |
| self.neg_prompt_embed = self.sd_pipeline.text_encoder( | |
| self.sd_pipeline.tokenizer( | |
| [""] if self.config.negative_prompts is None else self.config.negative_prompts, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.sd_pipeline.tokenizer.model_max_length, | |
| ).input_ids.to(self.accelerator.device) | |
| )[0] | |
| # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses | |
| # more memory | |
| self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast | |
| if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: | |
| unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) | |
| self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) | |
| else: | |
| self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) | |
| if config.resume_from: | |
| logger.info(f"Resuming from {config.resume_from}") | |
| self.accelerator.load_state(config.resume_from) | |
| self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 | |
| else: | |
| self.first_epoch = 0 | |
| def compute_rewards(self, prompt_image_pairs): | |
| reward, reward_metadata = self.reward_fn( | |
| prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"] | |
| ) | |
| return reward | |
| def step(self, epoch: int, global_step: int): | |
| """ | |
| Perform a single step of training. | |
| Args: | |
| epoch (int): The current epoch. | |
| global_step (int): The current global step. | |
| Side Effects: | |
| - Model weights are updated | |
| - Logs the statistics to the accelerator trackers. | |
| - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, | |
| and the accelerator tracker. | |
| Returns: | |
| global_step (int): The updated global step. | |
| """ | |
| info = defaultdict(list) | |
| self.sd_pipeline.unet.train() | |
| for _ in range(self.config.train_gradient_accumulation_steps): | |
| with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad(): | |
| prompt_image_pairs = self._generate_samples( | |
| batch_size=self.config.train_batch_size, | |
| ) | |
| rewards = self.compute_rewards(prompt_image_pairs) | |
| prompt_image_pairs["rewards"] = rewards | |
| rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy() | |
| loss = self.calculate_loss(rewards) | |
| self.accelerator.backward(loss) | |
| if self.accelerator.sync_gradients: | |
| self.accelerator.clip_grad_norm_( | |
| self.trainable_layers.parameters() | |
| if not isinstance(self.trainable_layers, list) | |
| else self.trainable_layers, | |
| self.config.train_max_grad_norm, | |
| ) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| info["reward_mean"].append(rewards_vis.mean()) | |
| info["reward_std"].append(rewards_vis.std()) | |
| info["loss"].append(loss.item()) | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if self.accelerator.sync_gradients: | |
| # log training-related stuff | |
| info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()} | |
| info = self.accelerator.reduce(info, reduction="mean") | |
| info.update({"epoch": epoch}) | |
| self.accelerator.log(info, step=global_step) | |
| global_step += 1 | |
| info = defaultdict(list) | |
| else: | |
| raise ValueError( | |
| "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." | |
| ) | |
| # Logs generated images | |
| if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0: | |
| self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) | |
| if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: | |
| self.accelerator.save_state() | |
| return global_step | |
| def calculate_loss(self, rewards): | |
| """ | |
| Calculate the loss for a batch of an unpacked sample | |
| Args: | |
| rewards (torch.Tensor): | |
| Differentiable reward scalars for each generated image, shape: [batch_size] | |
| Returns: | |
| loss (torch.Tensor) (all of these are of shape (1,)) | |
| """ | |
| # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739) | |
| loss = 10.0 - (rewards).mean() | |
| return loss | |
| def loss( | |
| self, | |
| advantages: torch.Tensor, | |
| clip_range: float, | |
| ratio: torch.Tensor, | |
| ): | |
| unclipped_loss = -advantages * ratio | |
| clipped_loss = -advantages * torch.clamp( | |
| ratio, | |
| 1.0 - clip_range, | |
| 1.0 + clip_range, | |
| ) | |
| return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) | |
| def _setup_optimizer(self, trainable_layers_parameters): | |
| if self.config.train_use_8bit_adam: | |
| import bitsandbytes | |
| optimizer_cls = bitsandbytes.optim.AdamW8bit | |
| else: | |
| optimizer_cls = torch.optim.AdamW | |
| return optimizer_cls( | |
| trainable_layers_parameters, | |
| lr=self.config.train_learning_rate, | |
| betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), | |
| weight_decay=self.config.train_adam_weight_decay, | |
| eps=self.config.train_adam_epsilon, | |
| ) | |
| def _save_model_hook(self, models, weights, output_dir): | |
| self.sd_pipeline.save_checkpoint(models, weights, output_dir) | |
| weights.pop() # ensures that accelerate doesn't try to handle saving of the model | |
| def _load_model_hook(self, models, input_dir): | |
| self.sd_pipeline.load_checkpoint(models, input_dir) | |
| models.pop() # ensures that accelerate doesn't try to handle loading of the model | |
| def _generate_samples(self, batch_size, with_grad=True, prompts=None): | |
| """ | |
| Generate samples from the model | |
| Args: | |
| batch_size (int): Batch size to use for sampling | |
| with_grad (bool): Whether the generated RGBs should have gradients attached to it. | |
| Returns: | |
| prompt_image_pairs (dict[Any]) | |
| """ | |
| prompt_image_pairs = {} | |
| sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) | |
| if prompts is None: | |
| prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) | |
| else: | |
| prompt_metadata = [{} for _ in range(batch_size)] | |
| prompt_ids = self.sd_pipeline.tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.sd_pipeline.tokenizer.model_max_length, | |
| ).input_ids.to(self.accelerator.device) | |
| prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] | |
| if with_grad: | |
| sd_output = self.sd_pipeline.rgb_with_grad( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=sample_neg_prompt_embeds, | |
| num_inference_steps=self.config.sample_num_steps, | |
| guidance_scale=self.config.sample_guidance_scale, | |
| eta=self.config.sample_eta, | |
| truncated_backprop_rand=self.config.truncated_backprop_rand, | |
| truncated_backprop_timestep=self.config.truncated_backprop_timestep, | |
| truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax, | |
| output_type="pt", | |
| ) | |
| else: | |
| sd_output = self.sd_pipeline( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=sample_neg_prompt_embeds, | |
| num_inference_steps=self.config.sample_num_steps, | |
| guidance_scale=self.config.sample_guidance_scale, | |
| eta=self.config.sample_eta, | |
| output_type="pt", | |
| ) | |
| images = sd_output.images | |
| prompt_image_pairs["images"] = images | |
| prompt_image_pairs["prompts"] = prompts | |
| prompt_image_pairs["prompt_metadata"] = prompt_metadata | |
| return prompt_image_pairs | |
| def train(self, epochs: Optional[int] = None): | |
| """ | |
| Train the model for a given number of epochs | |
| """ | |
| global_step = 0 | |
| if epochs is None: | |
| epochs = self.config.num_epochs | |
| for epoch in range(self.first_epoch, epochs): | |
| global_step = self.step(epoch, global_step) | |
| def _save_pretrained(self, save_directory): | |
| self.sd_pipeline.save_pretrained(save_directory) | |
| self.create_model_card() | |
| # Ensure the model card is saved along with the checkpoint | |
| def _save_checkpoint(self, model, trial): | |
| if self.args.hub_model_id is None: | |
| model_name = Path(self.args.output_dir).name | |
| else: | |
| model_name = self.args.hub_model_id.split("/")[-1] | |
| self.create_model_card(model_name=model_name) | |
| super()._save_checkpoint(model, trial) | |
| def create_model_card( | |
| self, | |
| model_name: Optional[str] = None, | |
| dataset_name: Optional[str] = None, | |
| tags: Union[str, list[str], None] = None, | |
| ): | |
| """ | |
| Creates a draft of a model card using the information available to the `Trainer`. | |
| Args: | |
| model_name (`str` or `None`, *optional*, defaults to `None`): | |
| Name of the model. | |
| dataset_name (`str` or `None`, *optional*, defaults to `None`): | |
| Name of the dataset used for training. | |
| tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): | |
| Tags to be associated with the model card. | |
| """ | |
| if not self.is_world_process_zero(): | |
| return | |
| if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): | |
| base_model = self.model.config._name_or_path | |
| else: | |
| base_model = None | |
| # normalize `tags` to a mutable set | |
| if tags is None: | |
| tags = set() | |
| elif isinstance(tags, str): | |
| tags = {tags} | |
| else: | |
| tags = set(tags) | |
| if hasattr(self.model.config, "unsloth_version"): | |
| tags.add("unsloth") | |
| tags.update(self._tag_names) | |
| citation = textwrap.dedent("""\ | |
| @article{prabhudesai2024aligning, | |
| title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}}, | |
| author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki}, | |
| year = 2024, | |
| eprint = {arXiv:2310.03739} | |
| }""") | |
| model_card = generate_model_card( | |
| base_model=base_model, | |
| model_name=model_name, | |
| hub_model_id=self.hub_model_id, | |
| dataset_name=dataset_name, | |
| tags=tags, | |
| wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, | |
| comet_url=get_comet_experiment_url(), | |
| trainer_name="AlignProp", | |
| trainer_citation=citation, | |
| paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation", | |
| paper_id="2310.03739", | |
| ) | |
| model_card.save(os.path.join(self.args.output_dir, "README.md")) | |