Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from typing import Dict, Optional | |
import torch | |
import torchvision.transforms.functional as FF | |
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
from diffusers import StableDiffusionPipeline | |
from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
from diffusers.schedulers import KarrasDiffusionSchedulers | |
from diffusers.utils import USE_PEFT_BACKEND | |
try: | |
from compel import Compel | |
except ImportError: | |
Compel = None | |
KCOMM = "ADDCOMM" | |
KBRK = "BREAK" | |
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): | |
r""" | |
Args for Regional Prompting Pipeline: | |
rp_args:dict | |
Required | |
rp_args["mode"]: cols, rows, prompt, prompt-ex | |
for cols, rows mode | |
rp_args["div"]: ex) 1;1;1(Divide into 3 regions) | |
for prompt, prompt-ex mode | |
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode) | |
Optional | |
rp_args["save_mask"]: True/False (save masks in prompt mode) | |
Pipeline for text-to-image generation using Stable Diffusion. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
vae ([`AutoencoderKL`]): | |
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | |
text_encoder ([`CLIPTextModel`]): | |
Frozen text-encoder. Stable Diffusion uses the text portion of | |
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | |
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | |
tokenizer (`CLIPTokenizer`): | |
Tokenizer of class | |
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | |
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | |
scheduler ([`SchedulerMixin`]): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | |
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | |
safety_checker ([`StableDiffusionSafetyChecker`]): | |
Classification module that estimates whether generated images could be considered offensive or harmful. | |
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. | |
feature_extractor ([`CLIPImageProcessor`]): | |
Model that extracts features from generated images to be used as inputs for the `safety_checker`. | |
""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: KarrasDiffusionSchedulers, | |
safety_checker: StableDiffusionSafetyChecker, | |
feature_extractor: CLIPFeatureExtractor, | |
requires_safety_checker: bool = True, | |
): | |
super().__init__( | |
vae, | |
text_encoder, | |
tokenizer, | |
unet, | |
scheduler, | |
safety_checker, | |
feature_extractor, | |
requires_safety_checker, | |
) | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=safety_checker, | |
feature_extractor=feature_extractor, | |
) | |
def __call__( | |
self, | |
prompt: str, | |
height: int = 512, | |
width: int = 512, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
negative_prompt: str = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[torch.Generator] = None, | |
latents: Optional[torch.Tensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
rp_args: Dict[str, str] = None, | |
): | |
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt | |
if negative_prompt is None: | |
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) | |
device = self._execution_device | |
regions = 0 | |
self.power = int(rp_args["power"]) if "power" in rp_args else 1 | |
prompts = prompt if isinstance(prompt, list) else [prompt] | |
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] | |
self.batch = batch = num_images_per_prompt * len(prompts) | |
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) | |
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) | |
equal = len(all_prompts_cn) == len(all_n_prompts_cn) | |
if Compel: | |
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) | |
def getcompelembs(prps): | |
embl = [] | |
for prp in prps: | |
embl.append(compel.build_conditioning_tensor(prp)) | |
return torch.cat(embl) | |
conds = getcompelembs(all_prompts_cn) | |
unconds = getcompelembs(all_n_prompts_cn) | |
embs = getcompelembs(prompts) | |
n_embs = getcompelembs(n_prompts) | |
prompt = negative_prompt = None | |
else: | |
conds = self.encode_prompt(prompts, device, 1, True)[0] | |
unconds = ( | |
self.encode_prompt(n_prompts, device, 1, True)[0] | |
if equal | |
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] | |
) | |
embs = n_embs = None | |
if not active: | |
pcallback = None | |
mode = None | |
else: | |
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]): | |
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW" | |
ocells, icells, regions = make_cells(rp_args["div"]) | |
elif "PRO" in rp_args["mode"].upper(): | |
regions = len(all_prompts_p[0]) | |
mode = "PROMPT" | |
reset_attnmaps(self) | |
self.ex = "EX" in rp_args["mode"].upper() | |
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p) | |
thresholds = [float(x) for x in rp_args["th"].split(",")] | |
orig_hw = (height, width) | |
revers = True | |
def pcallback(s_self, step: int, timestep: int, latents: torch.Tensor, selfs=None): | |
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps | |
self.step = step | |
if len(self.attnmaps_sizes) > 3: | |
self.history[step] = self.attnmaps.copy() | |
for hw in self.attnmaps_sizes: | |
allmasks = [] | |
basemasks = [None] * batch | |
for tt, th in zip(target_tokens, thresholds): | |
for b in range(batch): | |
key = f"{tt}-{b}" | |
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step) | |
mask = mask.unsqueeze(0).unsqueeze(-1) | |
if self.ex: | |
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]] | |
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]] | |
allmasks.append(mask) | |
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask | |
basemasks = [1 - mask for mask in basemasks] | |
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks] | |
allmasks = basemasks + allmasks | |
self.attnmasks[hw] = torch.cat(allmasks) | |
self.maskready = True | |
return latents | |
def hook_forward(module): | |
# diffusers==0.23.2 | |
def forward( | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
scale: float = 1.0, | |
) -> torch.Tensor: | |
attn = module | |
xshape = hidden_states.shape | |
self.hw = (h, w) = split_dims(xshape[1], *orig_hw) | |
if revers: | |
nx, px = hidden_states.chunk(2) | |
else: | |
px, nx = hidden_states.chunk(2) | |
if equal: | |
hidden_states = torch.cat( | |
[px for i in range(regions)] + [nx for i in range(regions)], | |
0, | |
) | |
encoder_hidden_states = torch.cat([conds] + [unconds]) | |
else: | |
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0) | |
encoder_hidden_states = torch.cat([conds] + [unconds]) | |
residual = hidden_states | |
args = () if USE_PEFT_BACKEND else (scale,) | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
args = () if USE_PEFT_BACKEND else (scale,) | |
query = attn.to_q(hidden_states, *args) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states, *args) | |
value = attn.to_v(encoder_hidden_states, *args) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = scaled_dot_product_attention( | |
self, | |
query, | |
key, | |
value, | |
attn_mask=attention_mask, | |
dropout_p=0.0, | |
is_causal=False, | |
getattn="PRO" in mode, | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states, *args) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
#### Regional Prompting Col/Row mode | |
if any(x in mode for x in ["COL", "ROW"]): | |
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2]) | |
center = reshaped.shape[0] // 2 | |
px = reshaped[0:center] if equal else reshaped[0:-batch] | |
nx = reshaped[center:] if equal else reshaped[-batch:] | |
outs = [px, nx] if equal else [px] | |
for out in outs: | |
c = 0 | |
for i, ocell in enumerate(ocells): | |
for icell in icells[i]: | |
if "ROW" in mode: | |
out[ | |
0:batch, | |
int(h * ocell[0]) : int(h * ocell[1]), | |
int(w * icell[0]) : int(w * icell[1]), | |
:, | |
] = out[ | |
c * batch : (c + 1) * batch, | |
int(h * ocell[0]) : int(h * ocell[1]), | |
int(w * icell[0]) : int(w * icell[1]), | |
:, | |
] | |
else: | |
out[ | |
0:batch, | |
int(h * icell[0]) : int(h * icell[1]), | |
int(w * ocell[0]) : int(w * ocell[1]), | |
:, | |
] = out[ | |
c * batch : (c + 1) * batch, | |
int(h * icell[0]) : int(h * icell[1]), | |
int(w * ocell[0]) : int(w * ocell[1]), | |
:, | |
] | |
c += 1 | |
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) | |
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) | |
hidden_states = hidden_states.reshape(xshape) | |
#### Regional Prompting Prompt mode | |
elif "PRO" in mode: | |
px, nx = ( | |
torch.chunk(hidden_states) if equal else hidden_states[0:-batch], | |
hidden_states[-batch:], | |
) | |
if (h, w) in self.attnmasks and self.maskready: | |
def mask(input): | |
out = torch.multiply(input, self.attnmasks[(h, w)]) | |
for b in range(batch): | |
for r in range(1, regions): | |
out[b] = out[b] + out[r * batch + b] | |
return out | |
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx) | |
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx) | |
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) | |
return hidden_states | |
return forward | |
def hook_forwards(root_module: torch.nn.Module): | |
for name, module in root_module.named_modules(): | |
if "attn2" in name and module.__class__.__name__ == "Attention": | |
module.forward = hook_forward(module) | |
hook_forwards(self.unet) | |
output = StableDiffusionPipeline(**self.components)( | |
prompt=prompt, | |
prompt_embeds=embs, | |
negative_prompt=negative_prompt, | |
negative_prompt_embeds=n_embs, | |
height=height, | |
width=width, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
num_images_per_prompt=num_images_per_prompt, | |
eta=eta, | |
generator=generator, | |
latents=latents, | |
output_type=output_type, | |
return_dict=return_dict, | |
callback_on_step_end=pcallback, | |
) | |
if "save_mask" in rp_args: | |
save_mask = rp_args["save_mask"] | |
else: | |
save_mask = False | |
if mode == "PROMPT" and save_mask: | |
saveattnmaps( | |
self, | |
output, | |
height, | |
width, | |
thresholds, | |
num_inference_steps // 2, | |
regions, | |
) | |
return output | |
### Make prompt list for each regions | |
def promptsmaker(prompts, batch): | |
out_p = [] | |
plen = len(prompts) | |
for prompt in prompts: | |
add = "" | |
if KCOMM in prompt: | |
add, prompt = prompt.split(KCOMM) | |
add = add + " " | |
prompts = prompt.split(KBRK) | |
out_p.append([add + p for p in prompts]) | |
out = [None] * batch * len(out_p[0]) * len(out_p) | |
for p, prs in enumerate(out_p): # inputs prompts | |
for r, pr in enumerate(prs): # prompts for regions | |
start = (p + r * plen) * batch | |
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1... | |
return out, out_p | |
### make regions from ratios | |
### ";" makes outercells, "," makes inner cells | |
def make_cells(ratios): | |
if ";" not in ratios and "," in ratios: | |
ratios = ratios.replace(",", ";") | |
ratios = ratios.split(";") | |
ratios = [inratios.split(",") for inratios in ratios] | |
icells = [] | |
ocells = [] | |
def startend(cells, array): | |
current_start = 0 | |
array = [float(x) for x in array] | |
for value in array: | |
end = current_start + (value / sum(array)) | |
cells.append([current_start, end]) | |
current_start = end | |
startend(ocells, [r[0] for r in ratios]) | |
for inratios in ratios: | |
if 2 > len(inratios): | |
icells.append([[0, 1]]) | |
else: | |
add = [] | |
startend(add, inratios[1:]) | |
icells.append(add) | |
return ocells, icells, sum(len(cell) for cell in icells) | |
def make_emblist(self, prompts): | |
with torch.no_grad(): | |
tokens = self.tokenizer( | |
prompts, | |
max_length=self.tokenizer.model_max_length, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
).input_ids.to(self.device) | |
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype) | |
return embs | |
def split_dims(xs, height, width): | |
xs = xs | |
def repeat_div(x, y): | |
while y > 0: | |
x = math.ceil(x / 2) | |
y = y - 1 | |
return x | |
scale = math.ceil(math.log2(math.sqrt(height * width / xs))) | |
dsh = repeat_div(height, scale) | |
dsw = repeat_div(width, scale) | |
return dsh, dsw | |
##### for prompt mode | |
def get_attn_maps(self, attn): | |
height, width = self.hw | |
target_tokens = self.target_tokens | |
if (height, width) not in self.attnmaps_sizes: | |
self.attnmaps_sizes.append((height, width)) | |
for b in range(self.batch): | |
for t in target_tokens: | |
power = self.power | |
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1) | |
add = torch.sum(add, dim=2) | |
key = f"{t}-{b}" | |
if key not in self.attnmaps: | |
self.attnmaps[key] = add | |
else: | |
if self.attnmaps[key].shape[1] != add.shape[1]: | |
add = add.view(8, height, width) | |
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None) | |
add = add.reshape_as(self.attnmaps[key]) | |
self.attnmaps[key] = self.attnmaps[key] + add | |
def reset_attnmaps(self): # init parameters in every batch | |
self.step = 0 | |
self.attnmaps = {} # maked from attention maps | |
self.attnmaps_sizes = [] # height,width set of u-net blocks | |
self.attnmasks = {} # maked from attnmaps for regions | |
self.maskready = False | |
self.history = {} | |
def saveattnmaps(self, output, h, w, th, step, regions): | |
masks = [] | |
for i, mask in enumerate(self.history[step].values()): | |
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step) | |
if self.ex: | |
masks = [x - mask for x in masks] | |
masks.append(mask) | |
if len(masks) == regions - 1: | |
output.images.extend([FF.to_pil_image(mask) for mask in masks]) | |
masks = [] | |
else: | |
output.images.append(img) | |
def makepmask( | |
self, mask, h, w, th, step | |
): # make masks from attention cache return [for preview, for attention, for Latent] | |
th = th - step * 0.005 | |
if 0.05 >= th: | |
th = 0.05 | |
mask = torch.mean(mask, dim=0) | |
mask = mask / mask.max().item() | |
mask = torch.where(mask > th, 1, 0) | |
mask = mask.float() | |
mask = mask.view(1, *self.attnmaps_sizes[0]) | |
img = FF.to_pil_image(mask) | |
img = img.resize((w, h)) | |
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None) | |
lmask = mask | |
mask = mask.reshape(h * w) | |
mask = torch.where(mask > 0.1, 1, 0) | |
return img, mask, lmask | |
def tokendealer(self, all_prompts): | |
for prompts in all_prompts: | |
targets = [p.split(",")[-1] for p in prompts[1:]] | |
tt = [] | |
for target in targets: | |
ptokens = ( | |
self.tokenizer( | |
prompts, | |
max_length=self.tokenizer.model_max_length, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
).input_ids | |
)[0] | |
ttokens = ( | |
self.tokenizer( | |
target, | |
max_length=self.tokenizer.model_max_length, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
).input_ids | |
)[0] | |
tlist = [] | |
for t in range(ttokens.shape[0] - 2): | |
for p in range(ptokens.shape[0]): | |
if ttokens[t + 1] == ptokens[p]: | |
tlist.append(p) | |
if tlist != []: | |
tt.append(tlist) | |
return tt | |
def scaled_dot_product_attention( | |
self, | |
query, | |
key, | |
value, | |
attn_mask=None, | |
dropout_p=0.0, | |
is_causal=False, | |
scale=None, | |
getattn=False, | |
) -> torch.Tensor: | |
# Efficient implementation equivalent to the following: | |
L, S = query.size(-2), key.size(-2) | |
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device) | |
if is_causal: | |
assert attn_mask is None | |
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
attn_bias.to(query.dtype) | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
else: | |
attn_bias += attn_mask | |
attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
attn_weight += attn_bias | |
attn_weight = torch.softmax(attn_weight, dim=-1) | |
if getattn: | |
get_attn_maps(self, attn_weight) | |
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | |
return attn_weight @ value | |