# ************************************************************************* # Copyright (2023) Bytedance Inc. # # Copyright (2023) DragDiffusion Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ************************************************************************* import os import cv2 import numpy as np import gradio as gr from copy import deepcopy from einops import rearrange from types import SimpleNamespace import datetime import PIL from PIL import Image from PIL.ImageOps import exif_transpose import torch import torch.nn.functional as F from diffusers import DDIMScheduler, AutoencoderKL, DPMSolverMultistepScheduler from diffusers.models.embeddings import ImageProjection from drag_pipeline import DragPipeline from torchvision.utils import save_image from pytorch_lightning import seed_everything from .drag_utils import drag_diffusion_update, drag_diffusion_update_gen from .lora_utils import train_lora from .attn_utils import register_attention_editor_diffusers, MutualSelfAttentionControl from .freeu_utils import register_free_upblock2d, register_free_crossattn_upblock2d # -------------- general UI functionality -------------- def clear_all(length=480): return gr.Image.update(value=None, height=length, width=length, interactive=True), \ gr.Image.update(value=None, height=length, width=length, interactive=False), \ gr.Image.update(value=None, height=length, width=length, interactive=False), \ [], None, None def clear_all_gen(length=480): return gr.Image.update(value=None, height=length, width=length, interactive=False), \ gr.Image.update(value=None, height=length, width=length, interactive=False), \ gr.Image.update(value=None, height=length, width=length, interactive=False), \ [], None, None, None def mask_image(image, mask, color=[255,0,0], alpha=0.5): """ Overlay mask on image for visualization purpose. Args: image (H, W, 3) or (H, W): input image mask (H, W): mask to be overlaid color: the color of overlaid mask alpha: the transparency of the mask """ out = deepcopy(image) img = deepcopy(image) img[mask == 1] = color out = cv2.addWeighted(img, alpha, out, 1-alpha, 0, out) return out def store_img(img, length=512): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. height,width,_ = image.shape image = Image.fromarray(image) image = exif_transpose(image) image = image.resize((length,int(length*height/width)), PIL.Image.BILINEAR) mask = cv2.resize(mask, (length,int(length*height/width)), interpolation=cv2.INTER_NEAREST) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], gr.Image.update(value=masked_img, interactive=True), mask # once user upload an image, the original image is stored in `original_image` # the same image is displayed in `input_image` for point clicking purpose def store_img_gen(img): image, mask = img["image"], np.float32(img["mask"][:, :, 0]) / 255. image = Image.fromarray(image) image = exif_transpose(image) image = np.array(image) if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = image.copy() # when new image is uploaded, `selected_points` should be empty return image, [], masked_img, mask # user click the image to get points, and show the points on the image def get_points(img, sel_pix, evt: gr.SelectData): # collect the selected point sel_pix.append(evt.index) # draw points points = [] for idx, point in enumerate(sel_pix): if idx % 2 == 0: # draw a red circle at the handle point cv2.circle(img, tuple(point), 10, (255, 0, 0), -1) else: # draw a blue circle at the handle point cv2.circle(img, tuple(point), 10, (0, 0, 255), -1) points.append(tuple(point)) # draw an arrow from handle point to target point if len(points) == 2: cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5) points = [] return img if isinstance(img, np.ndarray) else np.array(img) # clear all handle/target points def undo_points(original_image, mask): if mask.sum() > 0: mask = np.uint8(mask > 0) masked_img = mask_image(original_image, 1 - mask, color=[0, 0, 0], alpha=0.3) else: masked_img = original_image.copy() return masked_img, [] # ------------------------------------------------------ # ----------- dragging user-input image utils ----------- def train_lora_interface(original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_batch_size, lora_rank, progress=gr.Progress()): train_lora( original_image, prompt, model_path, vae_path, lora_path, lora_step, lora_lr, lora_batch_size, lora_rank, progress) return "Training LoRA Done!" def preprocess_image(image, device, dtype=torch.float32): image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1] image = rearrange(image, "h w c -> 1 c h w") image = image.to(device, dtype) return image def run_drag(source_image, image_with_clicks, mask, prompt, points, inversion_strength, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, save_dir="./results" ): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) model = DragPipeline.from_pretrained(model_path, scheduler=scheduler, torch_dtype=torch.float16) # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # off load model to cpu, which save some memory. model.enable_model_cpu_offload() # initialize parameters seed = 42 # random seed used by a lot of people for unknown reason seed_everything(seed) args = SimpleNamespace() args.prompt = prompt args.points = points args.n_inference_step = 50 args.n_actual_inference_step = round(inversion_strength * args.n_inference_step) args.guidance_scale = 1.0 args.unet_feature_idx = [3] args.r_m = 1 args.r_p = 3 args.lam = lam args.lr = latent_lr args.n_pix_step = n_pix_step full_h, full_w = source_image.shape[:2] args.sup_res_h = int(0.5*full_h) args.sup_res_w = int(0.5*full_w) print(args) source_image = preprocess_image(source_image, device, dtype=torch.float16) image_with_clicks = preprocess_image(image_with_clicks, device) # preparing editing meta data (handle, target, mask) mask = torch.from_numpy(mask).float() / 255. mask[mask > 0.0] = 1.0 mask = rearrange(mask, "h w -> 1 1 h w").cuda() mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") handle_points = [] target_points = [] # here, the point is in x,y coordinate for idx, point in enumerate(points): cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) cur_point = torch.round(cur_point) if idx % 2 == 0: handle_points.append(cur_point) else: target_points.append(cur_point) print('handle points:', handle_points) print('target points:', target_points) # set lora if lora_path == "": print("applying default parameters") model.unet.set_default_attn_processor() else: print("applying lora: " + lora_path) model.unet.load_attn_procs(lora_path) # obtain text embeddings text_embeddings = model.get_text_embeddings(prompt) # invert the source image # the latent code resolution is too small, only 64*64 invert_code = model.invert(source_image, prompt, encoder_hidden_states=text_embeddings, guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step) # empty cache to save memory torch.cuda.empty_cache() init_code = invert_code init_code_orig = deepcopy(init_code) model.scheduler.set_timesteps(args.n_inference_step) t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # convert dtype to float for optimization init_code = init_code.float() text_embeddings = text_embeddings.float() model.unet = model.unet.float() updated_init_code = drag_diffusion_update( model, init_code, text_embeddings, t, handle_points, target_points, mask, args) updated_init_code = updated_init_code.half() text_embeddings = text_embeddings.half() model.unet = model.unet.half() # empty cache to save memory torch.cuda.empty_cache() # hijack the attention module # inject the reference branch to guide the generation editor = MutualSelfAttentionControl(start_step=start_step, start_layer=start_layer, total_steps=args.n_inference_step, guidance_scale=args.guidance_scale) if lora_path == "": register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') else: register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') # inference the synthesized image gen_image = model( prompt=args.prompt, encoder_hidden_states=torch.cat([text_embeddings]*2, dim=0), batch_size=2, latents=torch.cat([init_code_orig, updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) # resize gen_image into the size of source_image # we do this because shape of gen_image will be rounded to multipliers of 8 gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') # save the original image, user editing instructions, synthesized image save_result = torch.cat([ source_image.float() * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), image_with_clicks.float() * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), gen_image[0:1].float() ], dim=-1) if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image # ------------------------------------------------------- # ----------- dragging generated image utils ----------- # once the user generated an image # it will be displayed on mask drawing-areas and point-clicking area def gen_img( length, # length of the window displaying the image height, # height of the generated image width, # width of the generated image n_inference_step, scheduler_name, seed, guidance_scale, prompt, neg_prompt, model_path, vae_path, lora_path, b1, b2, s1, s2): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16).to(device) if scheduler_name == "DDIM": scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) elif scheduler_name == "DPM++2M": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config ) elif scheduler_name == "DPM++2M_karras": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config, use_karras_sigmas=True ) else: raise NotImplementedError("scheduler name not correct") model.scheduler = scheduler # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # set lora #if lora_path != "": # print("applying lora for image generation: " + lora_path) # model.unet.load_attn_procs(lora_path) if lora_path != "": print("applying lora: " + lora_path) model.load_lora_weights(lora_path, weight_name="lora.safetensors") # apply FreeU if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0: print('applying FreeU') register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) else: print('do not apply FreeU') # initialize init noise seed_everything(seed) init_noise = torch.randn([1, 4, height // 8, width // 8], device=device, dtype=model.vae.dtype) gen_image, intermediate_latents = model(prompt=prompt, neg_prompt=neg_prompt, num_inference_steps=n_inference_step, latents=init_noise, guidance_scale=guidance_scale, return_intermediates=True) gen_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] gen_image = (gen_image * 255).astype(np.uint8) if height < width: # need to do this due to Gradio's bug return gr.Image.update(value=gen_image, height=int(length*height/width), width=length, interactive=True), \ gr.Image.update(height=int(length*height/width), width=length, interactive=True), \ gr.Image.update(height=int(length*height/width), width=length), \ None, \ intermediate_latents else: return gr.Image.update(value=gen_image, height=length, width=length, interactive=True), \ gr.Image.update(value=None, height=length, width=length, interactive=True), \ gr.Image.update(value=None, height=length, width=length), \ None, \ intermediate_latents def run_drag_gen( n_inference_step, scheduler_name, source_image, image_with_clicks, intermediate_latents_gen, guidance_scale, mask, prompt, neg_prompt, points, inversion_strength, lam, latent_lr, n_pix_step, model_path, vae_path, lora_path, start_step, start_layer, b1, b2, s1, s2, save_dir="./results"): # initialize model device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = DragPipeline.from_pretrained(model_path, torch_dtype=torch.float16) if scheduler_name == "DDIM": scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1) elif scheduler_name == "DPM++2M": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config ) elif scheduler_name == "DPM++2M_karras": scheduler = DPMSolverMultistepScheduler.from_config( model.scheduler.config, use_karras_sigmas=True ) else: raise NotImplementedError("scheduler name not correct") model.scheduler = scheduler # call this function to override unet forward function, # so that intermediate features are returned after forward model.modify_unet_forward() # set vae if vae_path != "default": model.vae = AutoencoderKL.from_pretrained( vae_path ).to(model.vae.device, model.vae.dtype) # off load model to cpu, which save some memory. model.enable_model_cpu_offload() # initialize parameters seed = 42 # random seed used by a lot of people for unknown reason seed_everything(seed) args = SimpleNamespace() args.prompt = prompt args.neg_prompt = neg_prompt args.points = points args.n_inference_step = n_inference_step args.n_actual_inference_step = round(n_inference_step * inversion_strength) args.guidance_scale = guidance_scale args.unet_feature_idx = [3] full_h, full_w = source_image.shape[:2] args.sup_res_h = int(0.5*full_h) args.sup_res_w = int(0.5*full_w) args.r_m = 1 args.r_p = 3 args.lam = lam args.lr = latent_lr args.n_pix_step = n_pix_step print(args) source_image = preprocess_image(source_image, device) image_with_clicks = preprocess_image(image_with_clicks, device) if lora_path != "": print("applying lora: " + lora_path) model.load_lora_weights(lora_path, weight_name="lora.safetensors") # preparing editing meta data (handle, target, mask) mask = torch.from_numpy(mask).float() / 255. mask[mask > 0.0] = 1.0 mask = rearrange(mask, "h w -> 1 1 h w").cuda() mask = F.interpolate(mask, (args.sup_res_h, args.sup_res_w), mode="nearest") handle_points = [] target_points = [] # here, the point is in x,y coordinate for idx, point in enumerate(points): cur_point = torch.tensor([point[1]/full_h*args.sup_res_h, point[0]/full_w*args.sup_res_w]) cur_point = torch.round(cur_point) if idx % 2 == 0: handle_points.append(cur_point) else: target_points.append(cur_point) print('handle points:', handle_points) print('target points:', target_points) # apply FreeU if b1 != 1.0 or b2!=1.0 or s1!=1.0 or s2!=1.0: print('applying FreeU') register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s2) else: print('do not apply FreeU') # obtain text embeddings text_embeddings = model.get_text_embeddings(prompt) model.scheduler.set_timesteps(args.n_inference_step) t = model.scheduler.timesteps[args.n_inference_step - args.n_actual_inference_step] init_code = deepcopy(intermediate_latents_gen[args.n_inference_step - args.n_actual_inference_step]) init_code_orig = deepcopy(init_code) # feature shape: [1280,16,16], [1280,32,32], [640,64,64], [320,64,64] # update according to the given supervision torch.cuda.empty_cache() init_code = init_code.to(torch.float32) text_embeddings = text_embeddings.to(torch.float32) model.unet = model.unet.to(torch.float32) updated_init_code = drag_diffusion_update_gen(model, init_code, text_embeddings, t, handle_points, target_points, mask, args) updated_init_code = updated_init_code.to(torch.float16) text_embeddings = text_embeddings.to(torch.float16) model.unet = model.unet.to(torch.float16) torch.cuda.empty_cache() # hijack the attention module # inject the reference branch to guide the generation editor = MutualSelfAttentionControl(start_step=start_step, start_layer=start_layer, total_steps=args.n_inference_step, guidance_scale=args.guidance_scale) if lora_path == "": register_attention_editor_diffusers(model, editor, attn_processor='attn_proc') else: register_attention_editor_diffusers(model, editor, attn_processor='lora_attn_proc') # inference the synthesized image gen_image = model( prompt=args.prompt, neg_prompt=args.neg_prompt, batch_size=2, # batch size is 2 because we have reference init_code and updated init_code latents=torch.cat([init_code_orig, updated_init_code], dim=0), guidance_scale=args.guidance_scale, num_inference_steps=args.n_inference_step, num_actual_inference_steps=args.n_actual_inference_step )[1].unsqueeze(dim=0) # resize gen_image into the size of source_image # we do this because shape of gen_image will be rounded to multipliers of 8 gen_image = F.interpolate(gen_image, (full_h, full_w), mode='bilinear') # save the original image, user editing instructions, synthesized image save_result = torch.cat([ source_image * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), image_with_clicks * 0.5 + 0.5, torch.ones((1,3,full_h,25)).cuda(), gen_image[0:1] ], dim=-1) if not os.path.isdir(save_dir): os.mkdir(save_dir) save_prefix = datetime.datetime.now().strftime("%Y-%m-%d-%H%M-%S") save_image(save_result, os.path.join(save_dir, save_prefix + '.png')) out_image = gen_image.cpu().permute(0, 2, 3, 1).numpy()[0] out_image = (out_image * 255).astype(np.uint8) return out_image # ------------------------------------------------------