HiPer / src /utils /gradio_utils.py
HiPer0's picture
Initial_commit
3e5825a
raw
history blame
25.6 kB
import numpy as np
from tqdm import tqdm
import gradio as gr
import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate import Accelerator
# from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, UNet2DConditionModel
# from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
from torch import autocast
from src.diffusers_ import StableDiffusionPipeline
def launch_source():
image = Image.open('./tmp/train_images_source.png').convert("RGB")
output = temp_save([image],num_rows=1)
return output
def launch_opt800():
image = Image.open('./tmp/train_images_step800.png').convert("RGB")
output = temp_save([image],num_rows=1)
return output
def launch_opt900():
image = Image.open('./tmp/train_images_step900.png').convert("RGB")
output = temp_save([image],num_rows=1)
return output
def launch_opt1000():
image = Image.open('./tmp/train_images_step1000.png').convert("RGB")
output = temp_save([image],num_rows=1)
return output
def launch_opt1100():
image = Image.open('./tmp/train_images_step1100.png').convert("RGB")
output = temp_save([image],num_rows=1)
return output
def launch_optimize(img_in_real, prompt, n_hiper):
os.makedirs("tmp", exist_ok=True)
# Setting
accelerator = Accelerator(
gradient_accumulation_steps=1,
mixed_precision="fp16",
)
seed = 2220000
set_seed(seed)
g_cuda = torch.Generator(device='cuda')
g_cuda.manual_seed(seed)
optimizer_class = torch.optim.Adam
weight_dtype = torch.float16
pretrained_model_name = 'CompVis/stable-diffusion-v1-4'
# Load pretrained models
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer", use_auth_token=True)
CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder", use_auth_token=True)
vae = AutoencoderKL.from_pretrained(pretrained_model_name, subfolder="vae", use_auth_token=True)
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name, subfolder="unet", use_auth_token=True)
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# Encode the input image.
vae.to(accelerator.device, dtype=weight_dtype)
input_image = img_in_real.convert("RGB")
img_in_real.save(os.path.join("tmp", "train_images_source.png"))
image_transforms = transforms.Compose(
[
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
init_image = image_transforms(input_image)
init_image = init_image[None].to(device=accelerator.device, dtype=weight_dtype)
with torch.inference_mode():
init_latents = vae.encode(init_image).latent_dist.sample()
init_latents = 0.18215 * init_latents
# Encode the source and target text.
CLIP_text_encoder.to(accelerator.device, dtype=weight_dtype)
text_ids_src = tokenizer(prompt,padding="max_length",truncation=True,max_length=tokenizer.model_max_length,return_tensors="pt").input_ids
text_ids_src = text_ids_src.to(device=accelerator.device)
with torch.inference_mode():
source_embeddings = CLIP_text_encoder(text_ids_src)[0].float()
# del vae, CLIP_text_encoder
del vae, CLIP_text_encoder
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For inference
ddim_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=ddim_scheduler, torch_dtype=torch.float16).to("cuda")
num_samples = 1
guidance_scale = 7.5
num_inference_steps = 50
height = 512
width = 512
# Optimize hiper embedding
n_hiper = int(n_hiper)
hiper_embeddings = source_embeddings[:,-n_hiper:].clone().detach()
src_embeddings = source_embeddings[:,:-n_hiper].clone().detach()
hiper_embeddings.requires_grad_(True)
optimizer = optimizer_class(
[hiper_embeddings],
lr=5e-3,
betas=(0.9, 0.999),
eps=1e-08,
)
unet, optimizer = accelerator.prepare(unet, optimizer)
emb_train_steps = 1101
# emb_train_steps = 201
def train_loop(optimizer, hiper_embeddings):
inf_images=[]
for step in tqdm(range(emb_train_steps)):
with accelerator.accumulate(unet):
noise = torch.randn_like(init_latents)
bsz = init_latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(init_latents, noise, timesteps)
source_embeddings = torch.cat([src_embeddings, hiper_embeddings], 1)
noise_pred = unet(noisy_latents, timesteps, source_embeddings).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# Check inference
if step in [800,900,1000,1100]:
inf_emb = torch.cat([src_embeddings, hiper_embeddings.clone().detach()], 1)
with autocast("cuda"), torch.inference_mode():
images = pipe(text_embeddings=inf_emb, height=height, width=width, num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda).images
inf_images.append(images[0])
images[0].save(os.path.join("tmp", "train_images_step{}.png".format(step)))
del images
if step in [800,900,1000,1100]:
torch.save(hiper_embeddings.cpu(), os.path.join("tmp", "hiper_embeddings_step{}.pt".format(step)))
accelerator.wait_for_everyone()
out_image = train_loop(optimizer, hiper_embeddings)
image = Image.open('./tmp/train_images_source.png').convert("RGB")
output = temp_save([image],num_rows=1)
return "tmp", output
def launch_main(dest, step, fpath_z_gen, seed):
seed = int(seed)
set_seed(seed)
g_cuda = torch.Generator(device='cuda')
g_cuda.manual_seed(seed)
# Load pretrained models
pretrained_model_name = 'CompVis/stable-diffusion-v1-4'
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name, scheduler=scheduler, torch_dtype=torch.float16).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name, subfolder="tokenizer", use_auth_token=True)
CLIP_text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name, subfolder="text_encoder", use_auth_token=True)
# Encode the target text.
text_ids_tgt = tokenizer(dest, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
CLIP_text_encoder.to('cuda', dtype=torch.float32)
with torch.inference_mode():
target_embedding = CLIP_text_encoder(text_ids_tgt.to('cuda'))[0].to('cuda')
del CLIP_text_encoder
# Concat target and hiper embeddings
step = int(step.replace("Step ",""))
hiper_embeddings = torch.load('./tmp/hiper_embeddings_step{}.pt'.format(step)).to("cuda")
n_hiper = hiper_embeddings.shape[1]
inference_embeddings =torch.cat([target_embedding[:, :-n_hiper], hiper_embeddings*0.8], 1)
# Generate target images
num_samples = 1
guidance_scale = 7.5
num_inference_steps = 50
height = 512
width = 512
with autocast("cuda"), torch.inference_mode():
image_set = []
for idx, embd in enumerate([inference_embeddings]):
for i in range(10):
images = pipe(
text_embeddings=embd,
height=height,
width=width,
num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=g_cuda
).images
image_set.append(images[0])
out_image = temp_save(image_set,num_rows=5)
return out_image
def set_visible_true():
return gr.update(visible=True)
def set_visible_false():
return gr.update(visible=False)
CSS_main = """
body {
font-family: "HelveticaNeue-Light", "Helvetica Neue Light", "Helvetica Neue", Helvetica, Arial, "Lucida Grande", sans-serif;
font-weight:300;
font-size:18px;
margin-left: auto;
margin-right: auto;
padding-left: 10px;
padding-right: 10px;
width: 800px;
}
h1 {
font-size:32px;
font-weight:300;
text-align: center;
}
h2 {
font-size:32px;
font-weight:300;
text-align: center;
}
#lbl_gallery_input{
font-family: 'Helvetica', 'Arial', sans-serif;
text-align: center;
color: #fff;
font-size: 28px;
display: inline
}
#lbl_gallery_comparision{
font-family: 'Helvetica', 'Arial', sans-serif;
text-align: center;
color: #fff;
font-size: 28px;
}
.disclaimerbox {
background-color: #eee;
border: 1px solid #eeeeee;
border-radius: 10px ;
-moz-border-radius: 10px ;
-webkit-border-radius: 10px ;
padding: 20px;
}
video.header-vid {
height: 140px;
border: 1px solid black;
border-radius: 10px ;
-moz-border-radius: 10px ;
-webkit-border-radius: 10px ;
}
img.header-img {
height: 140px;
border: 1px solid black;
border-radius: 10px ;
-moz-border-radius: 10px ;
-webkit-border-radius: 10px ;
}
img.rounded {
border: 1px solid #eeeeee;
border-radius: 10px ;
-moz-border-radius: 10px ;
-webkit-border-radius: 10px ;
}
a:link
{
color: #941120;
text-decoration: none;
}
a:visited
{
color: #941120;
text-decoration: none;
}
a:hover {
color: #941120;
}
td.dl-link {
height: 160px;
text-align: center;
font-size: 22px;
}
.layered-paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
box-shadow:
0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
5px 5px 0 0px #fff, /* The second layer */
5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
10px 10px 0 0px #fff, /* The third layer */
10px 10px 1px 1px rgba(0,0,0,0.35), /* The third layer shadow */
15px 15px 0 0px #fff, /* The fourth layer */
15px 15px 1px 1px rgba(0,0,0,0.35), /* The fourth layer shadow */
20px 20px 0 0px #fff, /* The fifth layer */
20px 20px 1px 1px rgba(0,0,0,0.35), /* The fifth layer shadow */
25px 25px 0 0px #fff, /* The fifth layer */
25px 25px 1px 1px rgba(0,0,0,0.35); /* The fifth layer shadow */
margin-left: 10px;
margin-right: 45px;
}
.paper-big { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
box-shadow:
0px 0px 1px 1px rgba(0,0,0,0.35); /* The top layer shadow */
margin-left: 10px;
margin-right: 45px;
}
.layered-paper { /* modified from: http://css-tricks.com/snippets/css/layered-paper/ */
box-shadow:
0px 0px 1px 1px rgba(0,0,0,0.35), /* The top layer shadow */
5px 5px 0 0px #fff, /* The second layer */
5px 5px 1px 1px rgba(0,0,0,0.35), /* The second layer shadow */
10px 10px 0 0px #fff, /* The third layer */
10px 10px 1px 1px rgba(0,0,0,0.35); /* The third layer shadow */
margin-top: 5px;
margin-left: 10px;
margin-right: 30px;
margin-bottom: 5px;
}
.vert-cent {
position: relative;
top: 50%;
transform: translateY(-50%);
}
hr
{
border: 0;
height: 1px;
background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
}
.card {
/* width: 130px;
height: 195px;
width: 1px;
height: 1px; */
position: relative;
display: inline-block;
/* margin: 50px; */
}
.card .img-top {
display: none;
position: absolute;
top: 0;
left: 0;
z-index: 99;
}
.card:hover .img-top {
display: inline;
}
details {
user-select: none;
}
details>summary span.icon {
width: 24px;
height: 24px;
transition: all 0.3s;
margin-left: auto;
}
details[open] summary span.icon {
transform: rotate(180deg);
}
summary {
display: flex;
cursor: pointer;
}
summary::-webkit-details-marker {
display: none;
}
ul {
display: table;
margin: 0 auto;
text-align: left;
}
.dark {
padding: 1em 2em;
background-color: #333;
box-shadow: 3px 3px 3px #333;
border: 1px #333;
}
.column {
float: left;
width: 20%;
padding: 0.5%;
}
.galleryImg {
transition: opacity 0.3s;
-webkit-transition: opacity 0.3s;
filter: grayscale(100%);
/* filter: blur(2px); */
-webkit-transition : -webkit-filter 250ms linear;
/* opacity: 0.5; */
cursor: pointer;
}
.selected {
/* outline: 100px solid var(--hover-background) !important; */
/* outline-offset: -100px; */
filter: grayscale(0%);
-webkit-transition : -webkit-filter 250ms linear;
/*opacity: 1.0 !important; */
}
.galleryImg:hover {
filter: grayscale(0%);
-webkit-transition : -webkit-filter 250ms linear;
}
.row {
margin-bottom: 1em;
padding: 0px 1em;
}
/* Clear floats after the columns */
.row:after {
content: "";
display: table;
clear: both;
}
/* The expanding image container */
#gallery {
position: relative;
/*display: none;*/
}
#section_comparison{
position: relative;
width: 100%;
height: max-content;
}
/* SLIDER
-------------------------------------------------- */
.slider-container {
position: relative;
height: 384px;
width: 512px;
cursor: grab;
overflow: hidden;
margin: auto;
}
.slider-after {
display: block;
position: absolute;
top: 0;
right: 0;
bottom: 0;
left: 0;
width: 100%;
height: 100%;
overflow: hidden;
}
.slider-before {
display: block;
position: absolute;
top: 0;
/* right: 0; */
bottom: 0;
left: 0;
width: 100%;
height: 100%;
z-index: 15;
overflow: hidden;
}
.slider-before-inset {
position: absolute;
top: 0;
bottom: 0;
left: 0;
}
.slider-after img,
.slider-before img {
object-fit: cover;
position: absolute;
width: 100%;
height: 100%;
object-position: 50% 50%;
top: 0;
bottom: 0;
left: 0;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-o-user-select: none;
user-select: none;
}
#lbl_inset_left{
text-align: center;
position: absolute;
top: 384px;
width: 150px;
left: calc(50% - 256px);
z-index: 11;
font-size: 16px;
color: #fff;
margin: 10px;
}
.inset-before {
position: absolute;
width: 150px;
height: 150px;
box-shadow: 3px 3px 3px #333;
border: 1px #333;
border-style: solid;
z-index: 16;
top: 410px;
left: calc(50% - 256px);
margin: 10px;
font-size: 1em;
background-repeat: no-repeat;
pointer-events: none;
}
#lbl_inset_right{
text-align: center;
position: absolute;
top: 384px;
width: 150px;
right: calc(50% - 256px);
z-index: 11;
font-size: 16px;
color: #fff;
margin: 10px;
}
.inset-after {
position: absolute;
width: 150px;
height: 150px;
box-shadow: 3px 3px 3px #333;
border: 1px #333;
border-style: solid;
z-index: 16;
top: 410px;
right: calc(50% - 256px);
margin: 10px;
font-size: 1em;
background-repeat: no-repeat;
pointer-events: none;
}
#lbl_inset_input{
text-align: center;
position: absolute;
top: 384px;
width: 150px;
left: calc(50% - 256px + 150px + 20px);
z-index: 11;
font-size: 16px;
color: #fff;
margin: 10px;
}
.inset-target {
position: absolute;
width: 150px;
height: 150px;
box-shadow: 3px 3px 3px #333;
border: 1px #333;
border-style: solid;
z-index: 16;
top: 410px;
right: calc(50% - 256px + 150px + 20px);
margin: 10px;
font-size: 1em;
background-repeat: no-repeat;
pointer-events: none;
}
.slider-beforePosition {
background: #121212;
color: #fff;
left: 0;
pointer-events: none;
border-radius: 0.2rem;
padding: 2px 10px;
}
.slider-afterPosition {
background: #121212;
color: #fff;
right: 0;
pointer-events: none;
border-radius: 0.2rem;
padding: 2px 10px;
}
.beforeLabel {
position: absolute;
top: 0;
margin: 1rem;
font-size: 1em;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-o-user-select: none;
user-select: none;
}
.afterLabel {
position: absolute;
top: 0;
margin: 1rem;
font-size: 1em;
-webkit-user-select: none;
-khtml-user-select: none;
-moz-user-select: none;
-o-user-select: none;
user-select: none;
}
.slider-handle {
height: 101px;
width: 41px;
position: absolute;
left: 50%;
top: 50%;
margin-left: -20px;
margin-top: -21px;
border: 2px solid #fff;
border-radius: 1000px;
z-index: 20;
pointer-events: none;
box-shadow: 0 0 10px rgb(12, 12, 12);
}
.handle-left-arrow,
.handle-right-arrow {
width: 0;
height: 0;
border: 6px inset transparent;
position: absolute;
top: 50%;
margin-top: -6px;
}
.handle-left-arrow {
border-right: 6px solid #fff;
left: 50%;
margin-left: -17px;
}
.handle-right-arrow {
border-left: 6px solid #fff;
right: 50%;
margin-right: -17px;
}
.slider-handle::before {
bottom: 50%;
margin-bottom: 20px;
box-shadow: 0 0 10px rgb(12, 12, 12);
}
.slider-handle::after {
top: 50%;
margin-top: 20.5px;
box-shadow: 0 0 5px rgb(12, 12, 12);
}
.slider-handle::before,
.slider-handle::after {
content: " ";
display: block;
width: 2px;
background: #fff;
height: 9999px;
position: absolute;
left: 50%;
margin-left: -1.5px;
}
/*
-------------------------------------------------
The editing results shown below inversion results
-------------------------------------------------
*/
.edit_labels{
font-weight:500;
font-size: 24px;
color: #fff;
height: 20px;
margin-left: 20px;
position: relative;
top: 20px;
}
.open > a:hover {
color: #555;
background-color: red;
}
#directions { padding-top:30; padding-bottom:0; margin-bottom: 0px; height: 20px; }
#custom_task { padding-top:0; padding-bottom:0; margin-bottom: 0px; height: 20px; }
#slider_ddim {accent-color: #941120;}
#slider_ddim::-webkit-slider-thumb {background-color: #941120;}
#slider_xa {accent-color: #941120;}
#slider_xa::-webkit-slider-thumb {background-color: #941120;}
#slider_edit_mul {accent-color: #941120;}
#slider_edit_mul::-webkit-slider-thumb {background-color: #941120;}
#input_image [data-testid="image"]{
height: unset;
}
#input_image_synth [data-testid="image"]{
height: unset;
}
"""
HTML_header = f"""
<body>
<center>
<span style="font-size:32px">Highly Personalized Text Embedding for Image Manipulation by Stable Diffusion</span>
<table align=center>
<tr>
<td align=center>
<center>
<span style="font-size:24px; margin-left: 0px;"><a href='https://hiper0.github.io/'>[Project page]</a></span>
<span style="font-size:24px; margin-left: 20px;"><a href='https://github.com/HiPer0/HiPer'>[Github]</a></span>
</center>
</td>
</tr>
</table>
</center>
<center>
<div align=center>
<p align=left>
We present a simple yet highly effective approach to personalization using highly personalized (HiPer) text embedding by decomposing the CLIP embedding space for personalization and content manipulation. Our method does not require model fine-tuning or identifiers, yet still enables manipulation of background, texture, and motion with just a single image and target text.
<br>
</p>
</div>
</center>
<hr>
</body>
"""
HTML_input_header = f"""
<p style="font-size:150%; padding: 0px">
<span font-weight: 800; style=" color: #941120;"> Step 1: </span> select a real input image.
</p>
"""
HTML_middle_header = f"""
<p style="font-size:150%;">
<span font-weight: 800; style=" color: #941120;"> Step 2: </span> select the editing options.
</p>
"""
HTML_output_header = f"""
<p style="font-size:150%;">
<span font-weight: 800; style=" color: #941120;"> Step 3: </span> translated image!
</p>
"""
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import cv2
from typing import Optional, Union, Tuple, List, Callable, Dict
# from tqdm.notebook import tqdm
#codes for 'show_image' and 'text_under_image' are from
# https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb
def show_images(images, num_rows=2, offset_ratio=0.02):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
pil_img = Image.fromarray(image_)
# pil_img.save(name)
return pil_img
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
return img
def inf_save(inf_img, name):
images = []
for i in range(len(inf_img)):
image = np.array(inf_img[i].resize((256,256)))
image = text_under_image(image, name[i])
images.append(image)
inf_image = show_images(np.stack(images, axis=0),num_rows=1)
return inf_image
def temp_save(inf_img,num_rows):
images = []
for i in range(len(inf_img)):
image = np.array(inf_img[i].resize((256,256)))
images.append(image)
inf_image = show_images(np.stack(images, axis=0),num_rows=num_rows)
return inf_image