import numpy as np from tqdm import tqdm import gradio as gr import os import torch import torch.nn.functional as F import torch.utils.checkpoint from accelerate import Accelerator # from accelerate.logging import get_logger from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel # from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from torchvision import transforms from transformers import CLIPTextModel, CLIPTokenizer from torch import autocast from src.diffusers_ import StableDiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from packaging import version def launch_source(): image = Image.open('./tmp/train_images_source.png').convert("RGB") output = temp_save([image],num_rows=1) return output def launch_opt800(): image = Image.open('./tmp/train_images_step800.png').convert("RGB") output = temp_save([image],num_rows=1) return output def launch_opt900(): image = Image.open('./tmp/train_images_step900.png').convert("RGB") output = temp_save([image],num_rows=1) return output def launch_opt1000(): image = Image.open('./tmp/train_images_step1000.png').convert("RGB") output = temp_save([image],num_rows=1) return output def launch_opt1100(): image = Image.open('./tmp/train_images_step1100.png').convert("RGB") output = temp_save([image],num_rows=1) return output def launch_optimize(img_in_real, prompt, n_hiper): os.makedirs("tmp", exist_ok=True) # Setting accelerator = Accelerator( gradient_accumulation_steps=1, mixed_precision="fp16", ) seed = 2220000 set_seed(seed) g_cuda = torch.Generator(device='cuda') g_cuda.manual_seed(seed) optimizer_class = torch.optim.Adam weight_dtype = torch.float16 pretrained_model_name = 'CompVis/stable-diffusion-v1-4' # Load pretrained models tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer")#, use_auth_token=True) CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder")#, use_auth_token=True) vae = AutoencoderKL.from_pretrained(pretrained_model_name, subfolder="vae")#, use_auth_token=True) unet = UNet2DConditionModel.from_pretrained(pretrained_model_name, subfolder="unet")#, use_auth_token=True) noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): print( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") # Encode the input image. vae.to(accelerator.device, dtype=weight_dtype) input_image = img_in_real.convert("RGB") img_in_real.save(os.path.join("tmp", "train_images_source.png")) image_transforms = transforms.Compose( [ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) init_image = image_transforms(input_image) init_image = init_image[None].to(device=accelerator.device, dtype=weight_dtype) with torch.inference_mode(): init_latents = vae.encode(init_image).latent_dist.sample() init_latents = 0.18215 * init_latents # Encode the source and target text. CLIP_text_encoder.to(accelerator.device, dtype=weight_dtype) text_ids_src = tokenizer(prompt,padding="max_length",truncation=True,max_length=tokenizer.model_max_length,return_tensors="pt").input_ids text_ids_src = text_ids_src.to(device=accelerator.device) with torch.inference_mode(): source_embeddings = CLIP_text_encoder(text_ids_src)[0].float() # del vae, CLIP_text_encoder del vae, CLIP_text_encoder if torch.cuda.is_available(): torch.cuda.empty_cache() # For inference ddim_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=ddim_scheduler, torch_dtype=torch.float16).to("cuda") num_samples = 1 guidance_scale = 7.5 num_inference_steps = 50 height = 512 width = 512 # Optimize hiper embedding n_hiper = int(n_hiper) hiper_embeddings = source_embeddings[:,-n_hiper:].clone().detach() src_embeddings = source_embeddings[:,:-n_hiper].clone().detach() hiper_embeddings.requires_grad_(True) optimizer = optimizer_class( [hiper_embeddings], lr=5e-3, betas=(0.9, 0.999), eps=1e-08, ) unet, optimizer = accelerator.prepare(unet, optimizer) emb_train_steps = 1101 # emb_train_steps = 201 def train_loop(optimizer, hiper_embeddings): inf_images=[] for step in tqdm(range(emb_train_steps)): with accelerator.accumulate(unet): noise = torch.randn_like(init_latents) bsz = init_latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(init_latents, noise, timesteps) source_embeddings = torch.cat([src_embeddings, hiper_embeddings], 1) noise_pred = unet(noisy_latents, timesteps, source_embeddings).sample loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) optimizer.step() optimizer.zero_grad(set_to_none=True) # Check inference if step in [800,900,1000,1100]: inf_emb = torch.cat([src_embeddings, hiper_embeddings.clone().detach()], 1) with autocast("cuda"), torch.inference_mode(): images = pipe(text_embeddings=inf_emb, height=height, width=width, num_images_per_prompt=num_samples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda).images inf_images.append(images[0]) images[0].save(os.path.join("tmp", "train_images_step{}.png".format(step))) del images if step in [800,900,1000,1100]: torch.save(hiper_embeddings.cpu(), os.path.join("tmp", "hiper_embeddings_step{}.pt".format(step))) accelerator.wait_for_everyone() out_image = train_loop(optimizer, hiper_embeddings) image = Image.open('./tmp/train_images_source.png').convert("RGB") output = temp_save([image],num_rows=1) return "tmp", output def launch_main(dest, step, fpath_z_gen, seed): seed = int(seed) set_seed(seed) g_cuda = torch.Generator(device='cuda') g_cuda.manual_seed(seed) # Load pretrained models pretrained_model_name = 'CompVis/stable-diffusion-v1-4' scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=scheduler, torch_dtype=torch.float16).to("cuda") tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer")#, use_auth_token=True) CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder")#, use_auth_token=True) # Encode the target text. text_ids_tgt = tokenizer(dest, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt").input_ids CLIP_text_encoder.to('cuda', dtype=torch.float32) with torch.inference_mode(): target_embedding = CLIP_text_encoder(text_ids_tgt.to('cuda'))[0].to('cuda') del CLIP_text_encoder # Concat target and hiper embeddings step = int(step.replace("Step ","")) hiper_embeddings = torch.load('./tmp/hiper_embeddings_step{}.pt'.format(step)).to("cuda") n_hiper = hiper_embeddings.shape[1] inference_embeddings =torch.cat([target_embedding[:, :-n_hiper], hiper_embeddings*0.8], 1) # Generate target images num_samples = 1 guidance_scale = 7.5 num_inference_steps = 50 height = 512 width = 512 with autocast("cuda"), torch.inference_mode(): image_set = [] for idx, embd in enumerate([inference_embeddings]): for i in range(10): images = pipe( text_embeddings=embd, height=height, width=width, num_images_per_prompt=num_samples, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda ).images image_set.append(images[0]) out_image = temp_save(image_set,num_rows=5) return out_image def set_visible_true(): return gr.update(visible=True) def set_visible_false(): return gr.update(visible=False) CSS_main = """ body { font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif; font-weight:300; font-size:18px; margin-left: auto; margin-right: auto; padding-left: 10px; padding-right: 10px; width: 800px; } h1 { font-size:32px; font-weight:300; text-align: center; } h2 { font-size:32px; font-weight:300; text-align: center; } #lbl_gallery_input{ font-family: 'Helvetica', 'Arial', sans-serif; text-align: center; color: #fff; font-size: 28px; display: inline } #lbl_gallery_comparision{ font-family: 'Helvetica', 'Arial', sans-serif; text-align: center; color: #fff; font-size: 28px; } .disclaimerbox { background-color: #eee; border: 1px solid #eeeeee; border-radius: 10px ; -moz-border-radius: 10px ; -webkit-border-radius: 10px ; padding: 20px; } video.header-vid { height: 140px; border: 1px solid black; border-radius: 10px ; -moz-border-radius: 10px ; -webkit-border-radius: 10px ; } img.header-img { height: 140px; border: 1px solid black; border-radius: 10px ; -moz-border-radius: 10px ; -webkit-border-radius: 10px ; } img.rounded { border: 1px solid #eeeeee; border-radius: 10px ; -moz-border-radius: 10px ; -webkit-border-radius: 10px ; } a:link { color: #941120; text-decoration: none; } a:visited { color: #941120; text-decoration: none; } a:hover { color: #941120; } td.dl-link { height: 160px; text-align: center; font-size: 22px; } .layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ 5px 5px 0 0px #fff, /* The second layer */ 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ 10px 10px 0 0px #fff, /* The third layer */ 10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */ 15px 15px 0 0px #fff, /* The fourth layer */ 15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */ 20px 20px 0 0px #fff, /* The fifth layer */ 20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */ 25px 25px 0 0px #fff, /* The fifth layer */ 25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */ margin-left: 10px; margin-right: 45px; } .paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.35); /* The top layer shadow */ margin-left: 10px; margin-right: 45px; } .layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */ box-shadow: 0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */ 5px 5px 0 0px #fff, /* The second layer */ 5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */ 10px 10px 0 0px #fff, /* The third layer */ 10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */ margin-top: 5px; margin-left: 10px; margin-right: 30px; margin-bottom: 5px; } .vert-cent { position: relative; top: 50%; transform: translateY(-50%); } hr { border: 0; height: 1px; background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); } .card { /* width: 130px; height: 195px; width: 1px; height: 1px; */ position: relative; display: inline-block; /* margin: 50px; */ } .card .img-top { display: none; position: absolute; top: 0; left: 0; z-index: 99; } .card:hover .img-top { display: inline; } details { user-select: none; } details>summary span.icon { width: 24px; height: 24px; transition: all 0.3s; margin-left: auto; } details[open] summary span.icon { transform: rotate(180deg); } summary { display: flex; cursor: pointer; } summary::-webkit-details-marker { display: none; } ul { display: table; margin: 0 auto; text-align: left; } .dark { padding: 1em 2em; background-color: #333; box-shadow: 3px 3px 3px #333; border: 1px #333; } .column { float: left; width: 20%; padding: 0.5%; } .galleryImg { transition: opacity 0.3s; -webkit-transition: opacity 0.3s; filter: grayscale(100%); /* filter: blur(2px); */ -webkit-transition : -webkit-filter 250ms linear; /* opacity: 0.5; */ cursor: pointer; } .selected { /* outline: 100px solid var(--hover-background) !important; */ /* outline-offset: -100px; */ filter: grayscale(0%); -webkit-transition : -webkit-filter 250ms linear; /*opacity: 1.0 !important; */ } .galleryImg:hover { filter: grayscale(0%); -webkit-transition : -webkit-filter 250ms linear; } .row { margin-bottom: 1em; padding: 0px 1em; } /* Clear floats after the columns */ .row:after { content: ""; display: table; clear: both; } /* The expanding image container */ #gallery { position: relative; /*display: none;*/ } #section_comparison{ position: relative; width: 100%; height: max-content; } /* SLIDER -------------------------------------------------- */ .slider-container { position: relative; height: 384px; width: 512px; cursor: grab; overflow: hidden; margin: auto; } .slider-after { display: block; position: absolute; top: 0; right: 0; bottom: 0; left: 0; width: 100%; height: 100%; overflow: hidden; } .slider-before { display: block; position: absolute; top: 0; /* right: 0; */ bottom: 0; left: 0; width: 100%; height: 100%; z-index: 15; overflow: hidden; } .slider-before-inset { position: absolute; top: 0; bottom: 0; left: 0; } .slider-after img, .slider-before img { object-fit: cover; position: absolute; width: 100%; height: 100%; object-position: 50% 50%; top: 0; bottom: 0; left: 0; -webkit-user-select: none; -khtml-user-select: none; -moz-user-select: none; -o-user-select: none; user-select: none; } #lbl_inset_left{ text-align: center; position: absolute; top: 384px; width: 150px; left: calc(50% - 256px); z-index: 11; font-size: 16px; color: #fff; margin: 10px; } .inset-before { position: absolute; width: 150px; height: 150px; box-shadow: 3px 3px 3px #333; border: 1px #333; border-style: solid; z-index: 16; top: 410px; left: calc(50% - 256px); margin: 10px; font-size: 1em; background-repeat: no-repeat; pointer-events: none; } #lbl_inset_right{ text-align: center; position: absolute; top: 384px; width: 150px; right: calc(50% - 256px); z-index: 11; font-size: 16px; color: #fff; margin: 10px; } .inset-after { position: absolute; width: 150px; height: 150px; box-shadow: 3px 3px 3px #333; border: 1px #333; border-style: solid; z-index: 16; top: 410px; right: calc(50% - 256px); margin: 10px; font-size: 1em; background-repeat: no-repeat; pointer-events: none; } #lbl_inset_input{ text-align: center; position: absolute; top: 384px; width: 150px; left: calc(50% - 256px + 150px + 20px); z-index: 11; font-size: 16px; color: #fff; margin: 10px; } .inset-target { position: absolute; width: 150px; height: 150px; box-shadow: 3px 3px 3px #333; border: 1px #333; border-style: solid; z-index: 16; top: 410px; right: calc(50% - 256px + 150px + 20px); margin: 10px; font-size: 1em; background-repeat: no-repeat; pointer-events: none; } .slider-beforePosition { background: #121212; color: #fff; left: 0; pointer-events: none; border-radius: 0.2rem; padding: 2px 10px; } .slider-afterPosition { background: #121212; color: #fff; right: 0; pointer-events: none; border-radius: 0.2rem; padding: 2px 10px; } .beforeLabel { position: absolute; top: 0; margin: 1rem; font-size: 1em; -webkit-user-select: none; -khtml-user-select: none; -moz-user-select: none; -o-user-select: none; user-select: none; } .afterLabel { position: absolute; top: 0; margin: 1rem; font-size: 1em; -webkit-user-select: none; -khtml-user-select: none; -moz-user-select: none; -o-user-select: none; user-select: none; } .slider-handle { height: 101px; width: 41px; position: absolute; left: 50%; top: 50%; margin-left: -20px; margin-top: -21px; border: 2px solid #fff; border-radius: 1000px; z-index: 20; pointer-events: none; box-shadow: 0 0 10px rgb(12, 12, 12); } .handle-left-arrow, .handle-right-arrow { width: 0; height: 0; border: 6px inset transparent; position: absolute; top: 50%; margin-top: -6px; } .handle-left-arrow { border-right: 6px solid #fff; left: 50%; margin-left: -17px; } .handle-right-arrow { border-left: 6px solid #fff; right: 50%; margin-right: -17px; } .slider-handle::before { bottom: 50%; margin-bottom: 20px; box-shadow: 0 0 10px rgb(12, 12, 12); } .slider-handle::after { top: 50%; margin-top: 20.5px; box-shadow: 0 0 5px rgb(12, 12, 12); } .slider-handle::before, .slider-handle::after { content: " "; display: block; width: 2px; background: #fff; height: 9999px; position: absolute; left: 50%; margin-left: -1.5px; } /* ------------------------------------------------- The editing results shown below inversion results ------------------------------------------------- */ .edit_labels{ font-weight:500; font-size: 24px; color: #fff; height: 20px; margin-left: 20px; position: relative; top: 20px; } .open > a:hover { color: #555; background-color: red; } #directions { padding-top:30; padding-bottom:0; margin-bottom: 0px; height: 20px; } #custom_task { padding-top:0; padding-bottom:0; margin-bottom: 0px; height: 20px; } #slider_ddim {accent-color: #941120;} #slider_ddim::-webkit-slider-thumb {background-color: #941120;} #slider_xa {accent-color: #941120;} #slider_xa::-webkit-slider-thumb {background-color: #941120;} #slider_edit_mul {accent-color: #941120;} #slider_edit_mul::-webkit-slider-thumb {background-color: #941120;} #input_image [data-testid="image"]{ height: unset; } #input_image_synth [data-testid="image"]{ height: unset; } """ HTML_header = f"""
Highly Personalized Text Embedding for Image Manipulation by Stable Diffusion
[Project page] [Github]

