from email.policy import default from json import encoder import gradio as gr import spaces import numpy as np import torch import requests import random import os import sys import pickle from PIL import Image from tqdm.auto import tqdm from datetime import datetime import torch.nn as nn import torch.nn.functional as F class AttnProcessor(nn.Module): r""" Default processor for performing attention-related computations. """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class AttnProcessor2_0(torch.nn.Module): r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__( self, hidden_size=None, cross_attention_dim=None, ): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def is_torch2_available(): return hasattr(F, "scaled_dot_product_attention") if is_torch2_available(): from utils.gradio_utils import \ AttnProcessor2_0 as AttnProcessor # from utils.gradio_utils import SpatialAttnProcessor2_0 else: from utils.gradio_utils import AttnProcessor import diffusers from diffusers import StableDiffusionXLPipeline from utils import PhotoMakerStableDiffusionXLPipeline from diffusers import DDIMScheduler import torch.nn.functional as F def cal_attn_mask(total_length,id_length,sa16,sa32,sa64,device="cuda",dtype= torch.float16): bool_matrix256 = torch.rand((1, total_length * 256),device = device,dtype = dtype) < sa16 bool_matrix1024 = torch.rand((1, total_length * 1024),device = device,dtype = dtype) < sa32 bool_matrix4096 = torch.rand((1, total_length * 4096),device = device,dtype = dtype) < sa64 bool_matrix256 = bool_matrix256.repeat(total_length,1) bool_matrix1024 = bool_matrix1024.repeat(total_length,1) bool_matrix4096 = bool_matrix4096.repeat(total_length,1) for i in range(total_length): bool_matrix256[i:i+1,id_length*256:] = False bool_matrix1024[i:i+1,id_length*1024:] = False bool_matrix4096[i:i+1,id_length*4096:] = False bool_matrix256[i:i+1,i*256:(i+1)*256] = True bool_matrix1024[i:i+1,i*1024:(i+1)*1024] = True bool_matrix4096[i:i+1,i*4096:(i+1)*4096] = True mask256 = bool_matrix256.unsqueeze(1).repeat(1,256,1).reshape(-1,total_length * 256) mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,1024,1).reshape(-1,total_length * 1024) mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,4096,1).reshape(-1,total_length * 4096) return mask256,mask1024,mask4096 def cal_attn_mask_xl(total_length,id_length,sa32,sa64,height,width,device="cuda",dtype= torch.float16): nums_1024 = (height // 32) * (width // 32) nums_4096 = (height // 16) * (width // 16) bool_matrix1024 = torch.rand((1, total_length * nums_1024),device = device,dtype = dtype) < sa32 bool_matrix4096 = torch.rand((1, total_length * nums_4096),device = device,dtype = dtype) < sa64 bool_matrix1024 = bool_matrix1024.repeat(total_length,1) bool_matrix4096 = bool_matrix4096.repeat(total_length,1) for i in range(total_length): bool_matrix1024[i:i+1,id_length*nums_1024:] = False bool_matrix4096[i:i+1,id_length*nums_4096:] = False bool_matrix1024[i:i+1,i*nums_1024:(i+1)*nums_1024] = True bool_matrix4096[i:i+1,i*nums_4096:(i+1)*nums_4096] = True mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,nums_1024,1).reshape(-1,total_length * nums_1024) mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,nums_4096,1).reshape(-1,total_length * nums_4096) return mask1024,mask4096 import copy import os from huggingface_hub import hf_hub_download from diffusers.utils import load_image from utils.utils import get_comic # must remove this one style_list = [ { "name": "(No style)", "prompt": "{prompt}", "negative_prompt": "", }, { "name": "Japanese Anime", "prompt": "anime artwork illustrating {prompt}. created by japanese anime studio. highly emotional. best quality, high resolution", "negative_prompt": "low quality, low resolution" }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured", }, { "name": "Disney Charactor", "prompt": "A Pixar animation character of {prompt} . pixar-style, studio anime, Disney, high-quality", "negative_prompt": "lowres, bad anatomy, bad hands, text, bad eyes, bad arms, bad legs, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry, grayscale, noisy, sloppy, messy, grainy, highly detailed, ultra textured, photo", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly", }, { "name": "Comic book", "prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed", "negative_prompt": "photograph, deformed, glitch, noisy, realistic, stock photo", }, { "name": "Line art", "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics", "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic", } ] styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} image_encoder_path = "./data/models/ip_adapter/sdxl_models/image_encoder" ip_ckpt = "./data/models/ip_adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" os.environ["no_proxy"] = "localhost,127.0.0.1,::1" STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Japanese Anime" global models_dict use_va = True models_dict = { # "Juggernaut": "RunDiffusion/Juggernaut-XL-v8", "RealVision": "SG161222/RealVisXL_V4.0" , # "SDXL":"stabilityai/stable-diffusion-xl-base-1.0" , "Unstable": "stablediffusionapi/sdxl-unstable-diffusers-y" } photomaker_path = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model") MAX_SEED = np.iinfo(np.int32).max def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def set_text_unfinished(): return gr.update(visible=True, value="

