from email.policy import default
from json import encoder
import gradio as gr
import spaces
import numpy as np
import torch
import requests
import random
import os
import sys
import pickle
from PIL import Image
from tqdm.auto import tqdm
from datetime import datetime
import torch.nn as nn
import torch.nn.functional as F
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
import diffusers
from diffusers import StableDiffusionXLPipeline
from utils import PhotoMakerStableDiffusionXLPipeline
from diffusers import DDIMScheduler
import torch.nn.functional as F
def cal_attn_mask(total_length,id_length,sa16,sa32,sa64,device="cuda",dtype= torch.float16):
bool_matrix256 = torch.rand((1, total_length * 256),device = device,dtype = dtype) < sa16
bool_matrix1024 = torch.rand((1, total_length * 1024),device = device,dtype = dtype) < sa32
bool_matrix4096 = torch.rand((1, total_length * 4096),device = device,dtype = dtype) < sa64
bool_matrix256 = bool_matrix256.repeat(total_length,1)
bool_matrix1024 = bool_matrix1024.repeat(total_length,1)
bool_matrix4096 = bool_matrix4096.repeat(total_length,1)
for i in range(total_length):
bool_matrix256[i:i+1,id_length*256:] = False
bool_matrix1024[i:i+1,id_length*1024:] = False
bool_matrix4096[i:i+1,id_length*4096:] = False
bool_matrix256[i:i+1,i*256:(i+1)*256] = True
bool_matrix1024[i:i+1,i*1024:(i+1)*1024] = True
bool_matrix4096[i:i+1,i*4096:(i+1)*4096] = True
mask256 = bool_matrix256.unsqueeze(1).repeat(1,256,1).reshape(-1,total_length * 256)
mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,1024,1).reshape(-1,total_length * 1024)
mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,4096,1).reshape(-1,total_length * 4096)
return mask256,mask1024,mask4096
def cal_attn_mask_xl(total_length,id_length,sa32,sa64,height,width,device="cuda",dtype= torch.float16):
nums_1024 = (height // 32) * (width // 32)
nums_4096 = (height // 16) * (width // 16)
bool_matrix1024 = torch.rand((1, total_length * nums_1024),device = device,dtype = dtype) < sa32
bool_matrix4096 = torch.rand((1, total_length * nums_4096),device = device,dtype = dtype) < sa64
bool_matrix1024 = bool_matrix1024.repeat(total_length,1)
bool_matrix4096 = bool_matrix4096.repeat(total_length,1)
for i in range(total_length):
bool_matrix1024[i:i+1,id_length*nums_1024:] = False
bool_matrix4096[i:i+1,id_length*nums_4096:] = False
bool_matrix1024[i:i+1,i*nums_1024:(i+1)*nums_1024] = True
bool_matrix4096[i:i+1,i*nums_4096:(i+1)*nums_4096] = True
mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,nums_1024,1).reshape(-1,total_length * nums_1024)
mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,nums_4096,1).reshape(-1,total_length * nums_4096)
return mask1024,mask4096
import copy
import os
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image
from utils.utils import get_comic # must remove this one
style_list = [
{
"name": "(No style)",
"prompt": "{prompt}",
"negative_prompt": "",
},
{
"name": "Japanese Anime",
"prompt": "anime artwork illustrating {prompt}. created by japanese anime studio. highly emotional. best quality, high resolution",
"negative_prompt": "low quality, low resolution"
},
{
"name": "Cinematic",
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
"negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
},
{
"name": "Disney Charactor",
"prompt": "A Pixar animation character of {prompt} . pixar-style, studio anime, Disney, high-quality",
"negative_prompt": "lowres, bad anatomy, bad hands, text, bad eyes, bad arms, bad legs, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry, grayscale, noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo",
},
{
"name": "Photographic",
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
"negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
},
{
"name": "Comic book",
"prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
"negative_prompt": "photograph, deformed, glitch, noisy, realistic, stock photo",
},
{
"name": "Line art",
"prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
"negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
}
]
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder"
ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin"
os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Japanese Anime"
global models_dict
use_va = True
models_dict = {
# "Juggernaut": "RunDiffusion/Juggernaut-XL-v8",
"RealVision": "SG161222/RealVisXL_V4.0" ,
# "SDXL":"stabilityai/stable-diffusion-xl-base-1.0" ,
"Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y"
}
photomaker_path = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
MAX_SEED = np.iinfo(np.int32).max
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def set_text_unfinished():
return gr.update(visible=True, value="
(Not Finished) Generating ··· The intermediate results will be shown.
")
def set_text_finished():
return gr.update(visible=True, value="Generation Finished
")
#################################################
def get_image_path_list(folder_name):
image_basename_list = os.listdir(folder_name)
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
return image_path_list
#################################################
class SpatialAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.device = device
self.dtype = dtype
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.total_length = id_length + 1
self.id_length = id_length
self.id_bank = {}
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
# un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2)
# un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb)
global total_count,attn_count,cur_step,mask1024,mask4096
global sa32, sa64
global write
global height,width
global num_steps
if write:
# print(f"white:{cur_step}")
self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]
else:
encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))
# 判断随机数是否大于0.5
if cur_step <=1:
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
else: # 256 1024 4096
random_number = random.random()
if cur_step <0.4 * num_steps:
rand_num = 0.3
else:
rand_num = 0.1
# print(f"hidden state shape {hidden_states.shape[1]}")
if random_number > rand_num:
# print("mask shape",mask1024.shape,mask4096.shape)
if not write:
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]
else:
attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]
else:
# print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32))
if hidden_states.shape[1] == (height//32) * (width//32):
attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length]
else:
attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]
# print(attention_mask.shape)
# print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None")
hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)
else:
hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)
attn_count +=1
if attn_count == total_count:
attn_count = 0
cur_step += 1
mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
return hidden_states
def __call1__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
# print("hidden state shape",hidden_states.shape,self.id_length)
residual = hidden_states
# if encoder_hidden_states is not None:
# raise Exception("not implement")
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
total_batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)
total_batch_size,nums_token,channel = hidden_states.shape
img_nums = total_batch_size//2
hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)
batch_size, sequence_length, _ = hidden_states.shape
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# print(key.shape,value.shape,query.shape,attention_mask.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
#print(query.shape,key.shape,value.shape,attention_mask.shape)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
# if input_ndim == 4:
# tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# if attn.residual_connection:
# tile_hidden_states = tile_hidden_states + residual
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# print(hidden_states.shape)
return hidden_states
def __call2__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, channel = (
hidden_states.shape
)
# print(hidden_states.shape)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states # B, N, C
else:
encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def set_attention_processor(unet,id_length,is_ipadapter = False):
global total_count
total_count = 0
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
if name.startswith("up_blocks") :
attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length)
total_count +=1
else:
attn_procs[name] = AttnProcessor()
else:
if is_ipadapter:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1,
num_tokens=4,
).to(unet.device, dtype=torch.float16)
else:
attn_procs[name] = AttnProcessor()
unet.set_attn_processor(copy.deepcopy(attn_procs))
print("successsfully load paired self-attention")
print(f"number of the processor : {total_count}")
canvas_html = ""
load_js = """
async () => {
const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js"
fetch(url)
.then(res => res.text())
.then(text => {
const script = document.createElement('script');
script.type = "module"
script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' }));
document.head.appendChild(script);
});
}
"""
get_js_colors = """
async (canvasData) => {
const canvasEl = document.getElementById("canvas-root");
return [canvasEl._data]
}
"""
css = '''
#color-bg{display:flex;justify-content: center;align-items: center;}
.color-bg-item{width: 100%; height: 32px}
#main_button{width:100%}