face-to-all-666 / pipeline_stable_diffusion_xl_instantid_img2img.py
primerz's picture
Update pipeline_stable_diffusion_xl_instantid_img2img.py
7fa4705 verified
raw
history blame
6.23 kB
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
from diffusers.image_processor import PipelineImageInput
from diffusers.models import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import deprecate, logging, replace_example_docstring
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
logger = logging.get_logger(__name__) # Initialize logger
# Check for xformers availability
try:
import xformers
import xformers.ops
xformers_available = True
except ImportError:
xformers_available = False
def reshape_tensor(x: torch.Tensor, heads: int) -> torch.Tensor:
"""Reshapes tensor for multi-head attention processing."""
bs, length, width = x.shape
return x.view(bs, length, heads, -1).transpose(1, 2)
class PerceiverAttention(nn.Module):
def __init__(self, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
x, latents = self.norm1(x), self.norm2(latents)
q, kv = self.to_q(latents), self.to_kv(torch.cat((x, latents), dim=1))
k, v = kv.chunk(2, dim=-1)
q, k, v = map(lambda t: reshape_tensor(t, self.heads), (q, k, v))
# Scaled dot-product attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).to(weight.dtype)
out = weight @ v
return self.to_out(out.permute(0, 2, 1, 3).reshape(latents.shape[0], latents.shape[1], -1))
class Resampler(nn.Module):
def __init__(
self,
dim: int = 1024,
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
embedding_dim: int = 768,
output_dim: int = 1024,
ff_mult: int = 4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / math.sqrt(dim))
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([nn.ModuleList([PerceiverAttention(dim, dim_head, heads), nn.LayerNorm(dim)]) for _ in range(depth)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
latents = self.latents.expand(x.size(0), -1, -1)
x = self.proj_in(x)
for attn, norm in self.layers:
latents = norm(attn(x, latents) + latents)
return self.norm_out(self.proj_out(latents))
class StableDiffusionXLInstantIDImg2ImgPipeline(StableDiffusionXLControlNetImg2ImgPipeline):
def cuda(self, dtype: torch.dtype = torch.float16, use_xformers: bool = False):
self.to("cuda", dtype)
if hasattr(self, "image_proj_model"):
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
if use_xformers:
if is_xformers_available():
self.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xFormers is not available. Ensure it is installed correctly.")
def load_ip_adapter_instantid(self, model_ckpt: str, image_emb_dim: int = 512, num_tokens: int = 16, scale: float = 0.5):
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
self.set_ip_adapter(model_ckpt, num_tokens, scale)
def set_image_proj_model(self, model_ckpt: str, image_emb_dim: int = 512, num_tokens: int = 16):
self.image_proj_model = Resampler(
dim=1280, depth=4, dim_head=64, heads=20, num_queries=num_tokens, embedding_dim=image_emb_dim, output_dim=self.unet.config.cross_attention_dim
).to(self.device, dtype=self.dtype).eval()
state_dict = torch.load(model_ckpt, map_location="cpu").get("image_proj", torch.load(model_ckpt, map_location="cpu"))
self.image_proj_model.load_state_dict(state_dict)
self.image_proj_model_in_features = image_emb_dim
def set_ip_adapter(self, model_ckpt: str, num_tokens: int, scale: float):
attn_procs = {}
for name, module in self.unet.attn_processors.items():
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
hidden_size = self.unet.config.block_out_channels[{"mid_block": -1, "up_blocks": int(name[9]), "down_blocks": int(name[12])}[name.split(".")[0]]]
attn_procs[name] = (IPAttnProcessor(hidden_size, cross_attention_dim, scale, num_tokens)
if cross_attention_dim else nn.Identity()).to(self.unet.device, dtype=self.unet.dtype)
self.unet.set_attn_processor(attn_procs)
self.unet.attn_processors.load_state_dict(torch.load(model_ckpt, map_location="cpu").get("ip_adapter", torch.load(model_ckpt, map_location="cpu")))
def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
prompt_image_emb = torch.tensor(prompt_image_emb) if not isinstance(prompt_image_emb, torch.Tensor) else prompt_image_emb.clone().detach()
prompt_image_emb = prompt_image_emb.to(device, dtype=dtype).reshape([1, -1, self.image_proj_model_in_features])
if do_classifier_free_guidance:
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
return self.image_proj_model.to(device)(prompt_image_emb)