We present a simple yet highly effective approach to personalization using highly personalized (HiPer) text embedding by decomposing the CLIP embedding space for personalization and content manipulation. Our method does not require model fine-tuning or identifiers, yet still enables manipulation of background, texture, and motion with just a single image and target text.


""" HTML_input_header = f"""

Step 1: select a real input image.

""" HTML_middle_header = f"""

Step 2: select the editing options.

""" HTML_output_header = f"""

Step 3: translated image!

""" import numpy as np import torch from PIL import Image, ImageDraw, ImageFont import cv2 from typing import Optional, Union, Tuple, List, Callable, Dict # from tqdm.notebook import tqdm #codes for 'show_image' and 'text_under_image' are from # https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb def show_images(images, num_rows=2, offset_ratio=0.02): if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty num_items = len(images) h, w, c = images[0].shape offset = int(h * offset_ratio) num_cols = num_items // num_rows image_ = np.ones((h * num_rows + offset * (num_rows - 1), w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 for i in range(num_rows): for j in range(num_cols): image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ i * num_cols + j] pil_img = Image.fromarray(image_) # pil_img.save(name) return pil_img def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): h, w, c = image.shape offset = int(h * .2) img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 font = cv2.FONT_HERSHEY_SIMPLEX img[:h] = image textsize = cv2.getTextSize(text, font, 1, 2)[0] text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) return img def inf_save(inf_img, name): images = [] for i in range(len(inf_img)): image = np.array(inf_img[i].resize((256,256))) image = text_under_image(image, name[i]) images.append(image) inf_image = show_images(np.stack(images, axis=0),num_rows=1) return inf_image def temp_save(inf_img,num_rows): images = [] for i in range(len(inf_img)): image = np.array(inf_img[i].resize((256,256))) images.append(image) inf_image = show_images(np.stack(images, axis=0),num_rows=num_rows) return inf_image