Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import tempfile | |
from typing import List, Optional | |
import numpy as np | |
import PIL.Image | |
import torch | |
from accelerate import Accelerator | |
from torchvision import transforms | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel | |
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin | |
from diffusers.models.attention_processor import ( | |
AttnAddedKVProcessor, | |
AttnAddedKVProcessor2_0, | |
LoRAAttnAddedKVProcessor, | |
LoRAAttnProcessor, | |
LoRAAttnProcessor2_0, | |
SlicedAttnAddedKVProcessor, | |
) | |
from diffusers.optimization import get_scheduler | |
class SdeDragPipeline(DiffusionPipeline): | |
r""" | |
Pipeline for image drag-and-drop editing using stochastic differential equations: https://arxiv.org/abs/2311.01410. | |
Please refer to the [official repository](https://github.com/ML-GSAI/SDE-Drag) for more information. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
text_encoder ([`CLIPTextModel`]): | |
Frozen text-encoder. Stable Diffusion uses the text portion of | |
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | |
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | |
tokenizer (`CLIPTokenizer`): | |
Tokenizer of class | |
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | |
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Please use | |
[`DDIMScheduler`]. | |
""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: DPMSolverMultistepScheduler, | |
): | |
super().__init__() | |
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | |
def __call__( | |
self, | |
prompt: str, | |
image: PIL.Image.Image, | |
mask_image: PIL.Image.Image, | |
source_points: List[List[int]], | |
target_points: List[List[int]], | |
t0: Optional[float] = 0.6, | |
steps: Optional[int] = 200, | |
step_size: Optional[int] = 2, | |
image_scale: Optional[float] = 0.3, | |
adapt_radius: Optional[int] = 5, | |
min_lora_scale: Optional[float] = 0.5, | |
generator: Optional[torch.Generator] = None, | |
): | |
r""" | |
Function invoked when calling the pipeline for image editing. | |
Args: | |
prompt (`str`, *required*): | |
The prompt to guide the image editing. | |
image (`PIL.Image.Image`, *required*): | |
Which will be edited, parts of the image will be masked out with `mask_image` and edited | |
according to `prompt`. | |
mask_image (`PIL.Image.Image`, *required*): | |
To mask `image`. White pixels in the mask will be edited, while black pixels will be preserved. | |
source_points (`List[List[int]]`, *required*): | |
Used to mark the starting positions of drag editing in the image, with each pixel represented as a | |
`List[int]` of length 2. | |
target_points (`List[List[int]]`, *required*): | |
Used to mark the target positions of drag editing in the image, with each pixel represented as a | |
`List[int]` of length 2. | |
t0 (`float`, *optional*, defaults to 0.6): | |
The time parameter. Higher t0 improves the fidelity while lowering the faithfulness of the edited images | |
and vice versa. | |
steps (`int`, *optional*, defaults to 200): | |
The number of sampling iterations. | |
step_size (`int`, *optional*, defaults to 2): | |
The drag diatance of each drag step. | |
image_scale (`float`, *optional*, defaults to 0.3): | |
To avoid duplicating the content, use image_scale to perturbs the source. | |
adapt_radius (`int`, *optional*, defaults to 5): | |
The size of the region for copy and paste operations during each step of the drag process. | |
min_lora_scale (`float`, *optional*, defaults to 0.5): | |
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. | |
min_lora_scale specifies the minimum LoRA scale during the image drag-editing process. | |
generator ('torch.Generator', *optional*, defaults to None): | |
To make generation deterministic(https://pytorch.org/docs/stable/generated/torch.Generator.html). | |
Examples: | |
```py | |
>>> import PIL | |
>>> import torch | |
>>> from diffusers import DDIMScheduler, DiffusionPipeline | |
>>> # Load the pipeline | |
>>> model_path = "runwayml/stable-diffusion-v1-5" | |
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") | |
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag") | |
>>> pipe.to('cuda') | |
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality. | |
>>> # If not training LoRA, please avoid using torch.float16 | |
>>> # pipe.to(torch.float16) | |
>>> # Provide prompt, image, mask image, and the starting and target points for drag editing. | |
>>> prompt = "prompt of the image" | |
>>> image = PIL.Image.open('/path/to/image') | |
>>> mask_image = PIL.Image.open('/path/to/mask_image') | |
>>> source_points = [[123, 456]] | |
>>> target_points = [[234, 567]] | |
>>> # train_lora is optional, and in most cases, using train_lora can better preserve consistency with the original image. | |
>>> pipe.train_lora(prompt, image) | |
>>> output = pipe(prompt, image, mask_image, source_points, target_points) | |
>>> output_image = PIL.Image.fromarray(output) | |
>>> output_image.save("./output.png") | |
``` | |
""" | |
self.scheduler.set_timesteps(steps) | |
noise_scale = (1 - image_scale**2) ** (0.5) | |
text_embeddings = self._get_text_embed(prompt) | |
uncond_embeddings = self._get_text_embed([""]) | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
latent = self._get_img_latent(image) | |
mask = mask_image.resize((latent.shape[3], latent.shape[2])) | |
mask = torch.tensor(np.array(mask)) | |
mask = mask.unsqueeze(0).expand_as(latent).to(self.device) | |
source_points = torch.tensor(source_points).div(torch.tensor([8]), rounding_mode="trunc") | |
target_points = torch.tensor(target_points).div(torch.tensor([8]), rounding_mode="trunc") | |
distance = target_points - source_points | |
distance_norm_max = torch.norm(distance.float(), dim=1, keepdim=True).max() | |
if distance_norm_max <= step_size: | |
drag_num = 1 | |
else: | |
drag_num = distance_norm_max.div(torch.tensor([step_size]), rounding_mode="trunc") | |
if (distance_norm_max / drag_num - step_size).abs() > ( | |
distance_norm_max / (drag_num + 1) - step_size | |
).abs(): | |
drag_num += 1 | |
latents = [] | |
for i in tqdm(range(int(drag_num)), desc="SDE Drag"): | |
source_new = source_points + (i / drag_num * distance).to(torch.int) | |
target_new = source_points + ((i + 1) / drag_num * distance).to(torch.int) | |
latent, noises, hook_latents, lora_scales, cfg_scales = self._forward( | |
latent, steps, t0, min_lora_scale, text_embeddings, generator | |
) | |
latent = self._copy_and_paste( | |
latent, | |
source_new, | |
target_new, | |
adapt_radius, | |
latent.shape[2] - 1, | |
latent.shape[3] - 1, | |
image_scale, | |
noise_scale, | |
generator, | |
) | |
latent = self._backward( | |
latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator | |
) | |
latents.append(latent) | |
result_image = 1 / 0.18215 * latents[-1] | |
with torch.no_grad(): | |
result_image = self.vae.decode(result_image).sample | |
result_image = (result_image / 2 + 0.5).clamp(0, 1) | |
result_image = result_image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
result_image = (result_image * 255).astype(np.uint8) | |
return result_image | |
def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None): | |
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision="fp16") | |
self.vae.requires_grad_(False) | |
self.text_encoder.requires_grad_(False) | |
self.unet.requires_grad_(False) | |
unet_lora_attn_procs = {} | |
for name, attn_processor in self.unet.attn_processors.items(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = self.unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = self.unet.config.block_out_channels[block_id] | |
else: | |
raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") | |
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): | |
lora_attn_processor_class = LoRAAttnAddedKVProcessor | |
else: | |
lora_attn_processor_class = ( | |
LoRAAttnProcessor2_0 | |
if hasattr(torch.nn.functional, "scaled_dot_product_attention") | |
else LoRAAttnProcessor | |
) | |
unet_lora_attn_procs[name] = lora_attn_processor_class( | |
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank | |
) | |
self.unet.set_attn_processor(unet_lora_attn_procs) | |
unet_lora_layers = AttnProcsLayers(self.unet.attn_processors) | |
params_to_optimize = unet_lora_layers.parameters() | |
optimizer = torch.optim.AdamW( | |
params_to_optimize, | |
lr=2e-4, | |
betas=(0.9, 0.999), | |
weight_decay=1e-2, | |
eps=1e-08, | |
) | |
lr_scheduler = get_scheduler( | |
"constant", | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=lora_step, | |
num_cycles=1, | |
power=1.0, | |
) | |
unet_lora_layers = accelerator.prepare_model(unet_lora_layers) | |
optimizer = accelerator.prepare_optimizer(optimizer) | |
lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) | |
with torch.no_grad(): | |
text_inputs = self._tokenize_prompt(prompt, tokenizer_max_length=None) | |
text_embedding = self._encode_prompt( | |
text_inputs.input_ids, text_inputs.attention_mask, text_encoder_use_attention_mask=False | |
) | |
image_transforms = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
image = image_transforms(image).to(self.device, dtype=self.vae.dtype) | |
image = image.unsqueeze(dim=0) | |
latents_dist = self.vae.encode(image).latent_dist | |
for _ in tqdm(range(lora_step), desc="Train LoRA"): | |
self.unet.train() | |
model_input = latents_dist.sample() * self.vae.config.scaling_factor | |
# Sample noise that we'll add to the latents | |
noise = torch.randn( | |
model_input.size(), | |
dtype=model_input.dtype, | |
layout=model_input.layout, | |
device=model_input.device, | |
generator=generator, | |
) | |
bsz, channels, height, width = model_input.shape | |
# Sample a random timestep for each image | |
timesteps = torch.randint( | |
0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator | |
) | |
timesteps = timesteps.long() | |
# Add noise to the model input according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_model_input = self.scheduler.add_noise(model_input, noise, timesteps) | |
# Predict the noise residual | |
model_pred = self.unet(noisy_model_input, timesteps, text_embedding).sample | |
# Get the target for loss depending on the prediction type | |
if self.scheduler.config.prediction_type == "epsilon": | |
target = noise | |
elif self.scheduler.config.prediction_type == "v_prediction": | |
target = self.scheduler.get_velocity(model_input, noise, timesteps) | |
else: | |
raise ValueError(f"Unknown prediction type {self.scheduler.config.prediction_type}") | |
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
accelerator.backward(loss) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
with tempfile.TemporaryDirectory() as save_lora_dir: | |
LoraLoaderMixin.save_lora_weights( | |
save_directory=save_lora_dir, | |
unet_lora_layers=unet_lora_layers, | |
text_encoder_lora_layers=None, | |
) | |
self.unet.load_attn_procs(save_lora_dir) | |
def _tokenize_prompt(self, prompt, tokenizer_max_length=None): | |
if tokenizer_max_length is not None: | |
max_length = tokenizer_max_length | |
else: | |
max_length = self.tokenizer.model_max_length | |
text_inputs = self.tokenizer( | |
prompt, | |
truncation=True, | |
padding="max_length", | |
max_length=max_length, | |
return_tensors="pt", | |
) | |
return text_inputs | |
def _encode_prompt(self, input_ids, attention_mask, text_encoder_use_attention_mask=False): | |
text_input_ids = input_ids.to(self.device) | |
if text_encoder_use_attention_mask: | |
attention_mask = attention_mask.to(self.device) | |
else: | |
attention_mask = None | |
prompt_embeds = self.text_encoder( | |
text_input_ids, | |
attention_mask=attention_mask, | |
) | |
prompt_embeds = prompt_embeds[0] | |
return prompt_embeds | |
def _get_text_embed(self, prompt): | |
text_input = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
return text_embeddings | |
def _copy_and_paste( | |
self, latent, source_new, target_new, adapt_radius, max_height, max_width, image_scale, noise_scale, generator | |
): | |
def adaption_r(source, target, adapt_radius, max_height, max_width): | |
r_x_lower = min(adapt_radius, source[0], target[0]) | |
r_x_upper = min(adapt_radius, max_width - source[0], max_width - target[0]) | |
r_y_lower = min(adapt_radius, source[1], target[1]) | |
r_y_upper = min(adapt_radius, max_height - source[1], max_height - target[1]) | |
return r_x_lower, r_x_upper, r_y_lower, r_y_upper | |
for source_, target_ in zip(source_new, target_new): | |
r_x_lower, r_x_upper, r_y_lower, r_y_upper = adaption_r( | |
source_, target_, adapt_radius, max_height, max_width | |
) | |
source_feature = latent[ | |
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper | |
].clone() | |
latent[ | |
:, :, source_[1] - r_y_lower : source_[1] + r_y_upper, source_[0] - r_x_lower : source_[0] + r_x_upper | |
] = image_scale * source_feature + noise_scale * torch.randn( | |
latent.shape[0], | |
4, | |
r_y_lower + r_y_upper, | |
r_x_lower + r_x_upper, | |
device=self.device, | |
generator=generator, | |
) | |
latent[ | |
:, :, target_[1] - r_y_lower : target_[1] + r_y_upper, target_[0] - r_x_lower : target_[0] + r_x_upper | |
] = source_feature * 1.1 | |
return latent | |
def _get_img_latent(self, image, height=None, weight=None): | |
data = image.convert("RGB") | |
if height is not None: | |
data = data.resize((weight, height)) | |
transform = transforms.ToTensor() | |
data = transform(data).unsqueeze(0) | |
data = (data * 2.0) - 1.0 | |
data = data.to(self.device, dtype=self.vae.dtype) | |
latent = self.vae.encode(data).latent_dist.sample() | |
latent = 0.18215 * latent | |
return latent | |
def _get_eps(self, latent, timestep, guidance_scale, text_embeddings, lora_scale=None): | |
latent_model_input = torch.cat([latent] * 2) if guidance_scale > 1.0 else latent | |
text_embeddings = text_embeddings if guidance_scale > 1.0 else text_embeddings.chunk(2)[1] | |
cross_attention_kwargs = None if lora_scale is None else {"scale": lora_scale} | |
with torch.no_grad(): | |
noise_pred = self.unet( | |
latent_model_input, | |
timestep, | |
encoder_hidden_states=text_embeddings, | |
cross_attention_kwargs=cross_attention_kwargs, | |
).sample | |
if guidance_scale > 1.0: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
elif guidance_scale == 1.0: | |
noise_pred_text = noise_pred | |
noise_pred_uncond = 0.0 | |
else: | |
raise NotImplementedError(guidance_scale) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
return noise_pred | |
def _forward_sde( | |
self, timestep, sample, guidance_scale, text_embeddings, steps, eta=1.0, lora_scale=None, generator=None | |
): | |
num_train_timesteps = len(self.scheduler) | |
alphas_cumprod = self.scheduler.alphas_cumprod | |
initial_alpha_cumprod = torch.tensor(1.0) | |
prev_timestep = timestep + num_train_timesteps // steps | |
alpha_prod_t = alphas_cumprod[timestep] if timestep >= 0 else initial_alpha_cumprod | |
alpha_prod_t_prev = alphas_cumprod[prev_timestep] | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
x_prev = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) * sample + (1 - alpha_prod_t_prev / alpha_prod_t) ** ( | |
0.5 | |
) * torch.randn( | |
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator | |
) | |
eps = self._get_eps(x_prev, prev_timestep, guidance_scale, text_embeddings, lora_scale) | |
sigma_t_prev = ( | |
eta | |
* (1 - alpha_prod_t) ** (0.5) | |
* (1 - alpha_prod_t_prev / (1 - alpha_prod_t_prev) * (1 - alpha_prod_t) / alpha_prod_t) ** (0.5) | |
) | |
pred_original_sample = (x_prev - beta_prod_t_prev ** (0.5) * eps) / alpha_prod_t_prev ** (0.5) | |
pred_sample_direction_coeff = (1 - alpha_prod_t - sigma_t_prev**2) ** (0.5) | |
noise = ( | |
sample - alpha_prod_t ** (0.5) * pred_original_sample - pred_sample_direction_coeff * eps | |
) / sigma_t_prev | |
return x_prev, noise | |
def _sample( | |
self, | |
timestep, | |
sample, | |
guidance_scale, | |
text_embeddings, | |
steps, | |
sde=False, | |
noise=None, | |
eta=1.0, | |
lora_scale=None, | |
generator=None, | |
): | |
num_train_timesteps = len(self.scheduler) | |
alphas_cumprod = self.scheduler.alphas_cumprod | |
final_alpha_cumprod = torch.tensor(1.0) | |
eps = self._get_eps(sample, timestep, guidance_scale, text_embeddings, lora_scale) | |
prev_timestep = timestep - num_train_timesteps // steps | |
alpha_prod_t = alphas_cumprod[timestep] | |
alpha_prod_t_prev = alphas_cumprod[prev_timestep] if prev_timestep >= 0 else final_alpha_cumprod | |
beta_prod_t = 1 - alpha_prod_t | |
sigma_t = ( | |
eta | |
* ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** (0.5) | |
* (1 - alpha_prod_t / alpha_prod_t_prev) ** (0.5) | |
if sde | |
else 0 | |
) | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * eps) / alpha_prod_t ** (0.5) | |
pred_sample_direction_coeff = (1 - alpha_prod_t_prev - sigma_t**2) ** (0.5) | |
noise = ( | |
torch.randn( | |
sample.size(), dtype=sample.dtype, layout=sample.layout, device=self.device, generator=generator | |
) | |
if noise is None | |
else noise | |
) | |
latent = ( | |
alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction_coeff * eps + sigma_t * noise | |
) | |
return latent | |
def _forward(self, latent, steps, t0, lora_scale_min, text_embeddings, generator): | |
def scale_schedule(begin, end, n, length, type="linear"): | |
if type == "constant": | |
return end | |
elif type == "linear": | |
return begin + (end - begin) * n / length | |
elif type == "cos": | |
factor = (1 - math.cos(n * math.pi / length)) / 2 | |
return (1 - factor) * begin + factor * end | |
else: | |
raise NotImplementedError(type) | |
noises = [] | |
latents = [] | |
lora_scales = [] | |
cfg_scales = [] | |
latents.append(latent) | |
t0 = int(t0 * steps) | |
t_begin = steps - t0 | |
length = len(self.scheduler.timesteps[t_begin - 1 : -1]) - 1 | |
index = 1 | |
for t in self.scheduler.timesteps[t_begin:].flip(dims=[0]): | |
lora_scale = scale_schedule(1, lora_scale_min, index, length, type="cos") | |
cfg_scale = scale_schedule(1, 3.0, index, length, type="linear") | |
latent, noise = self._forward_sde( | |
t, latent, cfg_scale, text_embeddings, steps, lora_scale=lora_scale, generator=generator | |
) | |
noises.append(noise) | |
latents.append(latent) | |
lora_scales.append(lora_scale) | |
cfg_scales.append(cfg_scale) | |
index += 1 | |
return latent, noises, latents, lora_scales, cfg_scales | |
def _backward( | |
self, latent, mask, steps, t0, noises, hook_latents, lora_scales, cfg_scales, text_embeddings, generator | |
): | |
t0 = int(t0 * steps) | |
t_begin = steps - t0 | |
hook_latent = hook_latents.pop() | |
latent = torch.where(mask > 128, latent, hook_latent) | |
for t in self.scheduler.timesteps[t_begin - 1 : -1]: | |
latent = self._sample( | |
t, | |
latent, | |
cfg_scales.pop(), | |
text_embeddings, | |
steps, | |
sde=True, | |
noise=noises.pop(), | |
lora_scale=lora_scales.pop(), | |
generator=generator, | |
) | |
hook_latent = hook_latents.pop() | |
latent = torch.where(mask > 128, latent, hook_latent) | |
return latent | |