(Not Finished) Generating ··· The intermediate results will be shown.

") def set_text_finished(): return gr.update(visible=True, value="

Generation Finished

") ################################################# def get_image_path_list(folder_name): image_basename_list = os.listdir(folder_name) image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list]) return image_path_list ################################################# class SpatialAttnProcessor2_0(torch.nn.Module): r""" Attention processor for IP-Adapater for PyTorch 2.0. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. text_context_len (`int`, defaults to 77): The context length of the text features. scale (`float`, defaults to 1.0): the weight scale of image prompt. """ def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = "cuda",dtype = torch.float16): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.device = device self.dtype = dtype self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.total_length = id_length + 1 self.id_length = id_length self.id_bank = {} def __call__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): # un_cond_hidden_states, cond_hidden_states = hidden_states.chunk(2) # un_cond_hidden_states = self.__call2__(attn, un_cond_hidden_states,encoder_hidden_states,attention_mask,temb) global total_count,attn_count,cur_step,mask1024,mask4096 global sa32, sa64 global write global height,width global num_steps if write: # print(f"white:{cur_step}") self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]] else: encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:])) # 判断随机数是否大于0.5 if cur_step <=1: hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb) else: # 256 1024 4096 random_number = random.random() if cur_step <0.4 * num_steps: rand_num = 0.3 else: rand_num = 0.1 # print(f"hidden state shape {hidden_states.shape[1]}") if random_number > rand_num: # print("mask shape",mask1024.shape,mask4096.shape) if not write: if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:] else: attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:] else: # print(self.total_length,self.id_length,hidden_states.shape,(height//32) * (width//32)) if hidden_states.shape[1] == (height//32) * (width//32): attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length] else: attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length] # print(attention_mask.shape) # print("before attention",hidden_states.shape,attention_mask.shape,encoder_hidden_states.shape if encoder_hidden_states is not None else "None") hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb) else: hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb) attn_count +=1 if attn_count == total_count: attn_count = 0 cur_step += 1 mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype) return hidden_states def __call1__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): # print("hidden state shape",hidden_states.shape,self.id_length) residual = hidden_states # if encoder_hidden_states is not None: # raise Exception("not implement") if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: total_batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) total_batch_size,nums_token,channel = hidden_states.shape img_nums = total_batch_size//2 hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel) batch_size, sequence_length, _ = hidden_states.shape if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # print(key.shape,value.shape,query.shape,attention_mask.shape) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 #print(query.shape,key.shape,value.shape,attention_mask.shape) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) # if input_ndim == 4: # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) # if attn.residual_connection: # tile_hidden_states = tile_hidden_states + residual if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor # print(hidden_states.shape) return hidden_states def __call2__( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, channel = ( hidden_states.shape ) # print(hidden_states.shape) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states # B, N, C else: encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def set_attention_processor(unet,id_length,is_ipadapter = False): global total_count total_count = 0 attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if name.startswith("up_blocks") : attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length) total_count +=1 else: attn_procs[name] = AttnProcessor() else: if is_ipadapter: attn_procs[name] = IPAttnProcessor2_0( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1, num_tokens=4, ).to(unet.device, dtype=torch.float16) else: attn_procs[name] = AttnProcessor() unet.set_attn_processor(copy.deepcopy(attn_procs)) print("successsfully load paired self-attention") print(f"number of the processor : {total_count}") canvas_html = "
" load_js = """ async () => { const url = "https://huggingface.co/datasets/radames/gradio-components/raw/main/sketch-canvas.js" fetch(url) .then(res => res.text()) .then(text => { const script = document.createElement('script'); script.type = "module" script.src = URL.createObjectURL(new Blob([text], { type: 'application/javascript' })); document.head.appendChild(script); }); } """ get_js_colors = """ async (canvasData) => { const canvasEl = document.getElementById("canvas-root"); return [canvasEl._data] } """ css = ''' #color-bg{display:flex;justify-content: center;align-items: center;} .color-bg-item{width: 100%; height: 32px} #main_button{width:100%}