Spaces:
Paused
Paused
Upload 25 files
Browse files- library/__init__.py +0 -0
- library/adafactor_fused.py +106 -0
- library/attention_processors.py +227 -0
- library/config_util.py +721 -0
- library/custom_train_functions.py +559 -0
- library/deepspeed_utils.py +139 -0
- library/device_utils.py +84 -0
- library/huggingface_util.py +84 -0
- library/hypernetwork.py +223 -0
- library/ipex/__init__.py +180 -0
- library/ipex/attention.py +177 -0
- library/ipex/diffusers.py +312 -0
- library/ipex/gradscaler.py +183 -0
- library/ipex/hijacks.py +313 -0
- library/lpw_stable_diffusion.py +1233 -0
- library/model_util.py +1356 -0
- library/original_unet.py +1919 -0
- library/sai_model_spec.py +309 -0
- library/sdxl_lpw_stable_diffusion.py +1347 -0
- library/sdxl_model_util.py +583 -0
- library/sdxl_original_unet.py +1286 -0
- library/sdxl_train_util.py +381 -0
- library/slicing_vae.py +682 -0
- library/train_util.py +0 -0
- library/utils.py +287 -0
library/__init__.py
ADDED
|
File without changes
|
library/adafactor_fused.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Adafactor
|
| 4 |
+
|
| 5 |
+
@torch.no_grad()
|
| 6 |
+
def adafactor_step_param(self, p, group):
|
| 7 |
+
if p.grad is None:
|
| 8 |
+
return
|
| 9 |
+
grad = p.grad
|
| 10 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 11 |
+
grad = grad.float()
|
| 12 |
+
if grad.is_sparse:
|
| 13 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
| 14 |
+
|
| 15 |
+
state = self.state[p]
|
| 16 |
+
grad_shape = grad.shape
|
| 17 |
+
|
| 18 |
+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
| 19 |
+
# State Initialization
|
| 20 |
+
if len(state) == 0:
|
| 21 |
+
state["step"] = 0
|
| 22 |
+
|
| 23 |
+
if use_first_moment:
|
| 24 |
+
# Exponential moving average of gradient values
|
| 25 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
| 26 |
+
if factored:
|
| 27 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
| 28 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
| 29 |
+
else:
|
| 30 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
| 31 |
+
|
| 32 |
+
state["RMS"] = 0
|
| 33 |
+
else:
|
| 34 |
+
if use_first_moment:
|
| 35 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
| 36 |
+
if factored:
|
| 37 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
| 38 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
| 39 |
+
else:
|
| 40 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
| 41 |
+
|
| 42 |
+
p_data_fp32 = p
|
| 43 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 44 |
+
p_data_fp32 = p_data_fp32.float()
|
| 45 |
+
|
| 46 |
+
state["step"] += 1
|
| 47 |
+
state["RMS"] = Adafactor._rms(p_data_fp32)
|
| 48 |
+
lr = Adafactor._get_lr(group, state)
|
| 49 |
+
|
| 50 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
| 51 |
+
update = (grad ** 2) + group["eps"][0]
|
| 52 |
+
if factored:
|
| 53 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
| 54 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
| 55 |
+
|
| 56 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
| 57 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
| 58 |
+
|
| 59 |
+
# Approximation of exponential moving average of square of gradient
|
| 60 |
+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
| 61 |
+
update.mul_(grad)
|
| 62 |
+
else:
|
| 63 |
+
exp_avg_sq = state["exp_avg_sq"]
|
| 64 |
+
|
| 65 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
| 66 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
| 67 |
+
|
| 68 |
+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
| 69 |
+
update.mul_(lr)
|
| 70 |
+
|
| 71 |
+
if use_first_moment:
|
| 72 |
+
exp_avg = state["exp_avg"]
|
| 73 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
| 74 |
+
update = exp_avg
|
| 75 |
+
|
| 76 |
+
if group["weight_decay"] != 0:
|
| 77 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
| 78 |
+
|
| 79 |
+
p_data_fp32.add_(-update)
|
| 80 |
+
|
| 81 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 82 |
+
p.copy_(p_data_fp32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def adafactor_step(self, closure=None):
|
| 87 |
+
"""
|
| 88 |
+
Performs a single optimization step
|
| 89 |
+
|
| 90 |
+
Arguments:
|
| 91 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 92 |
+
and returns the loss.
|
| 93 |
+
"""
|
| 94 |
+
loss = None
|
| 95 |
+
if closure is not None:
|
| 96 |
+
loss = closure()
|
| 97 |
+
|
| 98 |
+
for group in self.param_groups:
|
| 99 |
+
for p in group["params"]:
|
| 100 |
+
adafactor_step_param(self, p, group)
|
| 101 |
+
|
| 102 |
+
return loss
|
| 103 |
+
|
| 104 |
+
def patch_adafactor_fused(optimizer: Adafactor):
|
| 105 |
+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
| 106 |
+
optimizer.step = adafactor_step.__get__(optimizer)
|
library/attention_processors.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# flash attention forwards and backwards
|
| 9 |
+
|
| 10 |
+
# https://arxiv.org/abs/2205.14135
|
| 11 |
+
|
| 12 |
+
EPSILON = 1e-6
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
| 16 |
+
@staticmethod
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
| 19 |
+
"""Algorithm 2 in the paper"""
|
| 20 |
+
|
| 21 |
+
device = q.device
|
| 22 |
+
dtype = q.dtype
|
| 23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 25 |
+
|
| 26 |
+
o = torch.zeros_like(q)
|
| 27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
| 28 |
+
all_row_maxes = torch.full(
|
| 29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
scale = q.shape[-1] ** -0.5
|
| 33 |
+
|
| 34 |
+
if mask is None:
|
| 35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
| 36 |
+
else:
|
| 37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
| 38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
| 39 |
+
|
| 40 |
+
row_splits = zip(
|
| 41 |
+
q.split(q_bucket_size, dim=-2),
|
| 42 |
+
o.split(q_bucket_size, dim=-2),
|
| 43 |
+
mask,
|
| 44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
| 45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
| 49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 50 |
+
|
| 51 |
+
col_splits = zip(
|
| 52 |
+
k.split(k_bucket_size, dim=-2),
|
| 53 |
+
v.split(k_bucket_size, dim=-2),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
| 57 |
+
k_start_index = k_ind * k_bucket_size
|
| 58 |
+
|
| 59 |
+
attn_weights = (
|
| 60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if row_mask is not None:
|
| 64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
| 65 |
+
|
| 66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 67 |
+
causal_mask = torch.ones(
|
| 68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
| 69 |
+
).triu(q_start_index - k_start_index + 1)
|
| 70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 71 |
+
|
| 72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
| 73 |
+
attn_weights -= block_row_maxes
|
| 74 |
+
exp_weights = torch.exp(attn_weights)
|
| 75 |
+
|
| 76 |
+
if row_mask is not None:
|
| 77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
| 78 |
+
|
| 79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
| 80 |
+
min=EPSILON
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
| 84 |
+
|
| 85 |
+
exp_values = torch.einsum(
|
| 86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
| 90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
| 91 |
+
|
| 92 |
+
new_row_sums = (
|
| 93 |
+
exp_row_max_diff * row_sums
|
| 94 |
+
+ exp_block_row_max_diff * block_row_sums
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
| 98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
row_maxes.copy_(new_row_maxes)
|
| 102 |
+
row_sums.copy_(new_row_sums)
|
| 103 |
+
|
| 104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
| 105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
| 106 |
+
|
| 107 |
+
return o
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def backward(ctx, do):
|
| 112 |
+
"""Algorithm 4 in the paper"""
|
| 113 |
+
|
| 114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
| 115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
| 116 |
+
|
| 117 |
+
device = q.device
|
| 118 |
+
|
| 119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 121 |
+
|
| 122 |
+
dq = torch.zeros_like(q)
|
| 123 |
+
dk = torch.zeros_like(k)
|
| 124 |
+
dv = torch.zeros_like(v)
|
| 125 |
+
|
| 126 |
+
row_splits = zip(
|
| 127 |
+
q.split(q_bucket_size, dim=-2),
|
| 128 |
+
o.split(q_bucket_size, dim=-2),
|
| 129 |
+
do.split(q_bucket_size, dim=-2),
|
| 130 |
+
mask,
|
| 131 |
+
l.split(q_bucket_size, dim=-2),
|
| 132 |
+
m.split(q_bucket_size, dim=-2),
|
| 133 |
+
dq.split(q_bucket_size, dim=-2),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
| 137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 138 |
+
|
| 139 |
+
col_splits = zip(
|
| 140 |
+
k.split(k_bucket_size, dim=-2),
|
| 141 |
+
v.split(k_bucket_size, dim=-2),
|
| 142 |
+
dk.split(k_bucket_size, dim=-2),
|
| 143 |
+
dv.split(k_bucket_size, dim=-2),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
| 147 |
+
k_start_index = k_ind * k_bucket_size
|
| 148 |
+
|
| 149 |
+
attn_weights = (
|
| 150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 154 |
+
causal_mask = torch.ones(
|
| 155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
| 156 |
+
).triu(q_start_index - k_start_index + 1)
|
| 157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 158 |
+
|
| 159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
| 160 |
+
|
| 161 |
+
if row_mask is not None:
|
| 162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
| 163 |
+
|
| 164 |
+
p = exp_attn_weights / lc
|
| 165 |
+
|
| 166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
| 167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
| 168 |
+
|
| 169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
| 170 |
+
ds = p * scale * (dp - D)
|
| 171 |
+
|
| 172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
| 173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
| 174 |
+
|
| 175 |
+
dqc.add_(dq_chunk)
|
| 176 |
+
dkc.add_(dk_chunk)
|
| 177 |
+
dvc.add_(dv_chunk)
|
| 178 |
+
|
| 179 |
+
return dq, dk, dv, None, None, None, None
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class FlashAttnProcessor:
|
| 183 |
+
def __call__(
|
| 184 |
+
self,
|
| 185 |
+
attn: Attention,
|
| 186 |
+
hidden_states,
|
| 187 |
+
encoder_hidden_states=None,
|
| 188 |
+
attention_mask=None,
|
| 189 |
+
) -> Any:
|
| 190 |
+
q_bucket_size = 512
|
| 191 |
+
k_bucket_size = 1024
|
| 192 |
+
|
| 193 |
+
h = attn.heads
|
| 194 |
+
q = attn.to_q(hidden_states)
|
| 195 |
+
|
| 196 |
+
encoder_hidden_states = (
|
| 197 |
+
encoder_hidden_states
|
| 198 |
+
if encoder_hidden_states is not None
|
| 199 |
+
else hidden_states
|
| 200 |
+
)
|
| 201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
| 202 |
+
|
| 203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
| 204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
| 205 |
+
hidden_states, encoder_hidden_states
|
| 206 |
+
)
|
| 207 |
+
context_k = context_k.to(hidden_states.dtype)
|
| 208 |
+
context_v = context_v.to(hidden_states.dtype)
|
| 209 |
+
else:
|
| 210 |
+
context_k = encoder_hidden_states
|
| 211 |
+
context_v = encoder_hidden_states
|
| 212 |
+
|
| 213 |
+
k = attn.to_k(context_k)
|
| 214 |
+
v = attn.to_v(context_v)
|
| 215 |
+
del encoder_hidden_states, hidden_states
|
| 216 |
+
|
| 217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 218 |
+
|
| 219 |
+
out = FlashAttentionFunction.apply(
|
| 220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 224 |
+
|
| 225 |
+
out = attn.to_out[0](out)
|
| 226 |
+
out = attn.to_out[1](out)
|
| 227 |
+
return out
|
library/config_util.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from dataclasses import (
|
| 3 |
+
asdict,
|
| 4 |
+
dataclass,
|
| 5 |
+
)
|
| 6 |
+
import functools
|
| 7 |
+
import random
|
| 8 |
+
from textwrap import dedent, indent
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# from toolz import curry
|
| 13 |
+
from typing import (
|
| 14 |
+
List,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Tuple,
|
| 18 |
+
Union,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import toml
|
| 22 |
+
import voluptuous
|
| 23 |
+
from voluptuous import (
|
| 24 |
+
Any,
|
| 25 |
+
ExactSequence,
|
| 26 |
+
MultipleInvalid,
|
| 27 |
+
Object,
|
| 28 |
+
Required,
|
| 29 |
+
Schema,
|
| 30 |
+
)
|
| 31 |
+
from transformers import CLIPTokenizer
|
| 32 |
+
|
| 33 |
+
from . import train_util
|
| 34 |
+
from .train_util import (
|
| 35 |
+
DreamBoothSubset,
|
| 36 |
+
FineTuningSubset,
|
| 37 |
+
ControlNetSubset,
|
| 38 |
+
DreamBoothDataset,
|
| 39 |
+
FineTuningDataset,
|
| 40 |
+
ControlNetDataset,
|
| 41 |
+
DatasetGroup,
|
| 42 |
+
)
|
| 43 |
+
from .utils import setup_logging
|
| 44 |
+
|
| 45 |
+
setup_logging()
|
| 46 |
+
import logging
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# TODO: inherit Params class in Subset, Dataset
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class BaseSubsetParams:
|
| 62 |
+
image_dir: Optional[str] = None
|
| 63 |
+
num_repeats: int = 1
|
| 64 |
+
shuffle_caption: bool = False
|
| 65 |
+
caption_separator: str = (",",)
|
| 66 |
+
keep_tokens: int = 0
|
| 67 |
+
keep_tokens_separator: str = (None,)
|
| 68 |
+
secondary_separator: Optional[str] = None
|
| 69 |
+
enable_wildcard: bool = False
|
| 70 |
+
color_aug: bool = False
|
| 71 |
+
flip_aug: bool = False
|
| 72 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
| 73 |
+
random_crop: bool = False
|
| 74 |
+
caption_prefix: Optional[str] = None
|
| 75 |
+
caption_suffix: Optional[str] = None
|
| 76 |
+
caption_dropout_rate: float = 0.0
|
| 77 |
+
caption_dropout_every_n_epochs: int = 0
|
| 78 |
+
caption_tag_dropout_rate: float = 0.0
|
| 79 |
+
token_warmup_min: int = 1
|
| 80 |
+
token_warmup_step: float = 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
| 85 |
+
is_reg: bool = False
|
| 86 |
+
class_tokens: Optional[str] = None
|
| 87 |
+
caption_extension: str = ".caption"
|
| 88 |
+
cache_info: bool = False
|
| 89 |
+
alpha_mask: bool = False
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
| 94 |
+
metadata_file: Optional[str] = None
|
| 95 |
+
alpha_mask: bool = False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
| 100 |
+
conditioning_data_dir: str = None
|
| 101 |
+
caption_extension: str = ".caption"
|
| 102 |
+
cache_info: bool = False
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class BaseDatasetParams:
|
| 107 |
+
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
| 108 |
+
max_token_length: int = None
|
| 109 |
+
resolution: Optional[Tuple[int, int]] = None
|
| 110 |
+
network_multiplier: float = 1.0
|
| 111 |
+
debug_dataset: bool = False
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
| 116 |
+
batch_size: int = 1
|
| 117 |
+
enable_bucket: bool = False
|
| 118 |
+
min_bucket_reso: int = 256
|
| 119 |
+
max_bucket_reso: int = 1024
|
| 120 |
+
bucket_reso_steps: int = 64
|
| 121 |
+
bucket_no_upscale: bool = False
|
| 122 |
+
prior_loss_weight: float = 1.0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
| 127 |
+
batch_size: int = 1
|
| 128 |
+
enable_bucket: bool = False
|
| 129 |
+
min_bucket_reso: int = 256
|
| 130 |
+
max_bucket_reso: int = 1024
|
| 131 |
+
bucket_reso_steps: int = 64
|
| 132 |
+
bucket_no_upscale: bool = False
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
| 137 |
+
batch_size: int = 1
|
| 138 |
+
enable_bucket: bool = False
|
| 139 |
+
min_bucket_reso: int = 256
|
| 140 |
+
max_bucket_reso: int = 1024
|
| 141 |
+
bucket_reso_steps: int = 64
|
| 142 |
+
bucket_no_upscale: bool = False
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class SubsetBlueprint:
|
| 147 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class DatasetBlueprint:
|
| 152 |
+
is_dreambooth: bool
|
| 153 |
+
is_controlnet: bool
|
| 154 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
| 155 |
+
subsets: Sequence[SubsetBlueprint]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class DatasetGroupBlueprint:
|
| 160 |
+
datasets: Sequence[DatasetBlueprint]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class Blueprint:
|
| 165 |
+
dataset_group: DatasetGroupBlueprint
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConfigSanitizer:
|
| 169 |
+
# @curry
|
| 170 |
+
@staticmethod
|
| 171 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
| 172 |
+
Schema(ExactSequence([klass, klass]))(value)
|
| 173 |
+
return tuple(value)
|
| 174 |
+
|
| 175 |
+
# @curry
|
| 176 |
+
@staticmethod
|
| 177 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
| 178 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
| 179 |
+
try:
|
| 180 |
+
Schema(klass)(value)
|
| 181 |
+
return (value, value)
|
| 182 |
+
except:
|
| 183 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
| 184 |
+
|
| 185 |
+
# subset schema
|
| 186 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
| 187 |
+
"color_aug": bool,
|
| 188 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
| 189 |
+
"flip_aug": bool,
|
| 190 |
+
"num_repeats": int,
|
| 191 |
+
"random_crop": bool,
|
| 192 |
+
"shuffle_caption": bool,
|
| 193 |
+
"keep_tokens": int,
|
| 194 |
+
"keep_tokens_separator": str,
|
| 195 |
+
"secondary_separator": str,
|
| 196 |
+
"caption_separator": str,
|
| 197 |
+
"enable_wildcard": bool,
|
| 198 |
+
"token_warmup_min": int,
|
| 199 |
+
"token_warmup_step": Any(float, int),
|
| 200 |
+
"caption_prefix": str,
|
| 201 |
+
"caption_suffix": str,
|
| 202 |
+
}
|
| 203 |
+
# DO means DropOut
|
| 204 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
| 205 |
+
"caption_dropout_every_n_epochs": int,
|
| 206 |
+
"caption_dropout_rate": Any(float, int),
|
| 207 |
+
"caption_tag_dropout_rate": Any(float, int),
|
| 208 |
+
}
|
| 209 |
+
# DB means DreamBooth
|
| 210 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
| 211 |
+
"caption_extension": str,
|
| 212 |
+
"class_tokens": str,
|
| 213 |
+
"cache_info": bool,
|
| 214 |
+
}
|
| 215 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
| 216 |
+
Required("image_dir"): str,
|
| 217 |
+
"is_reg": bool,
|
| 218 |
+
"alpha_mask": bool,
|
| 219 |
+
}
|
| 220 |
+
# FT means FineTuning
|
| 221 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
| 222 |
+
Required("metadata_file"): str,
|
| 223 |
+
"image_dir": str,
|
| 224 |
+
"alpha_mask": bool,
|
| 225 |
+
}
|
| 226 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
| 227 |
+
"caption_extension": str,
|
| 228 |
+
"cache_info": bool,
|
| 229 |
+
}
|
| 230 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
| 231 |
+
Required("image_dir"): str,
|
| 232 |
+
Required("conditioning_data_dir"): str,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# datasets schema
|
| 236 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
| 237 |
+
"batch_size": int,
|
| 238 |
+
"bucket_no_upscale": bool,
|
| 239 |
+
"bucket_reso_steps": int,
|
| 240 |
+
"enable_bucket": bool,
|
| 241 |
+
"max_bucket_reso": int,
|
| 242 |
+
"min_bucket_reso": int,
|
| 243 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
| 244 |
+
"network_multiplier": float,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# options handled by argparse but not handled by user config
|
| 248 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
| 249 |
+
"debug_dataset": bool,
|
| 250 |
+
"max_token_length": Any(None, int),
|
| 251 |
+
"prior_loss_weight": Any(float, int),
|
| 252 |
+
}
|
| 253 |
+
# for handling default None value of argparse
|
| 254 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
| 255 |
+
"face_crop_aug_range",
|
| 256 |
+
"resolution",
|
| 257 |
+
]
|
| 258 |
+
# prepare map because option name may differ among argparse and user config
|
| 259 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
| 260 |
+
"train_batch_size": "batch_size",
|
| 261 |
+
"dataset_repeats": "num_repeats",
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
| 265 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
| 266 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
| 267 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.db_subset_schema = self.__merge_dict(
|
| 271 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 272 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
| 273 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.ft_subset_schema = self.__merge_dict(
|
| 278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 279 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
| 280 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.cn_subset_schema = self.__merge_dict(
|
| 284 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 285 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
| 286 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 287 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.db_dataset_schema = self.__merge_dict(
|
| 291 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 292 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 293 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 294 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 295 |
+
{"subsets": [self.db_subset_schema]},
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.ft_dataset_schema = self.__merge_dict(
|
| 299 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 300 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 301 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 302 |
+
{"subsets": [self.ft_subset_schema]},
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.cn_dataset_schema = self.__merge_dict(
|
| 306 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 307 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 308 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 309 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 310 |
+
{"subsets": [self.cn_subset_schema]},
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if support_dreambooth and support_finetuning:
|
| 314 |
+
|
| 315 |
+
def validate_flex_dataset(dataset_config: dict):
|
| 316 |
+
subsets_config = dataset_config.get("subsets", [])
|
| 317 |
+
|
| 318 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
| 319 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
| 320 |
+
# check dataset meets FT style
|
| 321 |
+
# NOTE: all FT subsets should have "metadata_file"
|
| 322 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
| 323 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
| 324 |
+
# check dataset meets DB style
|
| 325 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
| 326 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
| 327 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
| 328 |
+
else:
|
| 329 |
+
raise voluptuous.Invalid(
|
| 330 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
self.dataset_schema = validate_flex_dataset
|
| 334 |
+
elif support_dreambooth:
|
| 335 |
+
if support_controlnet:
|
| 336 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 337 |
+
else:
|
| 338 |
+
self.dataset_schema = self.db_dataset_schema
|
| 339 |
+
elif support_finetuning:
|
| 340 |
+
self.dataset_schema = self.ft_dataset_schema
|
| 341 |
+
elif support_controlnet:
|
| 342 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 343 |
+
|
| 344 |
+
self.general_schema = self.__merge_dict(
|
| 345 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 346 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 347 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
| 348 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
| 349 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self.user_config_validator = Schema(
|
| 353 |
+
{
|
| 354 |
+
"general": self.general_schema,
|
| 355 |
+
"datasets": [self.dataset_schema],
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.argparse_schema = self.__merge_dict(
|
| 360 |
+
self.general_schema,
|
| 361 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
| 362 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
| 363 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
| 367 |
+
|
| 368 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
| 369 |
+
try:
|
| 370 |
+
return self.user_config_validator(user_config)
|
| 371 |
+
except MultipleInvalid:
|
| 372 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
| 373 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
| 377 |
+
# However this will help us to detect program bug
|
| 378 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
| 379 |
+
try:
|
| 380 |
+
return self.argparse_config_validator(argparse_namespace)
|
| 381 |
+
except MultipleInvalid:
|
| 382 |
+
# XXX: this should be a bug
|
| 383 |
+
logger.error(
|
| 384 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
| 385 |
+
)
|
| 386 |
+
raise
|
| 387 |
+
|
| 388 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
| 389 |
+
@staticmethod
|
| 390 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
| 391 |
+
merged = {}
|
| 392 |
+
for schema in dict_list:
|
| 393 |
+
# merged |= schema
|
| 394 |
+
for k, v in schema.items():
|
| 395 |
+
merged[k] = v
|
| 396 |
+
return merged
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class BlueprintGenerator:
|
| 400 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
| 401 |
+
|
| 402 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
| 403 |
+
self.sanitizer = sanitizer
|
| 404 |
+
|
| 405 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
| 406 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
| 407 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
| 408 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
| 409 |
+
|
| 410 |
+
# convert argparse namespace to dict like config
|
| 411 |
+
# NOTE: it is ok to have extra entries in dict
|
| 412 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
| 413 |
+
argparse_config = {
|
| 414 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
general_config = sanitized_user_config.get("general", {})
|
| 418 |
+
|
| 419 |
+
dataset_blueprints = []
|
| 420 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
| 421 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
| 422 |
+
subsets = dataset_config.get("subsets", [])
|
| 423 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
| 424 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
| 425 |
+
if is_controlnet:
|
| 426 |
+
subset_params_klass = ControlNetSubsetParams
|
| 427 |
+
dataset_params_klass = ControlNetDatasetParams
|
| 428 |
+
elif is_dreambooth:
|
| 429 |
+
subset_params_klass = DreamBoothSubsetParams
|
| 430 |
+
dataset_params_klass = DreamBoothDatasetParams
|
| 431 |
+
else:
|
| 432 |
+
subset_params_klass = FineTuningSubsetParams
|
| 433 |
+
dataset_params_klass = FineTuningDatasetParams
|
| 434 |
+
|
| 435 |
+
subset_blueprints = []
|
| 436 |
+
for subset_config in subsets:
|
| 437 |
+
params = self.generate_params_by_fallbacks(
|
| 438 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
| 439 |
+
)
|
| 440 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
| 441 |
+
|
| 442 |
+
params = self.generate_params_by_fallbacks(
|
| 443 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
| 444 |
+
)
|
| 445 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
| 446 |
+
|
| 447 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
| 448 |
+
|
| 449 |
+
return Blueprint(dataset_group_blueprint)
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
| 453 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
| 454 |
+
search_value = BlueprintGenerator.search_value
|
| 455 |
+
default_params = asdict(param_klass())
|
| 456 |
+
param_names = default_params.keys()
|
| 457 |
+
|
| 458 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
| 459 |
+
|
| 460 |
+
return param_klass(**params)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
| 464 |
+
for cand in fallbacks:
|
| 465 |
+
value = cand.get(key)
|
| 466 |
+
if value is not None:
|
| 467 |
+
return value
|
| 468 |
+
|
| 469 |
+
return default_value
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
| 473 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
| 474 |
+
|
| 475 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
| 476 |
+
if dataset_blueprint.is_controlnet:
|
| 477 |
+
subset_klass = ControlNetSubset
|
| 478 |
+
dataset_klass = ControlNetDataset
|
| 479 |
+
elif dataset_blueprint.is_dreambooth:
|
| 480 |
+
subset_klass = DreamBoothSubset
|
| 481 |
+
dataset_klass = DreamBoothDataset
|
| 482 |
+
else:
|
| 483 |
+
subset_klass = FineTuningSubset
|
| 484 |
+
dataset_klass = FineTuningDataset
|
| 485 |
+
|
| 486 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
| 487 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
| 488 |
+
datasets.append(dataset)
|
| 489 |
+
|
| 490 |
+
# print info
|
| 491 |
+
info = ""
|
| 492 |
+
for i, dataset in enumerate(datasets):
|
| 493 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
| 494 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
| 495 |
+
info += dedent(
|
| 496 |
+
f"""\
|
| 497 |
+
[Dataset {i}]
|
| 498 |
+
batch_size: {dataset.batch_size}
|
| 499 |
+
resolution: {(dataset.width, dataset.height)}
|
| 500 |
+
enable_bucket: {dataset.enable_bucket}
|
| 501 |
+
network_multiplier: {dataset.network_multiplier}
|
| 502 |
+
"""
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if dataset.enable_bucket:
|
| 506 |
+
info += indent(
|
| 507 |
+
dedent(
|
| 508 |
+
f"""\
|
| 509 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
| 510 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
| 511 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
| 512 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
| 513 |
+
\n"""
|
| 514 |
+
),
|
| 515 |
+
" ",
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
info += "\n"
|
| 519 |
+
|
| 520 |
+
for j, subset in enumerate(dataset.subsets):
|
| 521 |
+
info += indent(
|
| 522 |
+
dedent(
|
| 523 |
+
f"""\
|
| 524 |
+
[Subset {j} of Dataset {i}]
|
| 525 |
+
image_dir: "{subset.image_dir}"
|
| 526 |
+
image_count: {subset.img_count}
|
| 527 |
+
num_repeats: {subset.num_repeats}
|
| 528 |
+
shuffle_caption: {subset.shuffle_caption}
|
| 529 |
+
keep_tokens: {subset.keep_tokens}
|
| 530 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
| 531 |
+
caption_separator: {subset.caption_separator}
|
| 532 |
+
secondary_separator: {subset.secondary_separator}
|
| 533 |
+
enable_wildcard: {subset.enable_wildcard}
|
| 534 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
| 535 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
| 536 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
| 537 |
+
caption_prefix: {subset.caption_prefix}
|
| 538 |
+
caption_suffix: {subset.caption_suffix}
|
| 539 |
+
color_aug: {subset.color_aug}
|
| 540 |
+
flip_aug: {subset.flip_aug}
|
| 541 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
| 542 |
+
random_crop: {subset.random_crop}
|
| 543 |
+
token_warmup_min: {subset.token_warmup_min},
|
| 544 |
+
token_warmup_step: {subset.token_warmup_step},
|
| 545 |
+
alpha_mask: {subset.alpha_mask},
|
| 546 |
+
"""
|
| 547 |
+
),
|
| 548 |
+
" ",
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if is_dreambooth:
|
| 552 |
+
info += indent(
|
| 553 |
+
dedent(
|
| 554 |
+
f"""\
|
| 555 |
+
is_reg: {subset.is_reg}
|
| 556 |
+
class_tokens: {subset.class_tokens}
|
| 557 |
+
caption_extension: {subset.caption_extension}
|
| 558 |
+
\n"""
|
| 559 |
+
),
|
| 560 |
+
" ",
|
| 561 |
+
)
|
| 562 |
+
elif not is_controlnet:
|
| 563 |
+
info += indent(
|
| 564 |
+
dedent(
|
| 565 |
+
f"""\
|
| 566 |
+
metadata_file: {subset.metadata_file}
|
| 567 |
+
\n"""
|
| 568 |
+
),
|
| 569 |
+
" ",
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
logger.info(f"{info}")
|
| 573 |
+
|
| 574 |
+
# make buckets first because it determines the length of dataset
|
| 575 |
+
# and set the same seed for all datasets
|
| 576 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
| 577 |
+
for i, dataset in enumerate(datasets):
|
| 578 |
+
logger.info(f"[Dataset {i}]")
|
| 579 |
+
dataset.make_buckets()
|
| 580 |
+
dataset.set_seed(seed)
|
| 581 |
+
|
| 582 |
+
return DatasetGroup(datasets)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
| 586 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
| 587 |
+
tokens = name.split("_")
|
| 588 |
+
try:
|
| 589 |
+
n_repeats = int(tokens[0])
|
| 590 |
+
except ValueError as e:
|
| 591 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
| 592 |
+
return 0, ""
|
| 593 |
+
caption_by_folder = "_".join(tokens[1:])
|
| 594 |
+
return n_repeats, caption_by_folder
|
| 595 |
+
|
| 596 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
| 597 |
+
if base_dir is None:
|
| 598 |
+
return []
|
| 599 |
+
|
| 600 |
+
base_dir: Path = Path(base_dir)
|
| 601 |
+
if not base_dir.is_dir():
|
| 602 |
+
return []
|
| 603 |
+
|
| 604 |
+
subsets_config = []
|
| 605 |
+
for subdir in base_dir.iterdir():
|
| 606 |
+
if not subdir.is_dir():
|
| 607 |
+
continue
|
| 608 |
+
|
| 609 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
| 610 |
+
if num_repeats < 1:
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
| 614 |
+
subsets_config.append(subset_config)
|
| 615 |
+
|
| 616 |
+
return subsets_config
|
| 617 |
+
|
| 618 |
+
subsets_config = []
|
| 619 |
+
subsets_config += generate(train_data_dir, False)
|
| 620 |
+
subsets_config += generate(reg_data_dir, True)
|
| 621 |
+
|
| 622 |
+
return subsets_config
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
| 626 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
| 627 |
+
):
|
| 628 |
+
def generate(base_dir: Optional[str]):
|
| 629 |
+
if base_dir is None:
|
| 630 |
+
return []
|
| 631 |
+
|
| 632 |
+
base_dir: Path = Path(base_dir)
|
| 633 |
+
if not base_dir.is_dir():
|
| 634 |
+
return []
|
| 635 |
+
|
| 636 |
+
subsets_config = []
|
| 637 |
+
subset_config = {
|
| 638 |
+
"image_dir": train_data_dir,
|
| 639 |
+
"conditioning_data_dir": conditioning_data_dir,
|
| 640 |
+
"caption_extension": caption_extension,
|
| 641 |
+
"num_repeats": 1,
|
| 642 |
+
}
|
| 643 |
+
subsets_config.append(subset_config)
|
| 644 |
+
|
| 645 |
+
return subsets_config
|
| 646 |
+
|
| 647 |
+
subsets_config = []
|
| 648 |
+
subsets_config += generate(train_data_dir)
|
| 649 |
+
|
| 650 |
+
return subsets_config
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def load_user_config(file: str) -> dict:
|
| 654 |
+
file: Path = Path(file)
|
| 655 |
+
if not file.is_file():
|
| 656 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
| 657 |
+
|
| 658 |
+
if file.name.lower().endswith(".json"):
|
| 659 |
+
try:
|
| 660 |
+
with open(file, "r") as f:
|
| 661 |
+
config = json.load(f)
|
| 662 |
+
except Exception:
|
| 663 |
+
logger.error(
|
| 664 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 665 |
+
)
|
| 666 |
+
raise
|
| 667 |
+
elif file.name.lower().endswith(".toml"):
|
| 668 |
+
try:
|
| 669 |
+
config = toml.load(file)
|
| 670 |
+
except Exception:
|
| 671 |
+
logger.error(
|
| 672 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 673 |
+
)
|
| 674 |
+
raise
|
| 675 |
+
else:
|
| 676 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
| 677 |
+
|
| 678 |
+
return config
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# for config test
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
parser = argparse.ArgumentParser()
|
| 684 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
| 685 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
| 686 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
| 687 |
+
parser.add_argument("--support_dropout", action="store_true")
|
| 688 |
+
parser.add_argument("dataset_config")
|
| 689 |
+
config_args, remain = parser.parse_known_args()
|
| 690 |
+
|
| 691 |
+
parser = argparse.ArgumentParser()
|
| 692 |
+
train_util.add_dataset_arguments(
|
| 693 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
| 694 |
+
)
|
| 695 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
| 696 |
+
argparse_namespace = parser.parse_args(remain)
|
| 697 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
| 698 |
+
|
| 699 |
+
logger.info("[argparse_namespace]")
|
| 700 |
+
logger.info(f"{vars(argparse_namespace)}")
|
| 701 |
+
|
| 702 |
+
user_config = load_user_config(config_args.dataset_config)
|
| 703 |
+
|
| 704 |
+
logger.info("")
|
| 705 |
+
logger.info("[user_config]")
|
| 706 |
+
logger.info(f"{user_config}")
|
| 707 |
+
|
| 708 |
+
sanitizer = ConfigSanitizer(
|
| 709 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
| 710 |
+
)
|
| 711 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
| 712 |
+
|
| 713 |
+
logger.info("")
|
| 714 |
+
logger.info("[sanitized_user_config]")
|
| 715 |
+
logger.info(f"{sanitized_user_config}")
|
| 716 |
+
|
| 717 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
| 718 |
+
|
| 719 |
+
logger.info("")
|
| 720 |
+
logger.info("[blueprint]")
|
| 721 |
+
logger.info(f"{blueprint}")
|
library/custom_train_functions.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
from .utils import setup_logging
|
| 7 |
+
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
| 15 |
+
if hasattr(noise_scheduler, "all_snr"):
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 19 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 20 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
| 21 |
+
alpha = sqrt_alphas_cumprod
|
| 22 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
| 23 |
+
all_snr = (alpha / sigma) ** 2
|
| 24 |
+
|
| 25 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
| 29 |
+
# fix beta: zero terminal SNR
|
| 30 |
+
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
| 31 |
+
|
| 32 |
+
def enforce_zero_terminal_snr(betas):
|
| 33 |
+
# Convert betas to alphas_bar_sqrt
|
| 34 |
+
alphas = 1 - betas
|
| 35 |
+
alphas_bar = alphas.cumprod(0)
|
| 36 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
| 37 |
+
|
| 38 |
+
# Store old values.
|
| 39 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 40 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 41 |
+
# Shift so last timestep is zero.
|
| 42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 43 |
+
# Scale so first timestep is back to old value.
|
| 44 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 45 |
+
|
| 46 |
+
# Convert alphas_bar_sqrt to betas
|
| 47 |
+
alphas_bar = alphas_bar_sqrt**2
|
| 48 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
| 49 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 50 |
+
betas = 1 - alphas
|
| 51 |
+
return betas
|
| 52 |
+
|
| 53 |
+
betas = noise_scheduler.betas
|
| 54 |
+
betas = enforce_zero_terminal_snr(betas)
|
| 55 |
+
alphas = 1.0 - betas
|
| 56 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 57 |
+
|
| 58 |
+
# logger.info(f"original: {noise_scheduler.betas}")
|
| 59 |
+
# logger.info(f"fixed: {betas}")
|
| 60 |
+
|
| 61 |
+
noise_scheduler.betas = betas
|
| 62 |
+
noise_scheduler.alphas = alphas
|
| 63 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
| 67 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
| 68 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
| 69 |
+
if v_prediction:
|
| 70 |
+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
| 71 |
+
else:
|
| 72 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
| 73 |
+
loss = loss * snr_weight
|
| 74 |
+
return loss
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
| 78 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 79 |
+
loss = loss * scale
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
| 84 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
| 85 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
| 86 |
+
scale = snr_t / (snr_t + 1)
|
| 87 |
+
# # show debug info
|
| 88 |
+
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
| 89 |
+
return scale
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
| 93 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 94 |
+
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
| 95 |
+
loss = loss + loss / scale * v_pred_like_loss
|
| 96 |
+
return loss
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
| 100 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
| 101 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
| 102 |
+
if v_prediction:
|
| 103 |
+
weight = 1 / (snr_t + 1)
|
| 104 |
+
else:
|
| 105 |
+
weight = 1 / torch.sqrt(snr_t)
|
| 106 |
+
loss = weight * loss
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--min_snr_gamma",
|
| 116 |
+
type=float,
|
| 117 |
+
default=None,
|
| 118 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--scale_v_pred_loss_like_noise_pred",
|
| 122 |
+
action="store_true",
|
| 123 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--v_pred_like_loss",
|
| 127 |
+
type=float,
|
| 128 |
+
default=None,
|
| 129 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--debiased_estimation_loss",
|
| 133 |
+
action="store_true",
|
| 134 |
+
help="debiased estimation loss / debiased estimation loss",
|
| 135 |
+
)
|
| 136 |
+
if support_weighted_captions:
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--weighted_captions",
|
| 139 |
+
action="store_true",
|
| 140 |
+
default=False,
|
| 141 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
re_attention = re.compile(
|
| 146 |
+
r"""
|
| 147 |
+
\\\(|
|
| 148 |
+
\\\)|
|
| 149 |
+
\\\[|
|
| 150 |
+
\\]|
|
| 151 |
+
\\\\|
|
| 152 |
+
\\|
|
| 153 |
+
\(|
|
| 154 |
+
\[|
|
| 155 |
+
:([+-]?[.\d]+)\)|
|
| 156 |
+
\)|
|
| 157 |
+
]|
|
| 158 |
+
[^\\()\[\]:]+|
|
| 159 |
+
:
|
| 160 |
+
""",
|
| 161 |
+
re.X,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def parse_prompt_attention(text):
|
| 166 |
+
"""
|
| 167 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 168 |
+
Accepted tokens are:
|
| 169 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 170 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 171 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 172 |
+
\( - literal character '('
|
| 173 |
+
\[ - literal character '['
|
| 174 |
+
\) - literal character ')'
|
| 175 |
+
\] - literal character ']'
|
| 176 |
+
\\ - literal character '\'
|
| 177 |
+
anything else - just text
|
| 178 |
+
>>> parse_prompt_attention('normal text')
|
| 179 |
+
[['normal text', 1.0]]
|
| 180 |
+
>>> parse_prompt_attention('an (important) word')
|
| 181 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 182 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 183 |
+
[['unbalanced', 1.1]]
|
| 184 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 185 |
+
[['(literal]', 1.0]]
|
| 186 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 187 |
+
[['unnecessaryparens', 1.1]]
|
| 188 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 189 |
+
[['a ', 1.0],
|
| 190 |
+
['house', 1.5730000000000004],
|
| 191 |
+
[' ', 1.1],
|
| 192 |
+
['on', 1.0],
|
| 193 |
+
[' a ', 1.1],
|
| 194 |
+
['hill', 0.55],
|
| 195 |
+
[', sun, ', 1.1],
|
| 196 |
+
['sky', 1.4641000000000006],
|
| 197 |
+
['.', 1.1]]
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
res = []
|
| 201 |
+
round_brackets = []
|
| 202 |
+
square_brackets = []
|
| 203 |
+
|
| 204 |
+
round_bracket_multiplier = 1.1
|
| 205 |
+
square_bracket_multiplier = 1 / 1.1
|
| 206 |
+
|
| 207 |
+
def multiply_range(start_position, multiplier):
|
| 208 |
+
for p in range(start_position, len(res)):
|
| 209 |
+
res[p][1] *= multiplier
|
| 210 |
+
|
| 211 |
+
for m in re_attention.finditer(text):
|
| 212 |
+
text = m.group(0)
|
| 213 |
+
weight = m.group(1)
|
| 214 |
+
|
| 215 |
+
if text.startswith("\\"):
|
| 216 |
+
res.append([text[1:], 1.0])
|
| 217 |
+
elif text == "(":
|
| 218 |
+
round_brackets.append(len(res))
|
| 219 |
+
elif text == "[":
|
| 220 |
+
square_brackets.append(len(res))
|
| 221 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 222 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 223 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 224 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 225 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 226 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 227 |
+
else:
|
| 228 |
+
res.append([text, 1.0])
|
| 229 |
+
|
| 230 |
+
for pos in round_brackets:
|
| 231 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 232 |
+
|
| 233 |
+
for pos in square_brackets:
|
| 234 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 235 |
+
|
| 236 |
+
if len(res) == 0:
|
| 237 |
+
res = [["", 1.0]]
|
| 238 |
+
|
| 239 |
+
# merge runs of identical weights
|
| 240 |
+
i = 0
|
| 241 |
+
while i + 1 < len(res):
|
| 242 |
+
if res[i][1] == res[i + 1][1]:
|
| 243 |
+
res[i][0] += res[i + 1][0]
|
| 244 |
+
res.pop(i + 1)
|
| 245 |
+
else:
|
| 246 |
+
i += 1
|
| 247 |
+
|
| 248 |
+
return res
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
| 252 |
+
r"""
|
| 253 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 254 |
+
|
| 255 |
+
No padding, starting or ending token is included.
|
| 256 |
+
"""
|
| 257 |
+
tokens = []
|
| 258 |
+
weights = []
|
| 259 |
+
truncated = False
|
| 260 |
+
for text in prompt:
|
| 261 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 262 |
+
text_token = []
|
| 263 |
+
text_weight = []
|
| 264 |
+
for word, weight in texts_and_weights:
|
| 265 |
+
# tokenize and discard the starting and the ending token
|
| 266 |
+
token = tokenizer(word).input_ids[1:-1]
|
| 267 |
+
text_token += token
|
| 268 |
+
# copy the weight by length of token
|
| 269 |
+
text_weight += [weight] * len(token)
|
| 270 |
+
# stop if the text is too long (longer than truncation limit)
|
| 271 |
+
if len(text_token) > max_length:
|
| 272 |
+
truncated = True
|
| 273 |
+
break
|
| 274 |
+
# truncate
|
| 275 |
+
if len(text_token) > max_length:
|
| 276 |
+
truncated = True
|
| 277 |
+
text_token = text_token[:max_length]
|
| 278 |
+
text_weight = text_weight[:max_length]
|
| 279 |
+
tokens.append(text_token)
|
| 280 |
+
weights.append(text_weight)
|
| 281 |
+
if truncated:
|
| 282 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 283 |
+
return tokens, weights
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 287 |
+
r"""
|
| 288 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 289 |
+
"""
|
| 290 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 291 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 292 |
+
for i in range(len(tokens)):
|
| 293 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 294 |
+
if no_boseos_middle:
|
| 295 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 296 |
+
else:
|
| 297 |
+
w = []
|
| 298 |
+
if len(weights[i]) == 0:
|
| 299 |
+
w = [1.0] * weights_length
|
| 300 |
+
else:
|
| 301 |
+
for j in range(max_embeddings_multiples):
|
| 302 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 303 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 304 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 305 |
+
w += [1.0] * (weights_length - len(w))
|
| 306 |
+
weights[i] = w[:]
|
| 307 |
+
|
| 308 |
+
return tokens, weights
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def get_unweighted_text_embeddings(
|
| 312 |
+
tokenizer,
|
| 313 |
+
text_encoder,
|
| 314 |
+
text_input: torch.Tensor,
|
| 315 |
+
chunk_length: int,
|
| 316 |
+
clip_skip: int,
|
| 317 |
+
eos: int,
|
| 318 |
+
pad: int,
|
| 319 |
+
no_boseos_middle: Optional[bool] = True,
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 323 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 324 |
+
"""
|
| 325 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 326 |
+
if max_embeddings_multiples > 1:
|
| 327 |
+
text_embeddings = []
|
| 328 |
+
for i in range(max_embeddings_multiples):
|
| 329 |
+
# extract the i-th chunk
|
| 330 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 331 |
+
|
| 332 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 333 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 334 |
+
if pad == eos: # v1
|
| 335 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 336 |
+
else: # v2
|
| 337 |
+
for j in range(len(text_input_chunk)):
|
| 338 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 339 |
+
text_input_chunk[j, -1] = eos
|
| 340 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 341 |
+
text_input_chunk[j, 1] = eos
|
| 342 |
+
|
| 343 |
+
if clip_skip is None or clip_skip == 1:
|
| 344 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
| 345 |
+
else:
|
| 346 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
| 347 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
| 348 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
| 349 |
+
|
| 350 |
+
if no_boseos_middle:
|
| 351 |
+
if i == 0:
|
| 352 |
+
# discard the ending token
|
| 353 |
+
text_embedding = text_embedding[:, :-1]
|
| 354 |
+
elif i == max_embeddings_multiples - 1:
|
| 355 |
+
# discard the starting token
|
| 356 |
+
text_embedding = text_embedding[:, 1:]
|
| 357 |
+
else:
|
| 358 |
+
# discard both starting and ending tokens
|
| 359 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 360 |
+
|
| 361 |
+
text_embeddings.append(text_embedding)
|
| 362 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 363 |
+
else:
|
| 364 |
+
if clip_skip is None or clip_skip == 1:
|
| 365 |
+
text_embeddings = text_encoder(text_input)[0]
|
| 366 |
+
else:
|
| 367 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
| 368 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
| 369 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
| 370 |
+
return text_embeddings
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def get_weighted_text_embeddings(
|
| 374 |
+
tokenizer,
|
| 375 |
+
text_encoder,
|
| 376 |
+
prompt: Union[str, List[str]],
|
| 377 |
+
device,
|
| 378 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 379 |
+
no_boseos_middle: Optional[bool] = False,
|
| 380 |
+
clip_skip=None,
|
| 381 |
+
):
|
| 382 |
+
r"""
|
| 383 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 384 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 385 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 386 |
+
|
| 387 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
prompt (`str` or `List[str]`):
|
| 391 |
+
The prompt or prompts to guide the image generation.
|
| 392 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 393 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 394 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 396 |
+
ending token in each of the chunk in the middle.
|
| 397 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 398 |
+
Skip the parsing of brackets.
|
| 399 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 400 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 401 |
+
"""
|
| 402 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 403 |
+
if isinstance(prompt, str):
|
| 404 |
+
prompt = [prompt]
|
| 405 |
+
|
| 406 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
| 407 |
+
|
| 408 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 409 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 410 |
+
|
| 411 |
+
max_embeddings_multiples = min(
|
| 412 |
+
max_embeddings_multiples,
|
| 413 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
| 414 |
+
)
|
| 415 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 416 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 417 |
+
|
| 418 |
+
# pad the length of tokens and weights
|
| 419 |
+
bos = tokenizer.bos_token_id
|
| 420 |
+
eos = tokenizer.eos_token_id
|
| 421 |
+
pad = tokenizer.pad_token_id
|
| 422 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 423 |
+
prompt_tokens,
|
| 424 |
+
prompt_weights,
|
| 425 |
+
max_length,
|
| 426 |
+
bos,
|
| 427 |
+
eos,
|
| 428 |
+
no_boseos_middle=no_boseos_middle,
|
| 429 |
+
chunk_length=tokenizer.model_max_length,
|
| 430 |
+
)
|
| 431 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
| 432 |
+
|
| 433 |
+
# get the embeddings
|
| 434 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 435 |
+
tokenizer,
|
| 436 |
+
text_encoder,
|
| 437 |
+
prompt_tokens,
|
| 438 |
+
tokenizer.model_max_length,
|
| 439 |
+
clip_skip,
|
| 440 |
+
eos,
|
| 441 |
+
pad,
|
| 442 |
+
no_boseos_middle=no_boseos_middle,
|
| 443 |
+
)
|
| 444 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
| 445 |
+
|
| 446 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 447 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 448 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
| 449 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 450 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 451 |
+
|
| 452 |
+
return text_embeddings
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
| 456 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
| 457 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
| 458 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
| 459 |
+
for i in range(iterations):
|
| 460 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
| 461 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
| 462 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
| 463 |
+
if wn == 1 or hn == 1:
|
| 464 |
+
break # Lowest resolution is 1x1
|
| 465 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
| 469 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
| 470 |
+
if noise_offset is None:
|
| 471 |
+
return noise
|
| 472 |
+
if adaptive_noise_scale is not None:
|
| 473 |
+
# latent shape: (batch_size, channels, height, width)
|
| 474 |
+
# abs mean value for each channel
|
| 475 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
| 476 |
+
|
| 477 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
| 478 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
| 479 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
| 480 |
+
|
| 481 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
| 482 |
+
return noise
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def apply_masked_loss(loss, batch):
|
| 486 |
+
if "conditioning_images" in batch:
|
| 487 |
+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
| 488 |
+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
| 489 |
+
mask_image = mask_image / 2 + 0.5
|
| 490 |
+
# print(f"conditioning_image: {mask_image.shape}")
|
| 491 |
+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
| 492 |
+
# alpha mask is 0 to 1
|
| 493 |
+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
| 494 |
+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
| 495 |
+
else:
|
| 496 |
+
return loss
|
| 497 |
+
|
| 498 |
+
# resize to the same size as the loss
|
| 499 |
+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
| 500 |
+
loss = loss * mask_image
|
| 501 |
+
return loss
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
"""
|
| 505 |
+
##########################################
|
| 506 |
+
# Perlin Noise
|
| 507 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
| 508 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
| 509 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
| 510 |
+
|
| 511 |
+
grid = (
|
| 512 |
+
torch.stack(
|
| 513 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
| 514 |
+
dim=-1,
|
| 515 |
+
)
|
| 516 |
+
% 1
|
| 517 |
+
)
|
| 518 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
| 519 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
| 520 |
+
|
| 521 |
+
tile_grads = (
|
| 522 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
| 523 |
+
.repeat_interleave(d[0], 0)
|
| 524 |
+
.repeat_interleave(d[1], 1)
|
| 525 |
+
)
|
| 526 |
+
dot = lambda grad, shift: (
|
| 527 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
| 528 |
+
* grad[: shape[0], : shape[1]]
|
| 529 |
+
).sum(dim=-1)
|
| 530 |
+
|
| 531 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
| 532 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
| 533 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
| 534 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
| 535 |
+
t = fade(grid[: shape[0], : shape[1]])
|
| 536 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
| 540 |
+
noise = torch.zeros(shape, device=device)
|
| 541 |
+
frequency = 1
|
| 542 |
+
amplitude = 1
|
| 543 |
+
for _ in range(octaves):
|
| 544 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
| 545 |
+
frequency *= 2
|
| 546 |
+
amplitude *= persistence
|
| 547 |
+
return noise
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def perlin_noise(noise, device, octaves):
|
| 551 |
+
_, c, w, h = noise.shape
|
| 552 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
| 553 |
+
noise_perlin = []
|
| 554 |
+
for _ in range(c):
|
| 555 |
+
noise_perlin.append(perlin())
|
| 556 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
| 557 |
+
noise += noise_perlin # broadcast for each batch
|
| 558 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 559 |
+
"""
|
library/deepspeed_utils.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from accelerate import DeepSpeedPlugin, Accelerator
|
| 5 |
+
|
| 6 |
+
from .utils import setup_logging
|
| 7 |
+
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
| 15 |
+
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
| 16 |
+
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
| 17 |
+
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--offload_optimizer_device",
|
| 20 |
+
type=str,
|
| 21 |
+
default=None,
|
| 22 |
+
choices=[None, "cpu", "nvme"],
|
| 23 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--offload_optimizer_nvme_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default=None,
|
| 29 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--offload_param_device",
|
| 33 |
+
type=str,
|
| 34 |
+
default=None,
|
| 35 |
+
choices=[None, "cpu", "nvme"],
|
| 36 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--offload_param_nvme_path",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--zero3_init_flag",
|
| 46 |
+
action="store_true",
|
| 47 |
+
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
| 48 |
+
"Only applicable with ZeRO Stage-3.",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--zero3_save_16bit_model",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--fp16_master_weights_and_gradients",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def prepare_deepspeed_args(args: argparse.Namespace):
|
| 63 |
+
if not args.deepspeed:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
| 67 |
+
args.max_data_loader_n_workers = 1
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
| 71 |
+
if not args.deepspeed:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
import deepspeed
|
| 76 |
+
except ImportError as e:
|
| 77 |
+
logger.error(
|
| 78 |
+
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
| 79 |
+
)
|
| 80 |
+
exit(1)
|
| 81 |
+
|
| 82 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
| 83 |
+
zero_stage=args.zero_stage,
|
| 84 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 85 |
+
gradient_clipping=args.max_grad_norm,
|
| 86 |
+
offload_optimizer_device=args.offload_optimizer_device,
|
| 87 |
+
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
| 88 |
+
offload_param_device=args.offload_param_device,
|
| 89 |
+
offload_param_nvme_path=args.offload_param_nvme_path,
|
| 90 |
+
zero3_init_flag=args.zero3_init_flag,
|
| 91 |
+
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
| 92 |
+
)
|
| 93 |
+
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
| 94 |
+
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
| 95 |
+
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
| 96 |
+
)
|
| 97 |
+
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
| 98 |
+
if args.mixed_precision.lower() == "fp16":
|
| 99 |
+
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
| 100 |
+
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
| 101 |
+
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
| 102 |
+
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
| 103 |
+
logger.info("[DeepSpeed] full fp16 enable.")
|
| 104 |
+
else:
|
| 105 |
+
logger.info(
|
| 106 |
+
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if args.offload_optimizer_device is not None:
|
| 110 |
+
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
| 111 |
+
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
| 112 |
+
logger.info("[DeepSpeed] building cpu_adam done.")
|
| 113 |
+
|
| 114 |
+
return deepspeed_plugin
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
| 118 |
+
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
| 119 |
+
# remove None from models
|
| 120 |
+
models = {k: v for k, v in models.items() if v is not None}
|
| 121 |
+
|
| 122 |
+
class DeepSpeedWrapper(torch.nn.Module):
|
| 123 |
+
def __init__(self, **kw_models) -> None:
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.models = torch.nn.ModuleDict()
|
| 126 |
+
|
| 127 |
+
for key, model in kw_models.items():
|
| 128 |
+
if isinstance(model, list):
|
| 129 |
+
model = torch.nn.ModuleList(model)
|
| 130 |
+
assert isinstance(
|
| 131 |
+
model, torch.nn.Module
|
| 132 |
+
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
| 133 |
+
self.models.update(torch.nn.ModuleDict({key: model}))
|
| 134 |
+
|
| 135 |
+
def get_models(self):
|
| 136 |
+
return self.models
|
| 137 |
+
|
| 138 |
+
ds_model = DeepSpeedWrapper(**models)
|
| 139 |
+
return ds_model
|
library/device_utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import gc
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
HAS_CUDA = torch.cuda.is_available()
|
| 8 |
+
except Exception:
|
| 9 |
+
HAS_CUDA = False
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
HAS_MPS = torch.backends.mps.is_available()
|
| 13 |
+
except Exception:
|
| 14 |
+
HAS_MPS = False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import intel_extension_for_pytorch as ipex # noqa
|
| 18 |
+
|
| 19 |
+
HAS_XPU = torch.xpu.is_available()
|
| 20 |
+
except Exception:
|
| 21 |
+
HAS_XPU = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def clean_memory():
|
| 25 |
+
gc.collect()
|
| 26 |
+
if HAS_CUDA:
|
| 27 |
+
torch.cuda.empty_cache()
|
| 28 |
+
if HAS_XPU:
|
| 29 |
+
torch.xpu.empty_cache()
|
| 30 |
+
if HAS_MPS:
|
| 31 |
+
torch.mps.empty_cache()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def clean_memory_on_device(device: torch.device):
|
| 35 |
+
r"""
|
| 36 |
+
Clean memory on the specified device, will be called from training scripts.
|
| 37 |
+
"""
|
| 38 |
+
gc.collect()
|
| 39 |
+
|
| 40 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
| 41 |
+
if device.type == "cuda":
|
| 42 |
+
torch.cuda.empty_cache()
|
| 43 |
+
if device.type == "xpu":
|
| 44 |
+
torch.xpu.empty_cache()
|
| 45 |
+
if device.type == "mps":
|
| 46 |
+
torch.mps.empty_cache()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@functools.lru_cache(maxsize=None)
|
| 50 |
+
def get_preferred_device() -> torch.device:
|
| 51 |
+
r"""
|
| 52 |
+
Do not call this function from training scripts. Use accelerator.device instead.
|
| 53 |
+
"""
|
| 54 |
+
if HAS_CUDA:
|
| 55 |
+
device = torch.device("cuda")
|
| 56 |
+
elif HAS_XPU:
|
| 57 |
+
device = torch.device("xpu")
|
| 58 |
+
elif HAS_MPS:
|
| 59 |
+
device = torch.device("mps")
|
| 60 |
+
else:
|
| 61 |
+
device = torch.device("cpu")
|
| 62 |
+
print(f"get_preferred_device() -> {device}")
|
| 63 |
+
return device
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def init_ipex():
|
| 67 |
+
"""
|
| 68 |
+
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
| 69 |
+
|
| 70 |
+
This function should run right after importing torch and before doing anything else.
|
| 71 |
+
|
| 72 |
+
If IPEX is not available, this function does nothing.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
if HAS_XPU:
|
| 76 |
+
from library.ipex import ipex_init
|
| 77 |
+
|
| 78 |
+
is_initialized, error_message = ipex_init()
|
| 79 |
+
if not is_initialized:
|
| 80 |
+
print("failed to initialize ipex:", error_message)
|
| 81 |
+
else:
|
| 82 |
+
return
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print("failed to initialize ipex:", e)
|
library/huggingface_util.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, BinaryIO
|
| 2 |
+
from huggingface_hub import HfApi
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from library.utils import fire_in_thread
|
| 7 |
+
from library.utils import setup_logging
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
| 13 |
+
api = HfApi(
|
| 14 |
+
token=token,
|
| 15 |
+
)
|
| 16 |
+
try:
|
| 17 |
+
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
| 18 |
+
return True
|
| 19 |
+
except:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def upload(
|
| 24 |
+
args: argparse.Namespace,
|
| 25 |
+
src: Union[str, Path, bytes, BinaryIO],
|
| 26 |
+
dest_suffix: str = "",
|
| 27 |
+
force_sync_upload: bool = False,
|
| 28 |
+
):
|
| 29 |
+
repo_id = args.huggingface_repo_id
|
| 30 |
+
repo_type = args.huggingface_repo_type
|
| 31 |
+
token = args.huggingface_token
|
| 32 |
+
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
| 33 |
+
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
| 34 |
+
api = HfApi(token=token)
|
| 35 |
+
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
| 36 |
+
try:
|
| 37 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
| 38 |
+
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
| 39 |
+
logger.error("===========================================")
|
| 40 |
+
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
| 41 |
+
logger.error("===========================================")
|
| 42 |
+
|
| 43 |
+
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
| 44 |
+
|
| 45 |
+
def uploader():
|
| 46 |
+
try:
|
| 47 |
+
if is_folder:
|
| 48 |
+
api.upload_folder(
|
| 49 |
+
repo_id=repo_id,
|
| 50 |
+
repo_type=repo_type,
|
| 51 |
+
folder_path=src,
|
| 52 |
+
path_in_repo=path_in_repo,
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
api.upload_file(
|
| 56 |
+
repo_id=repo_id,
|
| 57 |
+
repo_type=repo_type,
|
| 58 |
+
path_or_fileobj=src,
|
| 59 |
+
path_in_repo=path_in_repo,
|
| 60 |
+
)
|
| 61 |
+
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
| 62 |
+
logger.error("===========================================")
|
| 63 |
+
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
| 64 |
+
logger.error("===========================================")
|
| 65 |
+
|
| 66 |
+
if args.async_upload and not force_sync_upload:
|
| 67 |
+
fire_in_thread(uploader)
|
| 68 |
+
else:
|
| 69 |
+
uploader()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def list_dir(
|
| 73 |
+
repo_id: str,
|
| 74 |
+
subfolder: str,
|
| 75 |
+
repo_type: str,
|
| 76 |
+
revision: str = "main",
|
| 77 |
+
token: str = None,
|
| 78 |
+
):
|
| 79 |
+
api = HfApi(
|
| 80 |
+
token=token,
|
| 81 |
+
)
|
| 82 |
+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
| 83 |
+
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
| 84 |
+
return file_list
|
library/hypernetwork.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from diffusers.models.attention_processor import (
|
| 4 |
+
Attention,
|
| 5 |
+
AttnProcessor2_0,
|
| 6 |
+
SlicedAttnProcessor,
|
| 7 |
+
XFormersAttnProcessor
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import xformers.ops
|
| 12 |
+
except:
|
| 13 |
+
xformers = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
loaded_networks = []
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def apply_single_hypernetwork(
|
| 20 |
+
hypernetwork, hidden_states, encoder_hidden_states
|
| 21 |
+
):
|
| 22 |
+
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
| 23 |
+
return context_k, context_v
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def apply_hypernetworks(context_k, context_v, layer=None):
|
| 27 |
+
if len(loaded_networks) == 0:
|
| 28 |
+
return context_v, context_v
|
| 29 |
+
for hypernetwork in loaded_networks:
|
| 30 |
+
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
| 31 |
+
|
| 32 |
+
context_k = context_k.to(dtype=context_k.dtype)
|
| 33 |
+
context_v = context_v.to(dtype=context_k.dtype)
|
| 34 |
+
|
| 35 |
+
return context_k, context_v
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def xformers_forward(
|
| 40 |
+
self: XFormersAttnProcessor,
|
| 41 |
+
attn: Attention,
|
| 42 |
+
hidden_states: torch.Tensor,
|
| 43 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 44 |
+
attention_mask: torch.Tensor = None,
|
| 45 |
+
):
|
| 46 |
+
batch_size, sequence_length, _ = (
|
| 47 |
+
hidden_states.shape
|
| 48 |
+
if encoder_hidden_states is None
|
| 49 |
+
else encoder_hidden_states.shape
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
attention_mask = attn.prepare_attention_mask(
|
| 53 |
+
attention_mask, sequence_length, batch_size
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
query = attn.to_q(hidden_states)
|
| 57 |
+
|
| 58 |
+
if encoder_hidden_states is None:
|
| 59 |
+
encoder_hidden_states = hidden_states
|
| 60 |
+
elif attn.norm_cross:
|
| 61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 62 |
+
|
| 63 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 64 |
+
|
| 65 |
+
key = attn.to_k(context_k)
|
| 66 |
+
value = attn.to_v(context_v)
|
| 67 |
+
|
| 68 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
| 69 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
| 70 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
| 71 |
+
|
| 72 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
| 73 |
+
query,
|
| 74 |
+
key,
|
| 75 |
+
value,
|
| 76 |
+
attn_bias=attention_mask,
|
| 77 |
+
op=self.attention_op,
|
| 78 |
+
scale=attn.scale,
|
| 79 |
+
)
|
| 80 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 81 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 82 |
+
|
| 83 |
+
# linear proj
|
| 84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 85 |
+
# dropout
|
| 86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 87 |
+
return hidden_states
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sliced_attn_forward(
|
| 91 |
+
self: SlicedAttnProcessor,
|
| 92 |
+
attn: Attention,
|
| 93 |
+
hidden_states: torch.Tensor,
|
| 94 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 95 |
+
attention_mask: torch.Tensor = None,
|
| 96 |
+
):
|
| 97 |
+
batch_size, sequence_length, _ = (
|
| 98 |
+
hidden_states.shape
|
| 99 |
+
if encoder_hidden_states is None
|
| 100 |
+
else encoder_hidden_states.shape
|
| 101 |
+
)
|
| 102 |
+
attention_mask = attn.prepare_attention_mask(
|
| 103 |
+
attention_mask, sequence_length, batch_size
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
query = attn.to_q(hidden_states)
|
| 107 |
+
dim = query.shape[-1]
|
| 108 |
+
query = attn.head_to_batch_dim(query)
|
| 109 |
+
|
| 110 |
+
if encoder_hidden_states is None:
|
| 111 |
+
encoder_hidden_states = hidden_states
|
| 112 |
+
elif attn.norm_cross:
|
| 113 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 114 |
+
|
| 115 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 116 |
+
|
| 117 |
+
key = attn.to_k(context_k)
|
| 118 |
+
value = attn.to_v(context_v)
|
| 119 |
+
key = attn.head_to_batch_dim(key)
|
| 120 |
+
value = attn.head_to_batch_dim(value)
|
| 121 |
+
|
| 122 |
+
batch_size_attention, query_tokens, _ = query.shape
|
| 123 |
+
hidden_states = torch.zeros(
|
| 124 |
+
(batch_size_attention, query_tokens, dim // attn.heads),
|
| 125 |
+
device=query.device,
|
| 126 |
+
dtype=query.dtype,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
for i in range(batch_size_attention // self.slice_size):
|
| 130 |
+
start_idx = i * self.slice_size
|
| 131 |
+
end_idx = (i + 1) * self.slice_size
|
| 132 |
+
|
| 133 |
+
query_slice = query[start_idx:end_idx]
|
| 134 |
+
key_slice = key[start_idx:end_idx]
|
| 135 |
+
attn_mask_slice = (
|
| 136 |
+
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 140 |
+
|
| 141 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 142 |
+
|
| 143 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 144 |
+
|
| 145 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 146 |
+
|
| 147 |
+
# linear proj
|
| 148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 149 |
+
# dropout
|
| 150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 151 |
+
|
| 152 |
+
return hidden_states
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def v2_0_forward(
|
| 156 |
+
self: AttnProcessor2_0,
|
| 157 |
+
attn: Attention,
|
| 158 |
+
hidden_states,
|
| 159 |
+
encoder_hidden_states=None,
|
| 160 |
+
attention_mask=None,
|
| 161 |
+
):
|
| 162 |
+
batch_size, sequence_length, _ = (
|
| 163 |
+
hidden_states.shape
|
| 164 |
+
if encoder_hidden_states is None
|
| 165 |
+
else encoder_hidden_states.shape
|
| 166 |
+
)
|
| 167 |
+
inner_dim = hidden_states.shape[-1]
|
| 168 |
+
|
| 169 |
+
if attention_mask is not None:
|
| 170 |
+
attention_mask = attn.prepare_attention_mask(
|
| 171 |
+
attention_mask, sequence_length, batch_size
|
| 172 |
+
)
|
| 173 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 174 |
+
# (batch, heads, source_length, target_length)
|
| 175 |
+
attention_mask = attention_mask.view(
|
| 176 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
query = attn.to_q(hidden_states)
|
| 180 |
+
|
| 181 |
+
if encoder_hidden_states is None:
|
| 182 |
+
encoder_hidden_states = hidden_states
|
| 183 |
+
elif attn.norm_cross:
|
| 184 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 185 |
+
|
| 186 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 187 |
+
|
| 188 |
+
key = attn.to_k(context_k)
|
| 189 |
+
value = attn.to_v(context_v)
|
| 190 |
+
|
| 191 |
+
head_dim = inner_dim // attn.heads
|
| 192 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 193 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 194 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 195 |
+
|
| 196 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 197 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 198 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 199 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 203 |
+
batch_size, -1, attn.heads * head_dim
|
| 204 |
+
)
|
| 205 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 206 |
+
|
| 207 |
+
# linear proj
|
| 208 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 209 |
+
# dropout
|
| 210 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 211 |
+
return hidden_states
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def replace_attentions_for_hypernetwork():
|
| 215 |
+
import diffusers.models.attention_processor
|
| 216 |
+
|
| 217 |
+
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
| 218 |
+
xformers_forward
|
| 219 |
+
)
|
| 220 |
+
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
| 221 |
+
sliced_attn_forward
|
| 222 |
+
)
|
| 223 |
+
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
library/ipex/__init__.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import contextlib
|
| 4 |
+
import torch
|
| 5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 6 |
+
from .hijacks import ipex_hijacks
|
| 7 |
+
|
| 8 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 9 |
+
|
| 10 |
+
def ipex_init(): # pylint: disable=too-many-statements
|
| 11 |
+
try:
|
| 12 |
+
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
| 13 |
+
return True, "Skipping IPEX hijack"
|
| 14 |
+
else:
|
| 15 |
+
# Replace cuda with xpu:
|
| 16 |
+
torch.cuda.current_device = torch.xpu.current_device
|
| 17 |
+
torch.cuda.current_stream = torch.xpu.current_stream
|
| 18 |
+
torch.cuda.device = torch.xpu.device
|
| 19 |
+
torch.cuda.device_count = torch.xpu.device_count
|
| 20 |
+
torch.cuda.device_of = torch.xpu.device_of
|
| 21 |
+
torch.cuda.get_device_name = torch.xpu.get_device_name
|
| 22 |
+
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
| 23 |
+
torch.cuda.init = torch.xpu.init
|
| 24 |
+
torch.cuda.is_available = torch.xpu.is_available
|
| 25 |
+
torch.cuda.is_initialized = torch.xpu.is_initialized
|
| 26 |
+
torch.cuda.is_current_stream_capturing = lambda: False
|
| 27 |
+
torch.cuda.set_device = torch.xpu.set_device
|
| 28 |
+
torch.cuda.stream = torch.xpu.stream
|
| 29 |
+
torch.cuda.synchronize = torch.xpu.synchronize
|
| 30 |
+
torch.cuda.Event = torch.xpu.Event
|
| 31 |
+
torch.cuda.Stream = torch.xpu.Stream
|
| 32 |
+
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
| 33 |
+
torch.Tensor.cuda = torch.Tensor.xpu
|
| 34 |
+
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
| 35 |
+
torch.nn.Module.cuda = torch.nn.Module.xpu
|
| 36 |
+
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
| 37 |
+
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
| 38 |
+
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
| 39 |
+
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
| 40 |
+
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
| 41 |
+
torch.cuda._tls = torch.xpu.lazy_init._tls
|
| 42 |
+
torch.cuda.threading = torch.xpu.lazy_init.threading
|
| 43 |
+
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
| 44 |
+
torch.cuda.Optional = torch.xpu.Optional
|
| 45 |
+
torch.cuda.__cached__ = torch.xpu.__cached__
|
| 46 |
+
torch.cuda.__loader__ = torch.xpu.__loader__
|
| 47 |
+
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
| 48 |
+
torch.cuda.Tuple = torch.xpu.Tuple
|
| 49 |
+
torch.cuda.streams = torch.xpu.streams
|
| 50 |
+
torch.cuda._lazy_new = torch.xpu._lazy_new
|
| 51 |
+
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
| 52 |
+
torch.cuda.Any = torch.xpu.Any
|
| 53 |
+
torch.cuda.__doc__ = torch.xpu.__doc__
|
| 54 |
+
torch.cuda.default_generators = torch.xpu.default_generators
|
| 55 |
+
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
| 56 |
+
torch.cuda._get_device_index = torch.xpu._get_device_index
|
| 57 |
+
torch.cuda.__path__ = torch.xpu.__path__
|
| 58 |
+
torch.cuda.Device = torch.xpu.Device
|
| 59 |
+
torch.cuda.IntTensor = torch.xpu.IntTensor
|
| 60 |
+
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
| 61 |
+
torch.cuda.set_stream = torch.xpu.set_stream
|
| 62 |
+
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
| 63 |
+
torch.cuda.os = torch.xpu.os
|
| 64 |
+
torch.cuda.torch = torch.xpu.torch
|
| 65 |
+
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
| 66 |
+
torch.cuda.Union = torch.xpu.Union
|
| 67 |
+
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
| 68 |
+
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
| 69 |
+
torch.cuda.LongTensor = torch.xpu.LongTensor
|
| 70 |
+
torch.cuda.IntStorage = torch.xpu.IntStorage
|
| 71 |
+
torch.cuda.LongStorage = torch.xpu.LongStorage
|
| 72 |
+
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
| 73 |
+
torch.cuda.__package__ = torch.xpu.__package__
|
| 74 |
+
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
| 75 |
+
torch.cuda.CharTensor = torch.xpu.CharTensor
|
| 76 |
+
torch.cuda.List = torch.xpu.List
|
| 77 |
+
torch.cuda._lazy_init = torch.xpu._lazy_init
|
| 78 |
+
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
| 79 |
+
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
| 80 |
+
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
| 81 |
+
torch.cuda.StreamContext = torch.xpu.StreamContext
|
| 82 |
+
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
| 83 |
+
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
| 84 |
+
torch.cuda._lazy_call = torch.xpu._lazy_call
|
| 85 |
+
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
| 86 |
+
torch.cuda.random = torch.xpu.random
|
| 87 |
+
torch.cuda._device = torch.xpu._device
|
| 88 |
+
torch.cuda.classproperty = torch.xpu.classproperty
|
| 89 |
+
torch.cuda.__name__ = torch.xpu.__name__
|
| 90 |
+
torch.cuda._device_t = torch.xpu._device_t
|
| 91 |
+
torch.cuda.warnings = torch.xpu.warnings
|
| 92 |
+
torch.cuda.__spec__ = torch.xpu.__spec__
|
| 93 |
+
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
| 94 |
+
torch.cuda.CharStorage = torch.xpu.CharStorage
|
| 95 |
+
torch.cuda.__file__ = torch.xpu.__file__
|
| 96 |
+
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
| 97 |
+
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
| 98 |
+
|
| 99 |
+
# Memory:
|
| 100 |
+
torch.cuda.memory = torch.xpu.memory
|
| 101 |
+
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
| 102 |
+
torch.xpu.empty_cache = lambda: None
|
| 103 |
+
torch.cuda.empty_cache = torch.xpu.empty_cache
|
| 104 |
+
torch.cuda.memory_stats = torch.xpu.memory_stats
|
| 105 |
+
torch.cuda.memory_summary = torch.xpu.memory_summary
|
| 106 |
+
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
| 107 |
+
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
| 108 |
+
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
| 109 |
+
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
| 110 |
+
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
| 111 |
+
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
| 112 |
+
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
| 113 |
+
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
| 114 |
+
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
| 115 |
+
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
| 116 |
+
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
| 117 |
+
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
| 118 |
+
|
| 119 |
+
# RNG:
|
| 120 |
+
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
| 121 |
+
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
| 122 |
+
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
| 123 |
+
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
| 124 |
+
torch.cuda.manual_seed = torch.xpu.manual_seed
|
| 125 |
+
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
| 126 |
+
torch.cuda.seed = torch.xpu.seed
|
| 127 |
+
torch.cuda.seed_all = torch.xpu.seed_all
|
| 128 |
+
torch.cuda.initial_seed = torch.xpu.initial_seed
|
| 129 |
+
|
| 130 |
+
# AMP:
|
| 131 |
+
torch.cuda.amp = torch.xpu.amp
|
| 132 |
+
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
| 133 |
+
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
| 134 |
+
|
| 135 |
+
if not hasattr(torch.cuda.amp, "common"):
|
| 136 |
+
torch.cuda.amp.common = contextlib.nullcontext()
|
| 137 |
+
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
| 141 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 142 |
+
try:
|
| 143 |
+
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
| 144 |
+
gradscaler_init()
|
| 145 |
+
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
| 146 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 147 |
+
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
| 148 |
+
|
| 149 |
+
# C
|
| 150 |
+
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
| 151 |
+
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
| 152 |
+
ipex._C._DeviceProperties.major = 2024
|
| 153 |
+
ipex._C._DeviceProperties.minor = 0
|
| 154 |
+
|
| 155 |
+
# Fix functions with ipex:
|
| 156 |
+
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
| 157 |
+
torch._utils._get_available_device_type = lambda: "xpu"
|
| 158 |
+
torch.has_cuda = True
|
| 159 |
+
torch.cuda.has_half = True
|
| 160 |
+
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
| 161 |
+
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
| 162 |
+
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
| 163 |
+
torch.version.cuda = "12.1"
|
| 164 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
| 165 |
+
torch.cuda.get_device_properties.major = 12
|
| 166 |
+
torch.cuda.get_device_properties.minor = 1
|
| 167 |
+
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
| 168 |
+
torch.cuda.utilization = lambda *args, **kwargs: 0
|
| 169 |
+
|
| 170 |
+
ipex_hijacks()
|
| 171 |
+
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
| 172 |
+
try:
|
| 173 |
+
from .diffusers import ipex_diffusers
|
| 174 |
+
ipex_diffusers()
|
| 175 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 176 |
+
pass
|
| 177 |
+
torch.cuda.is_xpu_hijacked = True
|
| 178 |
+
except Exception as e:
|
| 179 |
+
return False, e
|
| 180 |
+
return True, None
|
library/ipex/attention.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 4 |
+
from functools import cache
|
| 5 |
+
|
| 6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 7 |
+
|
| 8 |
+
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
| 9 |
+
|
| 10 |
+
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
| 11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
| 12 |
+
|
| 13 |
+
# Find something divisible with the input_tokens
|
| 14 |
+
@cache
|
| 15 |
+
def find_slice_size(slice_size, slice_block_size):
|
| 16 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
| 17 |
+
slice_size = slice_size // 2
|
| 18 |
+
if slice_size <= 1:
|
| 19 |
+
slice_size = 1
|
| 20 |
+
break
|
| 21 |
+
return slice_size
|
| 22 |
+
|
| 23 |
+
# Find slice sizes for SDPA
|
| 24 |
+
@cache
|
| 25 |
+
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
| 26 |
+
if len(query_shape) == 3:
|
| 27 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
| 28 |
+
shape_four = 1
|
| 29 |
+
else:
|
| 30 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
| 31 |
+
|
| 32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
| 33 |
+
block_size = batch_size_attention * slice_block_size
|
| 34 |
+
|
| 35 |
+
split_slice_size = batch_size_attention
|
| 36 |
+
split_2_slice_size = query_tokens
|
| 37 |
+
split_3_slice_size = shape_three
|
| 38 |
+
|
| 39 |
+
do_split = False
|
| 40 |
+
do_split_2 = False
|
| 41 |
+
do_split_3 = False
|
| 42 |
+
|
| 43 |
+
if block_size > sdpa_slice_trigger_rate:
|
| 44 |
+
do_split = True
|
| 45 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
| 46 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
| 47 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
| 48 |
+
do_split_2 = True
|
| 49 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
| 50 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
| 51 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
| 52 |
+
do_split_3 = True
|
| 53 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
| 54 |
+
|
| 55 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
| 56 |
+
|
| 57 |
+
# Find slice sizes for BMM
|
| 58 |
+
@cache
|
| 59 |
+
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
| 60 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
| 61 |
+
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
| 62 |
+
block_size = batch_size_attention * slice_block_size
|
| 63 |
+
|
| 64 |
+
split_slice_size = batch_size_attention
|
| 65 |
+
split_2_slice_size = input_tokens
|
| 66 |
+
split_3_slice_size = mat2_atten_shape
|
| 67 |
+
|
| 68 |
+
do_split = False
|
| 69 |
+
do_split_2 = False
|
| 70 |
+
do_split_3 = False
|
| 71 |
+
|
| 72 |
+
if block_size > attention_slice_rate:
|
| 73 |
+
do_split = True
|
| 74 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
| 75 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
| 76 |
+
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
| 77 |
+
do_split_2 = True
|
| 78 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
| 79 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
| 80 |
+
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
| 81 |
+
do_split_3 = True
|
| 82 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
| 83 |
+
|
| 84 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
original_torch_bmm = torch.bmm
|
| 88 |
+
def torch_bmm_32_bit(input, mat2, *, out=None):
|
| 89 |
+
if input.device.type != "xpu":
|
| 90 |
+
return original_torch_bmm(input, mat2, out=out)
|
| 91 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
| 92 |
+
|
| 93 |
+
# Slice BMM
|
| 94 |
+
if do_split:
|
| 95 |
+
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
| 96 |
+
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
| 97 |
+
for i in range(batch_size_attention // split_slice_size):
|
| 98 |
+
start_idx = i * split_slice_size
|
| 99 |
+
end_idx = (i + 1) * split_slice_size
|
| 100 |
+
if do_split_2:
|
| 101 |
+
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
| 102 |
+
start_idx_2 = i2 * split_2_slice_size
|
| 103 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
| 104 |
+
if do_split_3:
|
| 105 |
+
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
| 106 |
+
start_idx_3 = i3 * split_3_slice_size
|
| 107 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
| 108 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
| 109 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
| 110 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
| 111 |
+
out=out
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
| 115 |
+
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
| 116 |
+
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
| 117 |
+
out=out
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
| 121 |
+
input[start_idx:end_idx],
|
| 122 |
+
mat2[start_idx:end_idx],
|
| 123 |
+
out=out
|
| 124 |
+
)
|
| 125 |
+
torch.xpu.synchronize(input.device)
|
| 126 |
+
else:
|
| 127 |
+
return original_torch_bmm(input, mat2, out=out)
|
| 128 |
+
return hidden_states
|
| 129 |
+
|
| 130 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 131 |
+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
| 132 |
+
if query.device.type != "xpu":
|
| 133 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
| 134 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
| 135 |
+
|
| 136 |
+
# Slice SDPA
|
| 137 |
+
if do_split:
|
| 138 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
| 139 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
| 140 |
+
for i in range(batch_size_attention // split_slice_size):
|
| 141 |
+
start_idx = i * split_slice_size
|
| 142 |
+
end_idx = (i + 1) * split_slice_size
|
| 143 |
+
if do_split_2:
|
| 144 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
| 145 |
+
start_idx_2 = i2 * split_2_slice_size
|
| 146 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
| 147 |
+
if do_split_3:
|
| 148 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
| 149 |
+
start_idx_3 = i3 * split_3_slice_size
|
| 150 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
| 151 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
| 152 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
| 153 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
| 154 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
| 155 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
| 156 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
| 160 |
+
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
| 161 |
+
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
| 162 |
+
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
| 163 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
| 164 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
| 168 |
+
query[start_idx:end_idx],
|
| 169 |
+
key[start_idx:end_idx],
|
| 170 |
+
value[start_idx:end_idx],
|
| 171 |
+
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
| 172 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 173 |
+
)
|
| 174 |
+
torch.xpu.synchronize(query.device)
|
| 175 |
+
else:
|
| 176 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
| 177 |
+
return hidden_states
|
library/ipex/diffusers.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 4 |
+
import diffusers #0.24.0 # pylint: disable=import-error
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
| 7 |
+
from functools import cache
|
| 8 |
+
|
| 9 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 10 |
+
|
| 11 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
| 12 |
+
|
| 13 |
+
@cache
|
| 14 |
+
def find_slice_size(slice_size, slice_block_size):
|
| 15 |
+
while (slice_size * slice_block_size) > attention_slice_rate:
|
| 16 |
+
slice_size = slice_size // 2
|
| 17 |
+
if slice_size <= 1:
|
| 18 |
+
slice_size = 1
|
| 19 |
+
break
|
| 20 |
+
return slice_size
|
| 21 |
+
|
| 22 |
+
@cache
|
| 23 |
+
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
| 24 |
+
if len(query_shape) == 3:
|
| 25 |
+
batch_size_attention, query_tokens, shape_three = query_shape
|
| 26 |
+
shape_four = 1
|
| 27 |
+
else:
|
| 28 |
+
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
| 29 |
+
if slice_size is not None:
|
| 30 |
+
batch_size_attention = slice_size
|
| 31 |
+
|
| 32 |
+
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
| 33 |
+
block_size = batch_size_attention * slice_block_size
|
| 34 |
+
|
| 35 |
+
split_slice_size = batch_size_attention
|
| 36 |
+
split_2_slice_size = query_tokens
|
| 37 |
+
split_3_slice_size = shape_three
|
| 38 |
+
|
| 39 |
+
do_split = False
|
| 40 |
+
do_split_2 = False
|
| 41 |
+
do_split_3 = False
|
| 42 |
+
|
| 43 |
+
if query_device_type != "xpu":
|
| 44 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
| 45 |
+
|
| 46 |
+
if block_size > attention_slice_rate:
|
| 47 |
+
do_split = True
|
| 48 |
+
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
| 49 |
+
if split_slice_size * slice_block_size > attention_slice_rate:
|
| 50 |
+
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
| 51 |
+
do_split_2 = True
|
| 52 |
+
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
| 53 |
+
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
| 54 |
+
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
| 55 |
+
do_split_3 = True
|
| 56 |
+
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
| 57 |
+
|
| 58 |
+
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
| 59 |
+
|
| 60 |
+
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
| 61 |
+
r"""
|
| 62 |
+
Processor for implementing sliced attention.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
slice_size (`int`, *optional*):
|
| 66 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
| 67 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, slice_size):
|
| 71 |
+
self.slice_size = slice_size
|
| 72 |
+
|
| 73 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
| 74 |
+
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
| 75 |
+
|
| 76 |
+
residual = hidden_states
|
| 77 |
+
|
| 78 |
+
input_ndim = hidden_states.ndim
|
| 79 |
+
|
| 80 |
+
if input_ndim == 4:
|
| 81 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 82 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 83 |
+
|
| 84 |
+
batch_size, sequence_length, _ = (
|
| 85 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 86 |
+
)
|
| 87 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 88 |
+
|
| 89 |
+
if attn.group_norm is not None:
|
| 90 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 91 |
+
|
| 92 |
+
query = attn.to_q(hidden_states)
|
| 93 |
+
dim = query.shape[-1]
|
| 94 |
+
query = attn.head_to_batch_dim(query)
|
| 95 |
+
|
| 96 |
+
if encoder_hidden_states is None:
|
| 97 |
+
encoder_hidden_states = hidden_states
|
| 98 |
+
elif attn.norm_cross:
|
| 99 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 100 |
+
|
| 101 |
+
key = attn.to_k(encoder_hidden_states)
|
| 102 |
+
value = attn.to_v(encoder_hidden_states)
|
| 103 |
+
key = attn.head_to_batch_dim(key)
|
| 104 |
+
value = attn.head_to_batch_dim(value)
|
| 105 |
+
|
| 106 |
+
batch_size_attention, query_tokens, shape_three = query.shape
|
| 107 |
+
hidden_states = torch.zeros(
|
| 108 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
####################################################################
|
| 112 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
| 113 |
+
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
| 114 |
+
|
| 115 |
+
for i in range(batch_size_attention // split_slice_size):
|
| 116 |
+
start_idx = i * split_slice_size
|
| 117 |
+
end_idx = (i + 1) * split_slice_size
|
| 118 |
+
if do_split_2:
|
| 119 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
| 120 |
+
start_idx_2 = i2 * split_2_slice_size
|
| 121 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
| 122 |
+
if do_split_3:
|
| 123 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
| 124 |
+
start_idx_3 = i3 * split_3_slice_size
|
| 125 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
| 126 |
+
|
| 127 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
| 128 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
| 129 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
| 130 |
+
|
| 131 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 132 |
+
del query_slice
|
| 133 |
+
del key_slice
|
| 134 |
+
del attn_mask_slice
|
| 135 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
| 136 |
+
|
| 137 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
| 138 |
+
del attn_slice
|
| 139 |
+
else:
|
| 140 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
| 141 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
| 142 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
| 143 |
+
|
| 144 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 145 |
+
del query_slice
|
| 146 |
+
del key_slice
|
| 147 |
+
del attn_mask_slice
|
| 148 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
| 149 |
+
|
| 150 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
| 151 |
+
del attn_slice
|
| 152 |
+
torch.xpu.synchronize(query.device)
|
| 153 |
+
else:
|
| 154 |
+
query_slice = query[start_idx:end_idx]
|
| 155 |
+
key_slice = key[start_idx:end_idx]
|
| 156 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 157 |
+
|
| 158 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 159 |
+
del query_slice
|
| 160 |
+
del key_slice
|
| 161 |
+
del attn_mask_slice
|
| 162 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 163 |
+
|
| 164 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 165 |
+
del attn_slice
|
| 166 |
+
####################################################################
|
| 167 |
+
|
| 168 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 169 |
+
|
| 170 |
+
# linear proj
|
| 171 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 172 |
+
# dropout
|
| 173 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 174 |
+
|
| 175 |
+
if input_ndim == 4:
|
| 176 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 177 |
+
|
| 178 |
+
if attn.residual_connection:
|
| 179 |
+
hidden_states = hidden_states + residual
|
| 180 |
+
|
| 181 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 182 |
+
|
| 183 |
+
return hidden_states
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class AttnProcessor:
|
| 187 |
+
r"""
|
| 188 |
+
Default processor for performing attention-related computations.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
| 192 |
+
encoder_hidden_states=None, attention_mask=None,
|
| 193 |
+
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
| 194 |
+
|
| 195 |
+
residual = hidden_states
|
| 196 |
+
|
| 197 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
| 198 |
+
|
| 199 |
+
if attn.spatial_norm is not None:
|
| 200 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 201 |
+
|
| 202 |
+
input_ndim = hidden_states.ndim
|
| 203 |
+
|
| 204 |
+
if input_ndim == 4:
|
| 205 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 206 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 207 |
+
|
| 208 |
+
batch_size, sequence_length, _ = (
|
| 209 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 210 |
+
)
|
| 211 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 212 |
+
|
| 213 |
+
if attn.group_norm is not None:
|
| 214 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 215 |
+
|
| 216 |
+
query = attn.to_q(hidden_states, *args)
|
| 217 |
+
|
| 218 |
+
if encoder_hidden_states is None:
|
| 219 |
+
encoder_hidden_states = hidden_states
|
| 220 |
+
elif attn.norm_cross:
|
| 221 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 222 |
+
|
| 223 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
| 224 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
| 225 |
+
|
| 226 |
+
query = attn.head_to_batch_dim(query)
|
| 227 |
+
key = attn.head_to_batch_dim(key)
|
| 228 |
+
value = attn.head_to_batch_dim(value)
|
| 229 |
+
|
| 230 |
+
####################################################################
|
| 231 |
+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
| 232 |
+
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
| 233 |
+
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
| 234 |
+
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
| 235 |
+
|
| 236 |
+
if do_split:
|
| 237 |
+
for i in range(batch_size_attention // split_slice_size):
|
| 238 |
+
start_idx = i * split_slice_size
|
| 239 |
+
end_idx = (i + 1) * split_slice_size
|
| 240 |
+
if do_split_2:
|
| 241 |
+
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
| 242 |
+
start_idx_2 = i2 * split_2_slice_size
|
| 243 |
+
end_idx_2 = (i2 + 1) * split_2_slice_size
|
| 244 |
+
if do_split_3:
|
| 245 |
+
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
| 246 |
+
start_idx_3 = i3 * split_3_slice_size
|
| 247 |
+
end_idx_3 = (i3 + 1) * split_3_slice_size
|
| 248 |
+
|
| 249 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
| 250 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
| 251 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
| 252 |
+
|
| 253 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 254 |
+
del query_slice
|
| 255 |
+
del key_slice
|
| 256 |
+
del attn_mask_slice
|
| 257 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
| 258 |
+
|
| 259 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
| 260 |
+
del attn_slice
|
| 261 |
+
else:
|
| 262 |
+
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
| 263 |
+
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
| 264 |
+
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
| 265 |
+
|
| 266 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 267 |
+
del query_slice
|
| 268 |
+
del key_slice
|
| 269 |
+
del attn_mask_slice
|
| 270 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
| 271 |
+
|
| 272 |
+
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
| 273 |
+
del attn_slice
|
| 274 |
+
else:
|
| 275 |
+
query_slice = query[start_idx:end_idx]
|
| 276 |
+
key_slice = key[start_idx:end_idx]
|
| 277 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 278 |
+
|
| 279 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 280 |
+
del query_slice
|
| 281 |
+
del key_slice
|
| 282 |
+
del attn_mask_slice
|
| 283 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 284 |
+
|
| 285 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 286 |
+
del attn_slice
|
| 287 |
+
torch.xpu.synchronize(query.device)
|
| 288 |
+
else:
|
| 289 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 290 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 291 |
+
####################################################################
|
| 292 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 293 |
+
|
| 294 |
+
# linear proj
|
| 295 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
| 296 |
+
# dropout
|
| 297 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 298 |
+
|
| 299 |
+
if input_ndim == 4:
|
| 300 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 301 |
+
|
| 302 |
+
if attn.residual_connection:
|
| 303 |
+
hidden_states = hidden_states + residual
|
| 304 |
+
|
| 305 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 306 |
+
|
| 307 |
+
return hidden_states
|
| 308 |
+
|
| 309 |
+
def ipex_diffusers():
|
| 310 |
+
#ARC GPUs can't allocate more than 4GB to a single block:
|
| 311 |
+
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
| 312 |
+
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
library/ipex/gradscaler.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import torch
|
| 3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 4 |
+
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
| 5 |
+
|
| 6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 7 |
+
|
| 8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
| 9 |
+
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
| 10 |
+
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
| 11 |
+
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
| 12 |
+
|
| 13 |
+
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
| 14 |
+
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
| 15 |
+
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
| 16 |
+
|
| 17 |
+
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
| 18 |
+
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
| 19 |
+
# However, we don't know their devices or dtypes in advance.
|
| 20 |
+
|
| 21 |
+
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
| 22 |
+
# Google says mypy struggles with defaultdicts type annotations.
|
| 23 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
| 24 |
+
# sync grad to master weight
|
| 25 |
+
if hasattr(optimizer, "sync_grad"):
|
| 26 |
+
optimizer.sync_grad()
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
for group in optimizer.param_groups:
|
| 29 |
+
for param in group["params"]:
|
| 30 |
+
if param.grad is None:
|
| 31 |
+
continue
|
| 32 |
+
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
| 33 |
+
raise ValueError("Attempting to unscale FP16 gradients.")
|
| 34 |
+
if param.grad.is_sparse:
|
| 35 |
+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
| 36 |
+
# coalesce() deduplicates indices and adds all values that have the same index.
|
| 37 |
+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
| 38 |
+
# so we should check the coalesced _values().
|
| 39 |
+
if param.grad.dtype is torch.float16:
|
| 40 |
+
param.grad = param.grad.coalesce()
|
| 41 |
+
to_unscale = param.grad._values()
|
| 42 |
+
else:
|
| 43 |
+
to_unscale = param.grad
|
| 44 |
+
|
| 45 |
+
# -: is there a way to split by device and dtype without appending in the inner loop?
|
| 46 |
+
to_unscale = to_unscale.to("cpu")
|
| 47 |
+
per_device_and_dtype_grads[to_unscale.device][
|
| 48 |
+
to_unscale.dtype
|
| 49 |
+
].append(to_unscale)
|
| 50 |
+
|
| 51 |
+
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
| 52 |
+
for grads in per_dtype_grads.values():
|
| 53 |
+
core._amp_foreach_non_finite_check_and_unscale_(
|
| 54 |
+
grads,
|
| 55 |
+
per_device_found_inf.get("cpu"),
|
| 56 |
+
per_device_inv_scale.get("cpu"),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return per_device_found_inf._per_device_tensors
|
| 60 |
+
|
| 61 |
+
def unscale_(self, optimizer):
|
| 62 |
+
"""
|
| 63 |
+
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
| 64 |
+
:meth:`unscale_` is optional, serving cases where you need to
|
| 65 |
+
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
| 66 |
+
between the backward pass(es) and :meth:`step`.
|
| 67 |
+
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
| 68 |
+
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
| 69 |
+
...
|
| 70 |
+
scaler.scale(loss).backward()
|
| 71 |
+
scaler.unscale_(optimizer)
|
| 72 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
| 73 |
+
scaler.step(optimizer)
|
| 74 |
+
scaler.update()
|
| 75 |
+
Args:
|
| 76 |
+
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
| 77 |
+
.. warning::
|
| 78 |
+
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
| 79 |
+
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
| 80 |
+
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
| 81 |
+
.. warning::
|
| 82 |
+
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
| 83 |
+
"""
|
| 84 |
+
if not self._enabled:
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
self._check_scale_growth_tracker("unscale_")
|
| 88 |
+
|
| 89 |
+
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
| 90 |
+
|
| 91 |
+
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
"unscale_() has already been called on this optimizer since the last update()."
|
| 94 |
+
)
|
| 95 |
+
elif optimizer_state["stage"] is OptState.STEPPED:
|
| 96 |
+
raise RuntimeError("unscale_() is being called after step().")
|
| 97 |
+
|
| 98 |
+
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
| 99 |
+
assert self._scale is not None
|
| 100 |
+
if device_supports_fp64:
|
| 101 |
+
inv_scale = self._scale.double().reciprocal().float()
|
| 102 |
+
else:
|
| 103 |
+
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
| 104 |
+
found_inf = torch.full(
|
| 105 |
+
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
| 109 |
+
optimizer, inv_scale, found_inf, False
|
| 110 |
+
)
|
| 111 |
+
optimizer_state["stage"] = OptState.UNSCALED
|
| 112 |
+
|
| 113 |
+
def update(self, new_scale=None):
|
| 114 |
+
"""
|
| 115 |
+
Updates the scale factor.
|
| 116 |
+
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
| 117 |
+
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
| 118 |
+
the scale is multiplied by ``growth_factor`` to increase it.
|
| 119 |
+
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
| 120 |
+
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
| 121 |
+
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
| 122 |
+
affect the scale GradScaler uses internally.)
|
| 123 |
+
Args:
|
| 124 |
+
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
| 125 |
+
.. warning::
|
| 126 |
+
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
| 127 |
+
been invoked for all optimizers used this iteration.
|
| 128 |
+
"""
|
| 129 |
+
if not self._enabled:
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
| 133 |
+
|
| 134 |
+
if new_scale is not None:
|
| 135 |
+
# Accept a new user-defined scale.
|
| 136 |
+
if isinstance(new_scale, float):
|
| 137 |
+
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
| 138 |
+
else:
|
| 139 |
+
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
| 140 |
+
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
| 141 |
+
assert new_scale.numel() == 1, reason
|
| 142 |
+
assert new_scale.requires_grad is False, reason
|
| 143 |
+
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
| 144 |
+
else:
|
| 145 |
+
# Consume shared inf/nan data collected from optimizers to update the scale.
|
| 146 |
+
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
| 147 |
+
found_infs = [
|
| 148 |
+
found_inf.to(device="cpu", non_blocking=True)
|
| 149 |
+
for state in self._per_optimizer_states.values()
|
| 150 |
+
for found_inf in state["found_inf_per_device"].values()
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
| 154 |
+
|
| 155 |
+
found_inf_combined = found_infs[0]
|
| 156 |
+
if len(found_infs) > 1:
|
| 157 |
+
for i in range(1, len(found_infs)):
|
| 158 |
+
found_inf_combined += found_infs[i]
|
| 159 |
+
|
| 160 |
+
to_device = _scale.device
|
| 161 |
+
_scale = _scale.to("cpu")
|
| 162 |
+
_growth_tracker = _growth_tracker.to("cpu")
|
| 163 |
+
|
| 164 |
+
core._amp_update_scale_(
|
| 165 |
+
_scale,
|
| 166 |
+
_growth_tracker,
|
| 167 |
+
found_inf_combined,
|
| 168 |
+
self._growth_factor,
|
| 169 |
+
self._backoff_factor,
|
| 170 |
+
self._growth_interval,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
_scale = _scale.to(to_device)
|
| 174 |
+
_growth_tracker = _growth_tracker.to(to_device)
|
| 175 |
+
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
| 176 |
+
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
| 177 |
+
|
| 178 |
+
def gradscaler_init():
|
| 179 |
+
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
| 180 |
+
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
| 181 |
+
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
| 182 |
+
torch.xpu.amp.GradScaler.update = update
|
| 183 |
+
return torch.xpu.amp.GradScaler
|
library/ipex/hijacks.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
import torch
|
| 5 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
| 9 |
+
|
| 10 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
| 11 |
+
|
| 12 |
+
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
| 13 |
+
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
| 14 |
+
if isinstance(device_ids, list) and len(device_ids) > 1:
|
| 15 |
+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
| 16 |
+
return module.to("xpu")
|
| 17 |
+
|
| 18 |
+
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
| 19 |
+
return nullcontext()
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def is_cuda(self):
|
| 23 |
+
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
| 24 |
+
|
| 25 |
+
def check_device(device):
|
| 26 |
+
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
| 27 |
+
|
| 28 |
+
def return_xpu(device):
|
| 29 |
+
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Autocast
|
| 33 |
+
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
| 34 |
+
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
| 35 |
+
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
| 36 |
+
if device_type == "cuda":
|
| 37 |
+
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
| 38 |
+
else:
|
| 39 |
+
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
| 40 |
+
|
| 41 |
+
# Latent Antialias CPU Offload:
|
| 42 |
+
original_interpolate = torch.nn.functional.interpolate
|
| 43 |
+
@wraps(torch.nn.functional.interpolate)
|
| 44 |
+
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
| 45 |
+
if antialias or align_corners is not None or mode == 'bicubic':
|
| 46 |
+
return_device = tensor.device
|
| 47 |
+
return_dtype = tensor.dtype
|
| 48 |
+
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
| 49 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
| 50 |
+
else:
|
| 51 |
+
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
| 52 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
| 56 |
+
original_from_numpy = torch.from_numpy
|
| 57 |
+
@wraps(torch.from_numpy)
|
| 58 |
+
def from_numpy(ndarray):
|
| 59 |
+
if ndarray.dtype == float:
|
| 60 |
+
return original_from_numpy(ndarray.astype('float32'))
|
| 61 |
+
else:
|
| 62 |
+
return original_from_numpy(ndarray)
|
| 63 |
+
|
| 64 |
+
original_as_tensor = torch.as_tensor
|
| 65 |
+
@wraps(torch.as_tensor)
|
| 66 |
+
def as_tensor(data, dtype=None, device=None):
|
| 67 |
+
if check_device(device):
|
| 68 |
+
device = return_xpu(device)
|
| 69 |
+
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
| 70 |
+
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
| 71 |
+
return original_as_tensor(data, dtype=torch.float32, device=device)
|
| 72 |
+
else:
|
| 73 |
+
return original_as_tensor(data, dtype=dtype, device=device)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
| 77 |
+
original_torch_bmm = torch.bmm
|
| 78 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 79 |
+
else:
|
| 80 |
+
# 32 bit attention workarounds for Alchemist:
|
| 81 |
+
try:
|
| 82 |
+
from .attention import torch_bmm_32_bit as original_torch_bmm
|
| 83 |
+
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
| 84 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 85 |
+
original_torch_bmm = torch.bmm
|
| 86 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Data Type Errors:
|
| 90 |
+
@wraps(torch.bmm)
|
| 91 |
+
def torch_bmm(input, mat2, *, out=None):
|
| 92 |
+
if input.dtype != mat2.dtype:
|
| 93 |
+
mat2 = mat2.to(input.dtype)
|
| 94 |
+
return original_torch_bmm(input, mat2, out=out)
|
| 95 |
+
|
| 96 |
+
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
| 97 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
| 98 |
+
if query.dtype != key.dtype:
|
| 99 |
+
key = key.to(dtype=query.dtype)
|
| 100 |
+
if query.dtype != value.dtype:
|
| 101 |
+
value = value.to(dtype=query.dtype)
|
| 102 |
+
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
| 103 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
| 104 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
| 105 |
+
|
| 106 |
+
# A1111 FP16
|
| 107 |
+
original_functional_group_norm = torch.nn.functional.group_norm
|
| 108 |
+
@wraps(torch.nn.functional.group_norm)
|
| 109 |
+
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
| 110 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
| 111 |
+
input = input.to(dtype=weight.data.dtype)
|
| 112 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
| 113 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 114 |
+
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
| 115 |
+
|
| 116 |
+
# A1111 BF16
|
| 117 |
+
original_functional_layer_norm = torch.nn.functional.layer_norm
|
| 118 |
+
@wraps(torch.nn.functional.layer_norm)
|
| 119 |
+
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
| 120 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
| 121 |
+
input = input.to(dtype=weight.data.dtype)
|
| 122 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
| 123 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 124 |
+
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
| 125 |
+
|
| 126 |
+
# Training
|
| 127 |
+
original_functional_linear = torch.nn.functional.linear
|
| 128 |
+
@wraps(torch.nn.functional.linear)
|
| 129 |
+
def functional_linear(input, weight, bias=None):
|
| 130 |
+
if input.dtype != weight.data.dtype:
|
| 131 |
+
input = input.to(dtype=weight.data.dtype)
|
| 132 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 133 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 134 |
+
return original_functional_linear(input, weight, bias=bias)
|
| 135 |
+
|
| 136 |
+
original_functional_conv2d = torch.nn.functional.conv2d
|
| 137 |
+
@wraps(torch.nn.functional.conv2d)
|
| 138 |
+
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 139 |
+
if input.dtype != weight.data.dtype:
|
| 140 |
+
input = input.to(dtype=weight.data.dtype)
|
| 141 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 142 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 143 |
+
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 144 |
+
|
| 145 |
+
# A1111 Embedding BF16
|
| 146 |
+
original_torch_cat = torch.cat
|
| 147 |
+
@wraps(torch.cat)
|
| 148 |
+
def torch_cat(tensor, *args, **kwargs):
|
| 149 |
+
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
| 150 |
+
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
| 151 |
+
else:
|
| 152 |
+
return original_torch_cat(tensor, *args, **kwargs)
|
| 153 |
+
|
| 154 |
+
# SwinIR BF16:
|
| 155 |
+
original_functional_pad = torch.nn.functional.pad
|
| 156 |
+
@wraps(torch.nn.functional.pad)
|
| 157 |
+
def functional_pad(input, pad, mode='constant', value=None):
|
| 158 |
+
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
| 159 |
+
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
| 160 |
+
else:
|
| 161 |
+
return original_functional_pad(input, pad, mode=mode, value=value)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
original_torch_tensor = torch.tensor
|
| 165 |
+
@wraps(torch.tensor)
|
| 166 |
+
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
| 167 |
+
if check_device(device):
|
| 168 |
+
device = return_xpu(device)
|
| 169 |
+
if not device_supports_fp64:
|
| 170 |
+
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
| 171 |
+
if dtype == torch.float64:
|
| 172 |
+
dtype = torch.float32
|
| 173 |
+
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
| 174 |
+
dtype = torch.float32
|
| 175 |
+
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
| 176 |
+
|
| 177 |
+
original_Tensor_to = torch.Tensor.to
|
| 178 |
+
@wraps(torch.Tensor.to)
|
| 179 |
+
def Tensor_to(self, device=None, *args, **kwargs):
|
| 180 |
+
if check_device(device):
|
| 181 |
+
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
| 182 |
+
else:
|
| 183 |
+
return original_Tensor_to(self, device, *args, **kwargs)
|
| 184 |
+
|
| 185 |
+
original_Tensor_cuda = torch.Tensor.cuda
|
| 186 |
+
@wraps(torch.Tensor.cuda)
|
| 187 |
+
def Tensor_cuda(self, device=None, *args, **kwargs):
|
| 188 |
+
if check_device(device):
|
| 189 |
+
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
| 190 |
+
else:
|
| 191 |
+
return original_Tensor_cuda(self, device, *args, **kwargs)
|
| 192 |
+
|
| 193 |
+
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
| 194 |
+
@wraps(torch.Tensor.pin_memory)
|
| 195 |
+
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
| 196 |
+
if device is None:
|
| 197 |
+
device = "xpu"
|
| 198 |
+
if check_device(device):
|
| 199 |
+
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
| 200 |
+
else:
|
| 201 |
+
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
| 202 |
+
|
| 203 |
+
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
| 204 |
+
@wraps(torch.UntypedStorage.__init__)
|
| 205 |
+
def UntypedStorage_init(*args, device=None, **kwargs):
|
| 206 |
+
if check_device(device):
|
| 207 |
+
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
| 208 |
+
else:
|
| 209 |
+
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
| 210 |
+
|
| 211 |
+
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
| 212 |
+
@wraps(torch.UntypedStorage.cuda)
|
| 213 |
+
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
| 214 |
+
if check_device(device):
|
| 215 |
+
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
| 216 |
+
else:
|
| 217 |
+
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
| 218 |
+
|
| 219 |
+
original_torch_empty = torch.empty
|
| 220 |
+
@wraps(torch.empty)
|
| 221 |
+
def torch_empty(*args, device=None, **kwargs):
|
| 222 |
+
if check_device(device):
|
| 223 |
+
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
| 224 |
+
else:
|
| 225 |
+
return original_torch_empty(*args, device=device, **kwargs)
|
| 226 |
+
|
| 227 |
+
original_torch_randn = torch.randn
|
| 228 |
+
@wraps(torch.randn)
|
| 229 |
+
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
| 230 |
+
if dtype == bytes:
|
| 231 |
+
dtype = None
|
| 232 |
+
if check_device(device):
|
| 233 |
+
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
| 234 |
+
else:
|
| 235 |
+
return original_torch_randn(*args, device=device, **kwargs)
|
| 236 |
+
|
| 237 |
+
original_torch_ones = torch.ones
|
| 238 |
+
@wraps(torch.ones)
|
| 239 |
+
def torch_ones(*args, device=None, **kwargs):
|
| 240 |
+
if check_device(device):
|
| 241 |
+
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
| 242 |
+
else:
|
| 243 |
+
return original_torch_ones(*args, device=device, **kwargs)
|
| 244 |
+
|
| 245 |
+
original_torch_zeros = torch.zeros
|
| 246 |
+
@wraps(torch.zeros)
|
| 247 |
+
def torch_zeros(*args, device=None, **kwargs):
|
| 248 |
+
if check_device(device):
|
| 249 |
+
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
| 250 |
+
else:
|
| 251 |
+
return original_torch_zeros(*args, device=device, **kwargs)
|
| 252 |
+
|
| 253 |
+
original_torch_linspace = torch.linspace
|
| 254 |
+
@wraps(torch.linspace)
|
| 255 |
+
def torch_linspace(*args, device=None, **kwargs):
|
| 256 |
+
if check_device(device):
|
| 257 |
+
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
| 258 |
+
else:
|
| 259 |
+
return original_torch_linspace(*args, device=device, **kwargs)
|
| 260 |
+
|
| 261 |
+
original_torch_Generator = torch.Generator
|
| 262 |
+
@wraps(torch.Generator)
|
| 263 |
+
def torch_Generator(device=None):
|
| 264 |
+
if check_device(device):
|
| 265 |
+
return original_torch_Generator(return_xpu(device))
|
| 266 |
+
else:
|
| 267 |
+
return original_torch_Generator(device)
|
| 268 |
+
|
| 269 |
+
original_torch_load = torch.load
|
| 270 |
+
@wraps(torch.load)
|
| 271 |
+
def torch_load(f, map_location=None, *args, **kwargs):
|
| 272 |
+
if map_location is None:
|
| 273 |
+
map_location = "xpu"
|
| 274 |
+
if check_device(map_location):
|
| 275 |
+
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
| 276 |
+
else:
|
| 277 |
+
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Hijack Functions:
|
| 281 |
+
def ipex_hijacks():
|
| 282 |
+
torch.tensor = torch_tensor
|
| 283 |
+
torch.Tensor.to = Tensor_to
|
| 284 |
+
torch.Tensor.cuda = Tensor_cuda
|
| 285 |
+
torch.Tensor.pin_memory = Tensor_pin_memory
|
| 286 |
+
torch.UntypedStorage.__init__ = UntypedStorage_init
|
| 287 |
+
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
| 288 |
+
torch.empty = torch_empty
|
| 289 |
+
torch.randn = torch_randn
|
| 290 |
+
torch.ones = torch_ones
|
| 291 |
+
torch.zeros = torch_zeros
|
| 292 |
+
torch.linspace = torch_linspace
|
| 293 |
+
torch.Generator = torch_Generator
|
| 294 |
+
torch.load = torch_load
|
| 295 |
+
|
| 296 |
+
torch.backends.cuda.sdp_kernel = return_null_context
|
| 297 |
+
torch.nn.DataParallel = DummyDataParallel
|
| 298 |
+
torch.UntypedStorage.is_cuda = is_cuda
|
| 299 |
+
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
| 300 |
+
|
| 301 |
+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
| 302 |
+
torch.nn.functional.group_norm = functional_group_norm
|
| 303 |
+
torch.nn.functional.layer_norm = functional_layer_norm
|
| 304 |
+
torch.nn.functional.linear = functional_linear
|
| 305 |
+
torch.nn.functional.conv2d = functional_conv2d
|
| 306 |
+
torch.nn.functional.interpolate = interpolate
|
| 307 |
+
torch.nn.functional.pad = functional_pad
|
| 308 |
+
|
| 309 |
+
torch.bmm = torch_bmm
|
| 310 |
+
torch.cat = torch_cat
|
| 311 |
+
if not device_supports_fp64:
|
| 312 |
+
torch.from_numpy = from_numpy
|
| 313 |
+
torch.as_tensor = as_tensor
|
library/lpw_stable_diffusion.py
ADDED
|
@@ -0,0 +1,1233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
| 2 |
+
# and modify to support SD2.x
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
import re
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import torch
|
| 11 |
+
from packaging import version
|
| 12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 13 |
+
|
| 14 |
+
import diffusers
|
| 15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
+
from diffusers.utils import logging
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from diffusers.utils import PIL_INTERPOLATION
|
| 22 |
+
except ImportError:
|
| 23 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 24 |
+
PIL_INTERPOLATION = {
|
| 25 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 26 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 27 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 28 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 29 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 30 |
+
}
|
| 31 |
+
else:
|
| 32 |
+
PIL_INTERPOLATION = {
|
| 33 |
+
"linear": PIL.Image.LINEAR,
|
| 34 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 35 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 36 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 37 |
+
"nearest": PIL.Image.NEAREST,
|
| 38 |
+
}
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
re_attention = re.compile(
|
| 44 |
+
r"""
|
| 45 |
+
\\\(|
|
| 46 |
+
\\\)|
|
| 47 |
+
\\\[|
|
| 48 |
+
\\]|
|
| 49 |
+
\\\\|
|
| 50 |
+
\\|
|
| 51 |
+
\(|
|
| 52 |
+
\[|
|
| 53 |
+
:([+-]?[.\d]+)\)|
|
| 54 |
+
\)|
|
| 55 |
+
]|
|
| 56 |
+
[^\\()\[\]:]+|
|
| 57 |
+
:
|
| 58 |
+
""",
|
| 59 |
+
re.X,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def parse_prompt_attention(text):
|
| 64 |
+
"""
|
| 65 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 66 |
+
Accepted tokens are:
|
| 67 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 68 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 69 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 70 |
+
\( - literal character '('
|
| 71 |
+
\[ - literal character '['
|
| 72 |
+
\) - literal character ')'
|
| 73 |
+
\] - literal character ']'
|
| 74 |
+
\\ - literal character '\'
|
| 75 |
+
anything else - just text
|
| 76 |
+
>>> parse_prompt_attention('normal text')
|
| 77 |
+
[['normal text', 1.0]]
|
| 78 |
+
>>> parse_prompt_attention('an (important) word')
|
| 79 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 80 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 81 |
+
[['unbalanced', 1.1]]
|
| 82 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 83 |
+
[['(literal]', 1.0]]
|
| 84 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 85 |
+
[['unnecessaryparens', 1.1]]
|
| 86 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 87 |
+
[['a ', 1.0],
|
| 88 |
+
['house', 1.5730000000000004],
|
| 89 |
+
[' ', 1.1],
|
| 90 |
+
['on', 1.0],
|
| 91 |
+
[' a ', 1.1],
|
| 92 |
+
['hill', 0.55],
|
| 93 |
+
[', sun, ', 1.1],
|
| 94 |
+
['sky', 1.4641000000000006],
|
| 95 |
+
['.', 1.1]]
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
res = []
|
| 99 |
+
round_brackets = []
|
| 100 |
+
square_brackets = []
|
| 101 |
+
|
| 102 |
+
round_bracket_multiplier = 1.1
|
| 103 |
+
square_bracket_multiplier = 1 / 1.1
|
| 104 |
+
|
| 105 |
+
def multiply_range(start_position, multiplier):
|
| 106 |
+
for p in range(start_position, len(res)):
|
| 107 |
+
res[p][1] *= multiplier
|
| 108 |
+
|
| 109 |
+
for m in re_attention.finditer(text):
|
| 110 |
+
text = m.group(0)
|
| 111 |
+
weight = m.group(1)
|
| 112 |
+
|
| 113 |
+
if text.startswith("\\"):
|
| 114 |
+
res.append([text[1:], 1.0])
|
| 115 |
+
elif text == "(":
|
| 116 |
+
round_brackets.append(len(res))
|
| 117 |
+
elif text == "[":
|
| 118 |
+
square_brackets.append(len(res))
|
| 119 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 120 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 121 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 122 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 123 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 124 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 125 |
+
else:
|
| 126 |
+
res.append([text, 1.0])
|
| 127 |
+
|
| 128 |
+
for pos in round_brackets:
|
| 129 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 130 |
+
|
| 131 |
+
for pos in square_brackets:
|
| 132 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 133 |
+
|
| 134 |
+
if len(res) == 0:
|
| 135 |
+
res = [["", 1.0]]
|
| 136 |
+
|
| 137 |
+
# merge runs of identical weights
|
| 138 |
+
i = 0
|
| 139 |
+
while i + 1 < len(res):
|
| 140 |
+
if res[i][1] == res[i + 1][1]:
|
| 141 |
+
res[i][0] += res[i + 1][0]
|
| 142 |
+
res.pop(i + 1)
|
| 143 |
+
else:
|
| 144 |
+
i += 1
|
| 145 |
+
|
| 146 |
+
return res
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
| 150 |
+
r"""
|
| 151 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 152 |
+
|
| 153 |
+
No padding, starting or ending token is included.
|
| 154 |
+
"""
|
| 155 |
+
tokens = []
|
| 156 |
+
weights = []
|
| 157 |
+
truncated = False
|
| 158 |
+
for text in prompt:
|
| 159 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 160 |
+
text_token = []
|
| 161 |
+
text_weight = []
|
| 162 |
+
for word, weight in texts_and_weights:
|
| 163 |
+
# tokenize and discard the starting and the ending token
|
| 164 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
| 165 |
+
text_token += token
|
| 166 |
+
# copy the weight by length of token
|
| 167 |
+
text_weight += [weight] * len(token)
|
| 168 |
+
# stop if the text is too long (longer than truncation limit)
|
| 169 |
+
if len(text_token) > max_length:
|
| 170 |
+
truncated = True
|
| 171 |
+
break
|
| 172 |
+
# truncate
|
| 173 |
+
if len(text_token) > max_length:
|
| 174 |
+
truncated = True
|
| 175 |
+
text_token = text_token[:max_length]
|
| 176 |
+
text_weight = text_weight[:max_length]
|
| 177 |
+
tokens.append(text_token)
|
| 178 |
+
weights.append(text_weight)
|
| 179 |
+
if truncated:
|
| 180 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 181 |
+
return tokens, weights
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 185 |
+
r"""
|
| 186 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 187 |
+
"""
|
| 188 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 189 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 190 |
+
for i in range(len(tokens)):
|
| 191 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 192 |
+
if no_boseos_middle:
|
| 193 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 194 |
+
else:
|
| 195 |
+
w = []
|
| 196 |
+
if len(weights[i]) == 0:
|
| 197 |
+
w = [1.0] * weights_length
|
| 198 |
+
else:
|
| 199 |
+
for j in range(max_embeddings_multiples):
|
| 200 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 201 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 202 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 203 |
+
w += [1.0] * (weights_length - len(w))
|
| 204 |
+
weights[i] = w[:]
|
| 205 |
+
|
| 206 |
+
return tokens, weights
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_unweighted_text_embeddings(
|
| 210 |
+
pipe: StableDiffusionPipeline,
|
| 211 |
+
text_input: torch.Tensor,
|
| 212 |
+
chunk_length: int,
|
| 213 |
+
clip_skip: int,
|
| 214 |
+
eos: int,
|
| 215 |
+
pad: int,
|
| 216 |
+
no_boseos_middle: Optional[bool] = True,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 220 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 221 |
+
"""
|
| 222 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 223 |
+
if max_embeddings_multiples > 1:
|
| 224 |
+
text_embeddings = []
|
| 225 |
+
for i in range(max_embeddings_multiples):
|
| 226 |
+
# extract the i-th chunk
|
| 227 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 228 |
+
|
| 229 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 230 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 231 |
+
if pad == eos: # v1
|
| 232 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 233 |
+
else: # v2
|
| 234 |
+
for j in range(len(text_input_chunk)):
|
| 235 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 236 |
+
text_input_chunk[j, -1] = eos
|
| 237 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 238 |
+
text_input_chunk[j, 1] = eos
|
| 239 |
+
|
| 240 |
+
if clip_skip is None or clip_skip == 1:
|
| 241 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
| 242 |
+
else:
|
| 243 |
+
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
| 244 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
| 245 |
+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
| 246 |
+
|
| 247 |
+
if no_boseos_middle:
|
| 248 |
+
if i == 0:
|
| 249 |
+
# discard the ending token
|
| 250 |
+
text_embedding = text_embedding[:, :-1]
|
| 251 |
+
elif i == max_embeddings_multiples - 1:
|
| 252 |
+
# discard the starting token
|
| 253 |
+
text_embedding = text_embedding[:, 1:]
|
| 254 |
+
else:
|
| 255 |
+
# discard both starting and ending tokens
|
| 256 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 257 |
+
|
| 258 |
+
text_embeddings.append(text_embedding)
|
| 259 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 260 |
+
else:
|
| 261 |
+
if clip_skip is None or clip_skip == 1:
|
| 262 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
| 263 |
+
else:
|
| 264 |
+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
| 265 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
| 266 |
+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
| 267 |
+
return text_embeddings
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def get_weighted_text_embeddings(
|
| 271 |
+
pipe: StableDiffusionPipeline,
|
| 272 |
+
prompt: Union[str, List[str]],
|
| 273 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
| 274 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 275 |
+
no_boseos_middle: Optional[bool] = False,
|
| 276 |
+
skip_parsing: Optional[bool] = False,
|
| 277 |
+
skip_weighting: Optional[bool] = False,
|
| 278 |
+
clip_skip=None,
|
| 279 |
+
):
|
| 280 |
+
r"""
|
| 281 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 282 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 283 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 284 |
+
|
| 285 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
pipe (`StableDiffusionPipeline`):
|
| 289 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
| 290 |
+
prompt (`str` or `List[str]`):
|
| 291 |
+
The prompt or prompts to guide the image generation.
|
| 292 |
+
uncond_prompt (`str` or `List[str]`):
|
| 293 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
| 294 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
| 295 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 296 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 297 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 298 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 299 |
+
ending token in each of the chunk in the middle.
|
| 300 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 301 |
+
Skip the parsing of brackets.
|
| 302 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 303 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 304 |
+
"""
|
| 305 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 306 |
+
if isinstance(prompt, str):
|
| 307 |
+
prompt = [prompt]
|
| 308 |
+
|
| 309 |
+
if not skip_parsing:
|
| 310 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
| 311 |
+
if uncond_prompt is not None:
|
| 312 |
+
if isinstance(uncond_prompt, str):
|
| 313 |
+
uncond_prompt = [uncond_prompt]
|
| 314 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
| 315 |
+
else:
|
| 316 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
| 317 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
| 318 |
+
if uncond_prompt is not None:
|
| 319 |
+
if isinstance(uncond_prompt, str):
|
| 320 |
+
uncond_prompt = [uncond_prompt]
|
| 321 |
+
uncond_tokens = [
|
| 322 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
| 323 |
+
]
|
| 324 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
| 325 |
+
|
| 326 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 327 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 328 |
+
if uncond_prompt is not None:
|
| 329 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
| 330 |
+
|
| 331 |
+
max_embeddings_multiples = min(
|
| 332 |
+
max_embeddings_multiples,
|
| 333 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
| 334 |
+
)
|
| 335 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 336 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 337 |
+
|
| 338 |
+
# pad the length of tokens and weights
|
| 339 |
+
bos = pipe.tokenizer.bos_token_id
|
| 340 |
+
eos = pipe.tokenizer.eos_token_id
|
| 341 |
+
pad = pipe.tokenizer.pad_token_id
|
| 342 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 343 |
+
prompt_tokens,
|
| 344 |
+
prompt_weights,
|
| 345 |
+
max_length,
|
| 346 |
+
bos,
|
| 347 |
+
eos,
|
| 348 |
+
no_boseos_middle=no_boseos_middle,
|
| 349 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 350 |
+
)
|
| 351 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
| 352 |
+
if uncond_prompt is not None:
|
| 353 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 354 |
+
uncond_tokens,
|
| 355 |
+
uncond_weights,
|
| 356 |
+
max_length,
|
| 357 |
+
bos,
|
| 358 |
+
eos,
|
| 359 |
+
no_boseos_middle=no_boseos_middle,
|
| 360 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 361 |
+
)
|
| 362 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
| 363 |
+
|
| 364 |
+
# get the embeddings
|
| 365 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 366 |
+
pipe,
|
| 367 |
+
prompt_tokens,
|
| 368 |
+
pipe.tokenizer.model_max_length,
|
| 369 |
+
clip_skip,
|
| 370 |
+
eos,
|
| 371 |
+
pad,
|
| 372 |
+
no_boseos_middle=no_boseos_middle,
|
| 373 |
+
)
|
| 374 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
| 375 |
+
if uncond_prompt is not None:
|
| 376 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
| 377 |
+
pipe,
|
| 378 |
+
uncond_tokens,
|
| 379 |
+
pipe.tokenizer.model_max_length,
|
| 380 |
+
clip_skip,
|
| 381 |
+
eos,
|
| 382 |
+
pad,
|
| 383 |
+
no_boseos_middle=no_boseos_middle,
|
| 384 |
+
)
|
| 385 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
| 386 |
+
|
| 387 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 388 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
| 389 |
+
if (not skip_parsing) and (not skip_weighting):
|
| 390 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 391 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
| 392 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 393 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 394 |
+
if uncond_prompt is not None:
|
| 395 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 396 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
| 397 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 398 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 399 |
+
|
| 400 |
+
if uncond_prompt is not None:
|
| 401 |
+
return text_embeddings, uncond_embeddings
|
| 402 |
+
return text_embeddings, None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def preprocess_image(image):
|
| 406 |
+
w, h = image.size
|
| 407 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 408 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
| 409 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 410 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 411 |
+
image = torch.from_numpy(image)
|
| 412 |
+
return 2.0 * image - 1.0
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def preprocess_mask(mask, scale_factor=8):
|
| 416 |
+
mask = mask.convert("L")
|
| 417 |
+
w, h = mask.size
|
| 418 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 419 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
| 420 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
| 421 |
+
mask = np.tile(mask, (4, 1, 1))
|
| 422 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
| 423 |
+
mask = 1 - mask # repaint white, keep black
|
| 424 |
+
mask = torch.from_numpy(mask)
|
| 425 |
+
return mask
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def prepare_controlnet_image(
|
| 429 |
+
image: PIL.Image.Image,
|
| 430 |
+
width: int,
|
| 431 |
+
height: int,
|
| 432 |
+
batch_size: int,
|
| 433 |
+
num_images_per_prompt: int,
|
| 434 |
+
device: torch.device,
|
| 435 |
+
dtype: torch.dtype,
|
| 436 |
+
do_classifier_free_guidance: bool = False,
|
| 437 |
+
guess_mode: bool = False,
|
| 438 |
+
):
|
| 439 |
+
if not isinstance(image, torch.Tensor):
|
| 440 |
+
if isinstance(image, PIL.Image.Image):
|
| 441 |
+
image = [image]
|
| 442 |
+
|
| 443 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 444 |
+
images = []
|
| 445 |
+
|
| 446 |
+
for image_ in image:
|
| 447 |
+
image_ = image_.convert("RGB")
|
| 448 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
| 449 |
+
image_ = np.array(image_)
|
| 450 |
+
image_ = image_[None, :]
|
| 451 |
+
images.append(image_)
|
| 452 |
+
|
| 453 |
+
image = images
|
| 454 |
+
|
| 455 |
+
image = np.concatenate(image, axis=0)
|
| 456 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 457 |
+
image = image.transpose(0, 3, 1, 2)
|
| 458 |
+
image = torch.from_numpy(image)
|
| 459 |
+
elif isinstance(image[0], torch.Tensor):
|
| 460 |
+
image = torch.cat(image, dim=0)
|
| 461 |
+
|
| 462 |
+
image_batch_size = image.shape[0]
|
| 463 |
+
|
| 464 |
+
if image_batch_size == 1:
|
| 465 |
+
repeat_by = batch_size
|
| 466 |
+
else:
|
| 467 |
+
# image batch size is the same as prompt batch size
|
| 468 |
+
repeat_by = num_images_per_prompt
|
| 469 |
+
|
| 470 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 471 |
+
|
| 472 |
+
image = image.to(device=device, dtype=dtype)
|
| 473 |
+
|
| 474 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 475 |
+
image = torch.cat([image] * 2)
|
| 476 |
+
|
| 477 |
+
return image
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
| 481 |
+
r"""
|
| 482 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
| 483 |
+
weighting in prompt.
|
| 484 |
+
|
| 485 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 486 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
vae ([`AutoencoderKL`]):
|
| 490 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 491 |
+
text_encoder ([`CLIPTextModel`]):
|
| 492 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 493 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 494 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 495 |
+
tokenizer (`CLIPTokenizer`):
|
| 496 |
+
Tokenizer of class
|
| 497 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 498 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 499 |
+
scheduler ([`SchedulerMixin`]):
|
| 500 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 501 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 502 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 503 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 504 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
| 505 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
| 506 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
| 510 |
+
|
| 511 |
+
def __init__(
|
| 512 |
+
self,
|
| 513 |
+
vae: AutoencoderKL,
|
| 514 |
+
text_encoder: CLIPTextModel,
|
| 515 |
+
tokenizer: CLIPTokenizer,
|
| 516 |
+
unet: UNet2DConditionModel,
|
| 517 |
+
scheduler: SchedulerMixin,
|
| 518 |
+
# clip_skip: int,
|
| 519 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 520 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 521 |
+
requires_safety_checker: bool = True,
|
| 522 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 523 |
+
clip_skip: int = 1,
|
| 524 |
+
):
|
| 525 |
+
super().__init__(
|
| 526 |
+
vae=vae,
|
| 527 |
+
text_encoder=text_encoder,
|
| 528 |
+
tokenizer=tokenizer,
|
| 529 |
+
unet=unet,
|
| 530 |
+
scheduler=scheduler,
|
| 531 |
+
safety_checker=safety_checker,
|
| 532 |
+
feature_extractor=feature_extractor,
|
| 533 |
+
requires_safety_checker=requires_safety_checker,
|
| 534 |
+
image_encoder=image_encoder,
|
| 535 |
+
)
|
| 536 |
+
self.custom_clip_skip = clip_skip
|
| 537 |
+
self.__init__additional__()
|
| 538 |
+
|
| 539 |
+
def __init__additional__(self):
|
| 540 |
+
if not hasattr(self, "vae_scale_factor"):
|
| 541 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
| 542 |
+
|
| 543 |
+
@property
|
| 544 |
+
def _execution_device(self):
|
| 545 |
+
r"""
|
| 546 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
| 547 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
| 548 |
+
hooks.
|
| 549 |
+
"""
|
| 550 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
| 551 |
+
return self.device
|
| 552 |
+
for module in self.unet.modules():
|
| 553 |
+
if (
|
| 554 |
+
hasattr(module, "_hf_hook")
|
| 555 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 556 |
+
and module._hf_hook.execution_device is not None
|
| 557 |
+
):
|
| 558 |
+
return torch.device(module._hf_hook.execution_device)
|
| 559 |
+
return self.device
|
| 560 |
+
|
| 561 |
+
def _encode_prompt(
|
| 562 |
+
self,
|
| 563 |
+
prompt,
|
| 564 |
+
device,
|
| 565 |
+
num_images_per_prompt,
|
| 566 |
+
do_classifier_free_guidance,
|
| 567 |
+
negative_prompt,
|
| 568 |
+
max_embeddings_multiples,
|
| 569 |
+
):
|
| 570 |
+
r"""
|
| 571 |
+
Encodes the prompt into text encoder hidden states.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
prompt (`str` or `list(int)`):
|
| 575 |
+
prompt to be encoded
|
| 576 |
+
device: (`torch.device`):
|
| 577 |
+
torch device
|
| 578 |
+
num_images_per_prompt (`int`):
|
| 579 |
+
number of images that should be generated per prompt
|
| 580 |
+
do_classifier_free_guidance (`bool`):
|
| 581 |
+
whether to use classifier free guidance or not
|
| 582 |
+
negative_prompt (`str` or `List[str]`):
|
| 583 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 584 |
+
if `guidance_scale` is less than `1`).
|
| 585 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 586 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 587 |
+
"""
|
| 588 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 589 |
+
|
| 590 |
+
if negative_prompt is None:
|
| 591 |
+
negative_prompt = [""] * batch_size
|
| 592 |
+
elif isinstance(negative_prompt, str):
|
| 593 |
+
negative_prompt = [negative_prompt] * batch_size
|
| 594 |
+
if batch_size != len(negative_prompt):
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 597 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 598 |
+
" the batch size of `prompt`."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
| 602 |
+
pipe=self,
|
| 603 |
+
prompt=prompt,
|
| 604 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
| 605 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 606 |
+
clip_skip=self.custom_clip_skip,
|
| 607 |
+
)
|
| 608 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
| 609 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 610 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 611 |
+
|
| 612 |
+
if do_classifier_free_guidance:
|
| 613 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
| 614 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 615 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 616 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 617 |
+
|
| 618 |
+
return text_embeddings
|
| 619 |
+
|
| 620 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
| 621 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
| 622 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 623 |
+
|
| 624 |
+
if strength < 0 or strength > 1:
|
| 625 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 626 |
+
|
| 627 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 628 |
+
logger.info(f'{height} {width}')
|
| 629 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 630 |
+
|
| 631 |
+
if (callback_steps is None) or (
|
| 632 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 633 |
+
):
|
| 634 |
+
raise ValueError(
|
| 635 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
| 639 |
+
if is_text2img:
|
| 640 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
| 641 |
+
else:
|
| 642 |
+
# get the original timestep using init_timestep
|
| 643 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
| 644 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
| 645 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
| 646 |
+
|
| 647 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 648 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
| 649 |
+
return timesteps, num_inference_steps - t_start
|
| 650 |
+
|
| 651 |
+
def run_safety_checker(self, image, device, dtype):
|
| 652 |
+
if self.safety_checker is not None:
|
| 653 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
| 654 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
| 655 |
+
else:
|
| 656 |
+
has_nsfw_concept = None
|
| 657 |
+
return image, has_nsfw_concept
|
| 658 |
+
|
| 659 |
+
def decode_latents(self, latents):
|
| 660 |
+
latents = 1 / 0.18215 * latents
|
| 661 |
+
image = self.vae.decode(latents).sample
|
| 662 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 663 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 664 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 665 |
+
return image
|
| 666 |
+
|
| 667 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 668 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 669 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 670 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 671 |
+
# and should be between [0, 1]
|
| 672 |
+
|
| 673 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 674 |
+
extra_step_kwargs = {}
|
| 675 |
+
if accepts_eta:
|
| 676 |
+
extra_step_kwargs["eta"] = eta
|
| 677 |
+
|
| 678 |
+
# check if the scheduler accepts generator
|
| 679 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 680 |
+
if accepts_generator:
|
| 681 |
+
extra_step_kwargs["generator"] = generator
|
| 682 |
+
return extra_step_kwargs
|
| 683 |
+
|
| 684 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
| 685 |
+
if image is None:
|
| 686 |
+
shape = (
|
| 687 |
+
batch_size,
|
| 688 |
+
self.unet.in_channels,
|
| 689 |
+
height // self.vae_scale_factor,
|
| 690 |
+
width // self.vae_scale_factor,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if latents is None:
|
| 694 |
+
if device.type == "mps":
|
| 695 |
+
# randn does not work reproducibly on mps
|
| 696 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 697 |
+
else:
|
| 698 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 699 |
+
else:
|
| 700 |
+
if latents.shape != shape:
|
| 701 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 702 |
+
latents = latents.to(device)
|
| 703 |
+
|
| 704 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 705 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 706 |
+
return latents, None, None
|
| 707 |
+
else:
|
| 708 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
| 709 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
| 710 |
+
init_latents = 0.18215 * init_latents
|
| 711 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
| 712 |
+
init_latents_orig = init_latents
|
| 713 |
+
shape = init_latents.shape
|
| 714 |
+
|
| 715 |
+
# add noise to latents using the timesteps
|
| 716 |
+
if device.type == "mps":
|
| 717 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 718 |
+
else:
|
| 719 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 720 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 721 |
+
return latents, init_latents_orig, noise
|
| 722 |
+
|
| 723 |
+
@torch.no_grad()
|
| 724 |
+
def __call__(
|
| 725 |
+
self,
|
| 726 |
+
prompt: Union[str, List[str]],
|
| 727 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 728 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 729 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 730 |
+
height: int = 512,
|
| 731 |
+
width: int = 512,
|
| 732 |
+
num_inference_steps: int = 50,
|
| 733 |
+
guidance_scale: float = 7.5,
|
| 734 |
+
strength: float = 0.8,
|
| 735 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 736 |
+
eta: float = 0.0,
|
| 737 |
+
generator: Optional[torch.Generator] = None,
|
| 738 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 739 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 740 |
+
output_type: Optional[str] = "pil",
|
| 741 |
+
return_dict: bool = True,
|
| 742 |
+
controlnet=None,
|
| 743 |
+
controlnet_image=None,
|
| 744 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 745 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 746 |
+
callback_steps: int = 1,
|
| 747 |
+
):
|
| 748 |
+
r"""
|
| 749 |
+
Function invoked when calling the pipeline for generation.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
prompt (`str` or `List[str]`):
|
| 753 |
+
The prompt or prompts to guide the image generation.
|
| 754 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 755 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 756 |
+
if `guidance_scale` is less than `1`).
|
| 757 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 758 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 759 |
+
process.
|
| 760 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 761 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 762 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 763 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 764 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 765 |
+
height (`int`, *optional*, defaults to 512):
|
| 766 |
+
The height in pixels of the generated image.
|
| 767 |
+
width (`int`, *optional*, defaults to 512):
|
| 768 |
+
The width in pixels of the generated image.
|
| 769 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 770 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 771 |
+
expense of slower inference.
|
| 772 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 773 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 774 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 775 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 776 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 777 |
+
usually at the expense of lower image quality.
|
| 778 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 779 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 780 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 781 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 782 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 783 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 784 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 785 |
+
The number of images to generate per prompt.
|
| 786 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 787 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 788 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 789 |
+
generator (`torch.Generator`, *optional*):
|
| 790 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 791 |
+
deterministic.
|
| 792 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 793 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 794 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 795 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 796 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 797 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 798 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 799 |
+
The output format of the generate image. Choose between
|
| 800 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 801 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 802 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 803 |
+
plain tuple.
|
| 804 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
| 805 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
| 806 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
| 807 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
| 808 |
+
inference.
|
| 809 |
+
callback (`Callable`, *optional*):
|
| 810 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 811 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 812 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 813 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 814 |
+
`True`, the inference will be cancelled.
|
| 815 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 816 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 817 |
+
called at every step.
|
| 818 |
+
|
| 819 |
+
Returns:
|
| 820 |
+
`None` if cancelled by `is_cancelled_callback`,
|
| 821 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 822 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 823 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 824 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 825 |
+
(nsfw) content, according to the `safety_checker`.
|
| 826 |
+
"""
|
| 827 |
+
if controlnet is not None and controlnet_image is None:
|
| 828 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
| 829 |
+
|
| 830 |
+
# 0. Default height and width to unet
|
| 831 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 832 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 833 |
+
|
| 834 |
+
# 1. Check inputs. Raise error if not correct
|
| 835 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
| 836 |
+
|
| 837 |
+
# 2. Define call parameters
|
| 838 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 839 |
+
device = self._execution_device
|
| 840 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 841 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 842 |
+
# corresponds to doing no classifier free guidance.
|
| 843 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 844 |
+
|
| 845 |
+
# 3. Encode input prompt
|
| 846 |
+
text_embeddings = self._encode_prompt(
|
| 847 |
+
prompt,
|
| 848 |
+
device,
|
| 849 |
+
num_images_per_prompt,
|
| 850 |
+
do_classifier_free_guidance,
|
| 851 |
+
negative_prompt,
|
| 852 |
+
max_embeddings_multiples,
|
| 853 |
+
)
|
| 854 |
+
dtype = text_embeddings.dtype
|
| 855 |
+
|
| 856 |
+
# 4. Preprocess image and mask
|
| 857 |
+
if isinstance(image, PIL.Image.Image):
|
| 858 |
+
image = preprocess_image(image)
|
| 859 |
+
if image is not None:
|
| 860 |
+
image = image.to(device=self.device, dtype=dtype)
|
| 861 |
+
if isinstance(mask_image, PIL.Image.Image):
|
| 862 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
| 863 |
+
if mask_image is not None:
|
| 864 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
| 865 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
| 866 |
+
else:
|
| 867 |
+
mask = None
|
| 868 |
+
|
| 869 |
+
if controlnet_image is not None:
|
| 870 |
+
controlnet_image = prepare_controlnet_image(
|
| 871 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# 5. set timesteps
|
| 875 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 876 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
| 877 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 878 |
+
|
| 879 |
+
# 6. Prepare latent variables
|
| 880 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
| 881 |
+
image,
|
| 882 |
+
latent_timestep,
|
| 883 |
+
batch_size * num_images_per_prompt,
|
| 884 |
+
height,
|
| 885 |
+
width,
|
| 886 |
+
dtype,
|
| 887 |
+
device,
|
| 888 |
+
generator,
|
| 889 |
+
latents,
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 893 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 894 |
+
|
| 895 |
+
# 8. Denoising loop
|
| 896 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 897 |
+
# expand the latents if we are doing classifier free guidance
|
| 898 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 899 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 900 |
+
|
| 901 |
+
unet_additional_args = {}
|
| 902 |
+
if controlnet is not None:
|
| 903 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
| 904 |
+
latent_model_input,
|
| 905 |
+
t,
|
| 906 |
+
encoder_hidden_states=text_embeddings,
|
| 907 |
+
controlnet_cond=controlnet_image,
|
| 908 |
+
conditioning_scale=1.0,
|
| 909 |
+
guess_mode=False,
|
| 910 |
+
return_dict=False,
|
| 911 |
+
)
|
| 912 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
| 913 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
| 914 |
+
|
| 915 |
+
# predict the noise residual
|
| 916 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
| 917 |
+
|
| 918 |
+
# perform guidance
|
| 919 |
+
if do_classifier_free_guidance:
|
| 920 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 921 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 922 |
+
|
| 923 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 924 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 925 |
+
|
| 926 |
+
if mask is not None:
|
| 927 |
+
# masking
|
| 928 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
| 929 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 930 |
+
|
| 931 |
+
# call the callback, if provided
|
| 932 |
+
if i % callback_steps == 0:
|
| 933 |
+
if callback is not None:
|
| 934 |
+
callback(i, t, latents)
|
| 935 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 936 |
+
return None
|
| 937 |
+
|
| 938 |
+
return latents
|
| 939 |
+
|
| 940 |
+
def latents_to_image(self, latents):
|
| 941 |
+
# 9. Post-processing
|
| 942 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
| 943 |
+
image = self.numpy_to_pil(image)
|
| 944 |
+
return image
|
| 945 |
+
|
| 946 |
+
def text2img(
|
| 947 |
+
self,
|
| 948 |
+
prompt: Union[str, List[str]],
|
| 949 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 950 |
+
height: int = 512,
|
| 951 |
+
width: int = 512,
|
| 952 |
+
num_inference_steps: int = 50,
|
| 953 |
+
guidance_scale: float = 7.5,
|
| 954 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 955 |
+
eta: float = 0.0,
|
| 956 |
+
generator: Optional[torch.Generator] = None,
|
| 957 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 958 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 959 |
+
output_type: Optional[str] = "pil",
|
| 960 |
+
return_dict: bool = True,
|
| 961 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 962 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 963 |
+
callback_steps: int = 1,
|
| 964 |
+
):
|
| 965 |
+
r"""
|
| 966 |
+
Function for text-to-image generation.
|
| 967 |
+
Args:
|
| 968 |
+
prompt (`str` or `List[str]`):
|
| 969 |
+
The prompt or prompts to guide the image generation.
|
| 970 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 971 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 972 |
+
if `guidance_scale` is less than `1`).
|
| 973 |
+
height (`int`, *optional*, defaults to 512):
|
| 974 |
+
The height in pixels of the generated image.
|
| 975 |
+
width (`int`, *optional*, defaults to 512):
|
| 976 |
+
The width in pixels of the generated image.
|
| 977 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 978 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 979 |
+
expense of slower inference.
|
| 980 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 981 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 982 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 983 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 984 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 985 |
+
usually at the expense of lower image quality.
|
| 986 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 987 |
+
The number of images to generate per prompt.
|
| 988 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 989 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 990 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 991 |
+
generator (`torch.Generator`, *optional*):
|
| 992 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 993 |
+
deterministic.
|
| 994 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 995 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 996 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 997 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 998 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 999 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1000 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1001 |
+
The output format of the generate image. Choose between
|
| 1002 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1003 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1004 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1005 |
+
plain tuple.
|
| 1006 |
+
callback (`Callable`, *optional*):
|
| 1007 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1008 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1009 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1010 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1011 |
+
`True`, the inference will be cancelled.
|
| 1012 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1013 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1014 |
+
called at every step.
|
| 1015 |
+
Returns:
|
| 1016 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1017 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1018 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1019 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1020 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1021 |
+
"""
|
| 1022 |
+
return self.__call__(
|
| 1023 |
+
prompt=prompt,
|
| 1024 |
+
negative_prompt=negative_prompt,
|
| 1025 |
+
height=height,
|
| 1026 |
+
width=width,
|
| 1027 |
+
num_inference_steps=num_inference_steps,
|
| 1028 |
+
guidance_scale=guidance_scale,
|
| 1029 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1030 |
+
eta=eta,
|
| 1031 |
+
generator=generator,
|
| 1032 |
+
latents=latents,
|
| 1033 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1034 |
+
output_type=output_type,
|
| 1035 |
+
return_dict=return_dict,
|
| 1036 |
+
callback=callback,
|
| 1037 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1038 |
+
callback_steps=callback_steps,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
def img2img(
|
| 1042 |
+
self,
|
| 1043 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1044 |
+
prompt: Union[str, List[str]],
|
| 1045 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1046 |
+
strength: float = 0.8,
|
| 1047 |
+
num_inference_steps: Optional[int] = 50,
|
| 1048 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1049 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1050 |
+
eta: Optional[float] = 0.0,
|
| 1051 |
+
generator: Optional[torch.Generator] = None,
|
| 1052 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1053 |
+
output_type: Optional[str] = "pil",
|
| 1054 |
+
return_dict: bool = True,
|
| 1055 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1056 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1057 |
+
callback_steps: int = 1,
|
| 1058 |
+
):
|
| 1059 |
+
r"""
|
| 1060 |
+
Function for image-to-image generation.
|
| 1061 |
+
Args:
|
| 1062 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1063 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1064 |
+
process.
|
| 1065 |
+
prompt (`str` or `List[str]`):
|
| 1066 |
+
The prompt or prompts to guide the image generation.
|
| 1067 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1068 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1069 |
+
if `guidance_scale` is less than `1`).
|
| 1070 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1071 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 1072 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 1073 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 1074 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 1075 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 1076 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1077 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 1078 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
| 1079 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1080 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1081 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1082 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1083 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1084 |
+
usually at the expense of lower image quality.
|
| 1085 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1086 |
+
The number of images to generate per prompt.
|
| 1087 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1088 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1089 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1090 |
+
generator (`torch.Generator`, *optional*):
|
| 1091 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1092 |
+
deterministic.
|
| 1093 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1094 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1095 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1096 |
+
The output format of the generate image. Choose between
|
| 1097 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1098 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1099 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1100 |
+
plain tuple.
|
| 1101 |
+
callback (`Callable`, *optional*):
|
| 1102 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1103 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1104 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1105 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1106 |
+
`True`, the inference will be cancelled.
|
| 1107 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1108 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1109 |
+
called at every step.
|
| 1110 |
+
Returns:
|
| 1111 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1112 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1113 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1114 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1115 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1116 |
+
"""
|
| 1117 |
+
return self.__call__(
|
| 1118 |
+
prompt=prompt,
|
| 1119 |
+
negative_prompt=negative_prompt,
|
| 1120 |
+
image=image,
|
| 1121 |
+
num_inference_steps=num_inference_steps,
|
| 1122 |
+
guidance_scale=guidance_scale,
|
| 1123 |
+
strength=strength,
|
| 1124 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1125 |
+
eta=eta,
|
| 1126 |
+
generator=generator,
|
| 1127 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1128 |
+
output_type=output_type,
|
| 1129 |
+
return_dict=return_dict,
|
| 1130 |
+
callback=callback,
|
| 1131 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1132 |
+
callback_steps=callback_steps,
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
def inpaint(
|
| 1136 |
+
self,
|
| 1137 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1138 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1139 |
+
prompt: Union[str, List[str]],
|
| 1140 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1141 |
+
strength: float = 0.8,
|
| 1142 |
+
num_inference_steps: Optional[int] = 50,
|
| 1143 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1144 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1145 |
+
eta: Optional[float] = 0.0,
|
| 1146 |
+
generator: Optional[torch.Generator] = None,
|
| 1147 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1148 |
+
output_type: Optional[str] = "pil",
|
| 1149 |
+
return_dict: bool = True,
|
| 1150 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1151 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1152 |
+
callback_steps: int = 1,
|
| 1153 |
+
):
|
| 1154 |
+
r"""
|
| 1155 |
+
Function for inpaint.
|
| 1156 |
+
Args:
|
| 1157 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1158 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1159 |
+
process. This is the image whose masked region will be inpainted.
|
| 1160 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1161 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 1162 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 1163 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 1164 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 1165 |
+
prompt (`str` or `List[str]`):
|
| 1166 |
+
The prompt or prompts to guide the image generation.
|
| 1167 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1168 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1169 |
+
if `guidance_scale` is less than `1`).
|
| 1170 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1171 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
| 1172 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
| 1173 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
| 1174 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
| 1175 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1176 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
| 1177 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
| 1178 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1179 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1180 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1181 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1182 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1183 |
+
usually at the expense of lower image quality.
|
| 1184 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1185 |
+
The number of images to generate per prompt.
|
| 1186 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1187 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1188 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1189 |
+
generator (`torch.Generator`, *optional*):
|
| 1190 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1191 |
+
deterministic.
|
| 1192 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1193 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1194 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1195 |
+
The output format of the generate image. Choose between
|
| 1196 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1197 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1198 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1199 |
+
plain tuple.
|
| 1200 |
+
callback (`Callable`, *optional*):
|
| 1201 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1202 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1203 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1204 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1205 |
+
`True`, the inference will be cancelled.
|
| 1206 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1207 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1208 |
+
called at every step.
|
| 1209 |
+
Returns:
|
| 1210 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1211 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1212 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1213 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1214 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1215 |
+
"""
|
| 1216 |
+
return self.__call__(
|
| 1217 |
+
prompt=prompt,
|
| 1218 |
+
negative_prompt=negative_prompt,
|
| 1219 |
+
image=image,
|
| 1220 |
+
mask_image=mask_image,
|
| 1221 |
+
num_inference_steps=num_inference_steps,
|
| 1222 |
+
guidance_scale=guidance_scale,
|
| 1223 |
+
strength=strength,
|
| 1224 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1225 |
+
eta=eta,
|
| 1226 |
+
generator=generator,
|
| 1227 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1228 |
+
output_type=output_type,
|
| 1229 |
+
return_dict=return_dict,
|
| 1230 |
+
callback=callback,
|
| 1231 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1232 |
+
callback_steps=callback_steps,
|
| 1233 |
+
)
|
library/model_util.py
ADDED
|
@@ -0,0 +1,1356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# v1: split from train_db_fixed.py.
|
| 2 |
+
# v2: support safetensors
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from library.device_utils import init_ipex
|
| 9 |
+
init_ipex()
|
| 10 |
+
|
| 11 |
+
import diffusers
|
| 12 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
| 13 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
| 14 |
+
from safetensors.torch import load_file, save_file
|
| 15 |
+
from library.original_unet import UNet2DConditionModel
|
| 16 |
+
from library.utils import setup_logging
|
| 17 |
+
setup_logging()
|
| 18 |
+
import logging
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
| 22 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
| 23 |
+
BETA_START = 0.00085
|
| 24 |
+
BETA_END = 0.0120
|
| 25 |
+
|
| 26 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
| 27 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
| 28 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
| 29 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
| 30 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
| 31 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
| 32 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
| 33 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
| 34 |
+
UNET_PARAMS_NUM_HEADS = 8
|
| 35 |
+
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
| 36 |
+
|
| 37 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
| 38 |
+
VAE_PARAMS_RESOLUTION = 256
|
| 39 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
| 40 |
+
VAE_PARAMS_OUT_CH = 3
|
| 41 |
+
VAE_PARAMS_CH = 128
|
| 42 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
| 43 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
| 44 |
+
|
| 45 |
+
# V2
|
| 46 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
| 47 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
| 48 |
+
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
| 49 |
+
|
| 50 |
+
# Diffusersの設定を読み込むための参照モデル
|
| 51 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
| 52 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# region StableDiffusion->Diffusersの変換コード
|
| 56 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 60 |
+
"""
|
| 61 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 62 |
+
"""
|
| 63 |
+
if n_shave_prefix_segments >= 0:
|
| 64 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 65 |
+
else:
|
| 66 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 70 |
+
"""
|
| 71 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 72 |
+
"""
|
| 73 |
+
mapping = []
|
| 74 |
+
for old_item in old_list:
|
| 75 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 76 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 77 |
+
|
| 78 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 79 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 80 |
+
|
| 81 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 82 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 83 |
+
|
| 84 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 85 |
+
|
| 86 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 87 |
+
|
| 88 |
+
return mapping
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 92 |
+
"""
|
| 93 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 94 |
+
"""
|
| 95 |
+
mapping = []
|
| 96 |
+
for old_item in old_list:
|
| 97 |
+
new_item = old_item
|
| 98 |
+
|
| 99 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 100 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 101 |
+
|
| 102 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 103 |
+
|
| 104 |
+
return mapping
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 108 |
+
"""
|
| 109 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 110 |
+
"""
|
| 111 |
+
mapping = []
|
| 112 |
+
for old_item in old_list:
|
| 113 |
+
new_item = old_item
|
| 114 |
+
|
| 115 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 116 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 117 |
+
|
| 118 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 119 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 120 |
+
|
| 121 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 122 |
+
|
| 123 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 124 |
+
|
| 125 |
+
return mapping
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 129 |
+
"""
|
| 130 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 131 |
+
"""
|
| 132 |
+
mapping = []
|
| 133 |
+
for old_item in old_list:
|
| 134 |
+
new_item = old_item
|
| 135 |
+
|
| 136 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 137 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 138 |
+
|
| 139 |
+
if diffusers.__version__ < "0.17.0":
|
| 140 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
| 141 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
| 142 |
+
|
| 143 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
| 144 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
| 145 |
+
|
| 146 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
| 147 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
| 148 |
+
|
| 149 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 150 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 151 |
+
else:
|
| 152 |
+
new_item = new_item.replace("q.weight", "to_q.weight")
|
| 153 |
+
new_item = new_item.replace("q.bias", "to_q.bias")
|
| 154 |
+
|
| 155 |
+
new_item = new_item.replace("k.weight", "to_k.weight")
|
| 156 |
+
new_item = new_item.replace("k.bias", "to_k.bias")
|
| 157 |
+
|
| 158 |
+
new_item = new_item.replace("v.weight", "to_v.weight")
|
| 159 |
+
new_item = new_item.replace("v.bias", "to_v.bias")
|
| 160 |
+
|
| 161 |
+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
| 162 |
+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
| 163 |
+
|
| 164 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 165 |
+
|
| 166 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 167 |
+
|
| 168 |
+
return mapping
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def assign_to_checkpoint(
|
| 172 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 173 |
+
):
|
| 174 |
+
"""
|
| 175 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
| 176 |
+
to them. It splits attention layers, and takes into account additional replacements
|
| 177 |
+
that may arise.
|
| 178 |
+
|
| 179 |
+
Assigns the weights to the new checkpoint.
|
| 180 |
+
"""
|
| 181 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 182 |
+
|
| 183 |
+
# Splits the attention layers into three variables.
|
| 184 |
+
if attention_paths_to_split is not None:
|
| 185 |
+
for path, path_map in attention_paths_to_split.items():
|
| 186 |
+
old_tensor = old_checkpoint[path]
|
| 187 |
+
channels = old_tensor.shape[0] // 3
|
| 188 |
+
|
| 189 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 190 |
+
|
| 191 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 192 |
+
|
| 193 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 194 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 195 |
+
|
| 196 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 197 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 198 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 199 |
+
|
| 200 |
+
for path in paths:
|
| 201 |
+
new_path = path["new"]
|
| 202 |
+
|
| 203 |
+
# These have already been assigned
|
| 204 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
# Global renaming happens here
|
| 208 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 209 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 210 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 211 |
+
|
| 212 |
+
if additional_replacements is not None:
|
| 213 |
+
for replacement in additional_replacements:
|
| 214 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 215 |
+
|
| 216 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 217 |
+
reshaping = False
|
| 218 |
+
if diffusers.__version__ < "0.17.0":
|
| 219 |
+
if "proj_attn.weight" in new_path:
|
| 220 |
+
reshaping = True
|
| 221 |
+
else:
|
| 222 |
+
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
|
| 223 |
+
reshaping = True
|
| 224 |
+
|
| 225 |
+
if reshaping:
|
| 226 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
| 227 |
+
else:
|
| 228 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def conv_attn_to_linear(checkpoint):
|
| 232 |
+
keys = list(checkpoint.keys())
|
| 233 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 234 |
+
for key in keys:
|
| 235 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 236 |
+
if checkpoint[key].ndim > 2:
|
| 237 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 238 |
+
elif "proj_attn.weight" in key:
|
| 239 |
+
if checkpoint[key].ndim > 2:
|
| 240 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def linear_transformer_to_conv(checkpoint):
|
| 244 |
+
keys = list(checkpoint.keys())
|
| 245 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
| 246 |
+
for key in keys:
|
| 247 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
| 248 |
+
if checkpoint[key].ndim == 2:
|
| 249 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
| 253 |
+
"""
|
| 254 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# extract state_dict for UNet
|
| 258 |
+
unet_state_dict = {}
|
| 259 |
+
unet_key = "model.diffusion_model."
|
| 260 |
+
keys = list(checkpoint.keys())
|
| 261 |
+
for key in keys:
|
| 262 |
+
if key.startswith(unet_key):
|
| 263 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 264 |
+
|
| 265 |
+
new_checkpoint = {}
|
| 266 |
+
|
| 267 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 268 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 269 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 270 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 271 |
+
|
| 272 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 273 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 274 |
+
|
| 275 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 276 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 277 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 278 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 279 |
+
|
| 280 |
+
# Retrieves the keys for the input blocks only
|
| 281 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 282 |
+
input_blocks = {
|
| 283 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
# Retrieves the keys for the middle blocks only
|
| 287 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 288 |
+
middle_blocks = {
|
| 289 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# Retrieves the keys for the output blocks only
|
| 293 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 294 |
+
output_blocks = {
|
| 295 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
for i in range(1, num_input_blocks):
|
| 299 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 300 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 301 |
+
|
| 302 |
+
resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
|
| 303 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 304 |
+
|
| 305 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 306 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 307 |
+
f"input_blocks.{i}.0.op.weight"
|
| 308 |
+
)
|
| 309 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
|
| 310 |
+
|
| 311 |
+
paths = renew_resnet_paths(resnets)
|
| 312 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 313 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 314 |
+
|
| 315 |
+
if len(attentions):
|
| 316 |
+
paths = renew_attention_paths(attentions)
|
| 317 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 318 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 319 |
+
|
| 320 |
+
resnet_0 = middle_blocks[0]
|
| 321 |
+
attentions = middle_blocks[1]
|
| 322 |
+
resnet_1 = middle_blocks[2]
|
| 323 |
+
|
| 324 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 325 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 326 |
+
|
| 327 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 328 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 329 |
+
|
| 330 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 331 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 332 |
+
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 333 |
+
|
| 334 |
+
for i in range(num_output_blocks):
|
| 335 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 336 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 337 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 338 |
+
output_block_list = {}
|
| 339 |
+
|
| 340 |
+
for layer in output_block_layers:
|
| 341 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 342 |
+
if layer_id in output_block_list:
|
| 343 |
+
output_block_list[layer_id].append(layer_name)
|
| 344 |
+
else:
|
| 345 |
+
output_block_list[layer_id] = [layer_name]
|
| 346 |
+
|
| 347 |
+
if len(output_block_list) > 1:
|
| 348 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 349 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 350 |
+
|
| 351 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 352 |
+
paths = renew_resnet_paths(resnets)
|
| 353 |
+
|
| 354 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 355 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 356 |
+
|
| 357 |
+
# オリジナル:
|
| 358 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
| 359 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
| 360 |
+
|
| 361 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
| 362 |
+
for l in output_block_list.values():
|
| 363 |
+
l.sort()
|
| 364 |
+
|
| 365 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
| 366 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
| 367 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 368 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 369 |
+
]
|
| 370 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 371 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 372 |
+
]
|
| 373 |
+
|
| 374 |
+
# Clear attentions as they have been attributed above.
|
| 375 |
+
if len(attentions) == 2:
|
| 376 |
+
attentions = []
|
| 377 |
+
|
| 378 |
+
if len(attentions):
|
| 379 |
+
paths = renew_attention_paths(attentions)
|
| 380 |
+
meta_path = {
|
| 381 |
+
"old": f"output_blocks.{i}.1",
|
| 382 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 383 |
+
}
|
| 384 |
+
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
| 385 |
+
else:
|
| 386 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 387 |
+
for path in resnet_0_paths:
|
| 388 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 389 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 390 |
+
|
| 391 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 392 |
+
|
| 393 |
+
# SDのv2では1*1のconv2dがlinearに変わっている
|
| 394 |
+
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
| 395 |
+
if v2 and not config.get("use_linear_projection", False):
|
| 396 |
+
linear_transformer_to_conv(new_checkpoint)
|
| 397 |
+
|
| 398 |
+
return new_checkpoint
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 402 |
+
# extract state dict for VAE
|
| 403 |
+
vae_state_dict = {}
|
| 404 |
+
vae_key = "first_stage_model."
|
| 405 |
+
keys = list(checkpoint.keys())
|
| 406 |
+
for key in keys:
|
| 407 |
+
if key.startswith(vae_key):
|
| 408 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 409 |
+
# if len(vae_state_dict) == 0:
|
| 410 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
| 411 |
+
# vae_state_dict = checkpoint
|
| 412 |
+
|
| 413 |
+
new_checkpoint = {}
|
| 414 |
+
|
| 415 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 416 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 417 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 418 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 419 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 420 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 421 |
+
|
| 422 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 423 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 424 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 425 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 426 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 427 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 428 |
+
|
| 429 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 430 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 431 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 432 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 433 |
+
|
| 434 |
+
# Retrieves the keys for the encoder down blocks only
|
| 435 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 436 |
+
down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
|
| 437 |
+
|
| 438 |
+
# Retrieves the keys for the decoder up blocks only
|
| 439 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 440 |
+
up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
|
| 441 |
+
|
| 442 |
+
for i in range(num_down_blocks):
|
| 443 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 444 |
+
|
| 445 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 446 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 447 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 448 |
+
)
|
| 449 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 450 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 454 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 455 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 456 |
+
|
| 457 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 458 |
+
num_mid_res_blocks = 2
|
| 459 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 460 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 461 |
+
|
| 462 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 463 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 464 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 465 |
+
|
| 466 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 467 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 468 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 469 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 470 |
+
conv_attn_to_linear(new_checkpoint)
|
| 471 |
+
|
| 472 |
+
for i in range(num_up_blocks):
|
| 473 |
+
block_id = num_up_blocks - 1 - i
|
| 474 |
+
resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
|
| 475 |
+
|
| 476 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 477 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 478 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 479 |
+
]
|
| 480 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 481 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 485 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 486 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 487 |
+
|
| 488 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 489 |
+
num_mid_res_blocks = 2
|
| 490 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 491 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 492 |
+
|
| 493 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 494 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 495 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 496 |
+
|
| 497 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 498 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 499 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 500 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 501 |
+
conv_attn_to_linear(new_checkpoint)
|
| 502 |
+
return new_checkpoint
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
| 506 |
+
"""
|
| 507 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 508 |
+
"""
|
| 509 |
+
# unet_params = original_config.model.params.unet_config.params
|
| 510 |
+
|
| 511 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
| 512 |
+
|
| 513 |
+
down_block_types = []
|
| 514 |
+
resolution = 1
|
| 515 |
+
for i in range(len(block_out_channels)):
|
| 516 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
| 517 |
+
down_block_types.append(block_type)
|
| 518 |
+
if i != len(block_out_channels) - 1:
|
| 519 |
+
resolution *= 2
|
| 520 |
+
|
| 521 |
+
up_block_types = []
|
| 522 |
+
for i in range(len(block_out_channels)):
|
| 523 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
| 524 |
+
up_block_types.append(block_type)
|
| 525 |
+
resolution //= 2
|
| 526 |
+
|
| 527 |
+
config = dict(
|
| 528 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
| 529 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
| 530 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
| 531 |
+
down_block_types=tuple(down_block_types),
|
| 532 |
+
up_block_types=tuple(up_block_types),
|
| 533 |
+
block_out_channels=tuple(block_out_channels),
|
| 534 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
| 535 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
| 536 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
| 537 |
+
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
| 538 |
+
)
|
| 539 |
+
if v2 and use_linear_projection_in_v2:
|
| 540 |
+
config["use_linear_projection"] = True
|
| 541 |
+
|
| 542 |
+
return config
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def create_vae_diffusers_config():
|
| 546 |
+
"""
|
| 547 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 548 |
+
"""
|
| 549 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 550 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
| 551 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
| 552 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 553 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 554 |
+
|
| 555 |
+
config = dict(
|
| 556 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
| 557 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
| 558 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
| 559 |
+
down_block_types=tuple(down_block_types),
|
| 560 |
+
up_block_types=tuple(up_block_types),
|
| 561 |
+
block_out_channels=tuple(block_out_channels),
|
| 562 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
| 563 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
| 564 |
+
)
|
| 565 |
+
return config
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
| 569 |
+
keys = list(checkpoint.keys())
|
| 570 |
+
text_model_dict = {}
|
| 571 |
+
for key in keys:
|
| 572 |
+
if key.startswith("cond_stage_model.transformer"):
|
| 573 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
| 574 |
+
|
| 575 |
+
# remove position_ids for newer transformer, which causes error :(
|
| 576 |
+
if "text_model.embeddings.position_ids" in text_model_dict:
|
| 577 |
+
text_model_dict.pop("text_model.embeddings.position_ids")
|
| 578 |
+
|
| 579 |
+
return text_model_dict
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
| 583 |
+
# 嫌になるくらい違うぞ!
|
| 584 |
+
def convert_key(key):
|
| 585 |
+
if not key.startswith("cond_stage_model"):
|
| 586 |
+
return None
|
| 587 |
+
|
| 588 |
+
# common conversion
|
| 589 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
| 590 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
| 591 |
+
|
| 592 |
+
if "resblocks" in key:
|
| 593 |
+
# resblocks conversion
|
| 594 |
+
key = key.replace(".resblocks.", ".layers.")
|
| 595 |
+
if ".ln_" in key:
|
| 596 |
+
key = key.replace(".ln_", ".layer_norm")
|
| 597 |
+
elif ".mlp." in key:
|
| 598 |
+
key = key.replace(".c_fc.", ".fc1.")
|
| 599 |
+
key = key.replace(".c_proj.", ".fc2.")
|
| 600 |
+
elif ".attn.out_proj" in key:
|
| 601 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
| 602 |
+
elif ".attn.in_proj" in key:
|
| 603 |
+
key = None # 特殊なので後で処理する
|
| 604 |
+
else:
|
| 605 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
| 606 |
+
elif ".positional_embedding" in key:
|
| 607 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
| 608 |
+
elif ".text_projection" in key:
|
| 609 |
+
key = None # 使われない???
|
| 610 |
+
elif ".logit_scale" in key:
|
| 611 |
+
key = None # 使われない???
|
| 612 |
+
elif ".token_embedding" in key:
|
| 613 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
| 614 |
+
elif ".ln_final" in key:
|
| 615 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
| 616 |
+
return key
|
| 617 |
+
|
| 618 |
+
keys = list(checkpoint.keys())
|
| 619 |
+
new_sd = {}
|
| 620 |
+
for key in keys:
|
| 621 |
+
# remove resblocks 23
|
| 622 |
+
if ".resblocks.23." in key:
|
| 623 |
+
continue
|
| 624 |
+
new_key = convert_key(key)
|
| 625 |
+
if new_key is None:
|
| 626 |
+
continue
|
| 627 |
+
new_sd[new_key] = checkpoint[key]
|
| 628 |
+
|
| 629 |
+
# attnの変換
|
| 630 |
+
for key in keys:
|
| 631 |
+
if ".resblocks.23." in key:
|
| 632 |
+
continue
|
| 633 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
| 634 |
+
# 三つに分割
|
| 635 |
+
values = torch.chunk(checkpoint[key], 3)
|
| 636 |
+
|
| 637 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
| 638 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
| 639 |
+
key_pfx = key_pfx.replace("_weight", "")
|
| 640 |
+
key_pfx = key_pfx.replace("_bias", "")
|
| 641 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
| 642 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
| 643 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
| 644 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
| 645 |
+
|
| 646 |
+
# rename or add position_ids
|
| 647 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
| 648 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
| 649 |
+
# waifu diffusion v1.4
|
| 650 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
| 651 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
| 652 |
+
else:
|
| 653 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
| 654 |
+
|
| 655 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
| 656 |
+
return new_sd
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
# endregion
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
# region Diffusers->StableDiffusion の変換コード
|
| 663 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
def conv_transformer_to_linear(checkpoint):
|
| 667 |
+
keys = list(checkpoint.keys())
|
| 668 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
| 669 |
+
for key in keys:
|
| 670 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
| 671 |
+
if checkpoint[key].ndim > 2:
|
| 672 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
| 676 |
+
unet_conversion_map = [
|
| 677 |
+
# (stable-diffusion, HF Diffusers)
|
| 678 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 679 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 680 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 681 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 682 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 683 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 684 |
+
("out.0.weight", "conv_norm_out.weight"),
|
| 685 |
+
("out.0.bias", "conv_norm_out.bias"),
|
| 686 |
+
("out.2.weight", "conv_out.weight"),
|
| 687 |
+
("out.2.bias", "conv_out.bias"),
|
| 688 |
+
]
|
| 689 |
+
|
| 690 |
+
unet_conversion_map_resnet = [
|
| 691 |
+
# (stable-diffusion, HF Diffusers)
|
| 692 |
+
("in_layers.0", "norm1"),
|
| 693 |
+
("in_layers.2", "conv1"),
|
| 694 |
+
("out_layers.0", "norm2"),
|
| 695 |
+
("out_layers.3", "conv2"),
|
| 696 |
+
("emb_layers.1", "time_emb_proj"),
|
| 697 |
+
("skip_connection", "conv_shortcut"),
|
| 698 |
+
]
|
| 699 |
+
|
| 700 |
+
unet_conversion_map_layer = []
|
| 701 |
+
for i in range(4):
|
| 702 |
+
# loop over downblocks/upblocks
|
| 703 |
+
|
| 704 |
+
for j in range(2):
|
| 705 |
+
# loop over resnets/attentions for downblocks
|
| 706 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 707 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 708 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 709 |
+
|
| 710 |
+
if i < 3:
|
| 711 |
+
# no attention layers in down_blocks.3
|
| 712 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 713 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 714 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 715 |
+
|
| 716 |
+
for j in range(3):
|
| 717 |
+
# loop over resnets/attentions for upblocks
|
| 718 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 719 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 720 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 721 |
+
|
| 722 |
+
if i > 0:
|
| 723 |
+
# no attention layers in up_blocks.0
|
| 724 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 725 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 726 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 727 |
+
|
| 728 |
+
if i < 3:
|
| 729 |
+
# no downsample in down_blocks.3
|
| 730 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 731 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 732 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 733 |
+
|
| 734 |
+
# no upsample in up_blocks.3
|
| 735 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 736 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
| 737 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 738 |
+
|
| 739 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 740 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 741 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 742 |
+
|
| 743 |
+
for j in range(2):
|
| 744 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 745 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 746 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 747 |
+
|
| 748 |
+
# buyer beware: this is a *brittle* function,
|
| 749 |
+
# and correct output requires that all of these pieces interact in
|
| 750 |
+
# the exact order in which I have arranged them.
|
| 751 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
| 752 |
+
for sd_name, hf_name in unet_conversion_map:
|
| 753 |
+
mapping[hf_name] = sd_name
|
| 754 |
+
for k, v in mapping.items():
|
| 755 |
+
if "resnets" in k:
|
| 756 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
| 757 |
+
v = v.replace(hf_part, sd_part)
|
| 758 |
+
mapping[k] = v
|
| 759 |
+
for k, v in mapping.items():
|
| 760 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
| 761 |
+
v = v.replace(hf_part, sd_part)
|
| 762 |
+
mapping[k] = v
|
| 763 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
| 764 |
+
|
| 765 |
+
if v2:
|
| 766 |
+
conv_transformer_to_linear(new_state_dict)
|
| 767 |
+
|
| 768 |
+
return new_state_dict
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def controlnet_conversion_map():
|
| 772 |
+
unet_conversion_map = [
|
| 773 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
| 774 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
| 775 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
| 776 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
| 777 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
| 778 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
| 779 |
+
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
|
| 780 |
+
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
|
| 781 |
+
]
|
| 782 |
+
|
| 783 |
+
unet_conversion_map_resnet = [
|
| 784 |
+
("in_layers.0", "norm1"),
|
| 785 |
+
("in_layers.2", "conv1"),
|
| 786 |
+
("out_layers.0", "norm2"),
|
| 787 |
+
("out_layers.3", "conv2"),
|
| 788 |
+
("emb_layers.1", "time_emb_proj"),
|
| 789 |
+
("skip_connection", "conv_shortcut"),
|
| 790 |
+
]
|
| 791 |
+
|
| 792 |
+
unet_conversion_map_layer = []
|
| 793 |
+
for i in range(4):
|
| 794 |
+
for j in range(2):
|
| 795 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 796 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 797 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 798 |
+
|
| 799 |
+
if i < 3:
|
| 800 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 801 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 802 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 803 |
+
|
| 804 |
+
if i < 3:
|
| 805 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 806 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 807 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 808 |
+
|
| 809 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 810 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 811 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 812 |
+
|
| 813 |
+
for j in range(2):
|
| 814 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 815 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 816 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 817 |
+
|
| 818 |
+
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
|
| 819 |
+
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
|
| 820 |
+
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
|
| 821 |
+
sd_prefix = f"input_hint_block.{i*2}."
|
| 822 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
| 823 |
+
|
| 824 |
+
for i in range(12):
|
| 825 |
+
hf_prefix = f"controlnet_down_blocks.{i}."
|
| 826 |
+
sd_prefix = f"zero_convs.{i}.0."
|
| 827 |
+
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
| 828 |
+
|
| 829 |
+
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
| 833 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
| 834 |
+
|
| 835 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
| 836 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
| 837 |
+
mapping[diffusers_name] = sd_name
|
| 838 |
+
for k, v in mapping.items():
|
| 839 |
+
if "resnets" in k:
|
| 840 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
| 841 |
+
v = v.replace(diffusers_part, sd_part)
|
| 842 |
+
mapping[k] = v
|
| 843 |
+
for k, v in mapping.items():
|
| 844 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
| 845 |
+
v = v.replace(diffusers_part, sd_part)
|
| 846 |
+
mapping[k] = v
|
| 847 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
| 848 |
+
return new_state_dict
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
| 852 |
+
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
| 853 |
+
|
| 854 |
+
mapping = {k: k for k in controlnet_state_dict.keys()}
|
| 855 |
+
for sd_name, diffusers_name in unet_conversion_map:
|
| 856 |
+
mapping[sd_name] = diffusers_name
|
| 857 |
+
for k, v in mapping.items():
|
| 858 |
+
for sd_part, diffusers_part in unet_conversion_map_layer:
|
| 859 |
+
v = v.replace(sd_part, diffusers_part)
|
| 860 |
+
mapping[k] = v
|
| 861 |
+
for k, v in mapping.items():
|
| 862 |
+
if "resnets" in v:
|
| 863 |
+
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
| 864 |
+
v = v.replace(sd_part, diffusers_part)
|
| 865 |
+
mapping[k] = v
|
| 866 |
+
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
| 867 |
+
return new_state_dict
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
# ================#
|
| 871 |
+
# VAE Conversion #
|
| 872 |
+
# ================#
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def reshape_weight_for_sd(w):
|
| 876 |
+
# convert HF linear weights to SD conv2d weights
|
| 877 |
+
return w.reshape(*w.shape, 1, 1)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def convert_vae_state_dict(vae_state_dict):
|
| 881 |
+
vae_conversion_map = [
|
| 882 |
+
# (stable-diffusion, HF Diffusers)
|
| 883 |
+
("nin_shortcut", "conv_shortcut"),
|
| 884 |
+
("norm_out", "conv_norm_out"),
|
| 885 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
| 886 |
+
]
|
| 887 |
+
|
| 888 |
+
for i in range(4):
|
| 889 |
+
# down_blocks have two resnets
|
| 890 |
+
for j in range(2):
|
| 891 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
| 892 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
| 893 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
| 894 |
+
|
| 895 |
+
if i < 3:
|
| 896 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
| 897 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
| 898 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 899 |
+
|
| 900 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 901 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
| 902 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 903 |
+
|
| 904 |
+
# up_blocks have three resnets
|
| 905 |
+
# also, up blocks in hf are numbered in reverse from sd
|
| 906 |
+
for j in range(3):
|
| 907 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
| 908 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
| 909 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
| 910 |
+
|
| 911 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
| 912 |
+
for i in range(2):
|
| 913 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
| 914 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
| 915 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 916 |
+
|
| 917 |
+
if diffusers.__version__ < "0.17.0":
|
| 918 |
+
vae_conversion_map_attn = [
|
| 919 |
+
# (stable-diffusion, HF Diffusers)
|
| 920 |
+
("norm.", "group_norm."),
|
| 921 |
+
("q.", "query."),
|
| 922 |
+
("k.", "key."),
|
| 923 |
+
("v.", "value."),
|
| 924 |
+
("proj_out.", "proj_attn."),
|
| 925 |
+
]
|
| 926 |
+
else:
|
| 927 |
+
vae_conversion_map_attn = [
|
| 928 |
+
# (stable-diffusion, HF Diffusers)
|
| 929 |
+
("norm.", "group_norm."),
|
| 930 |
+
("q.", "to_q."),
|
| 931 |
+
("k.", "to_k."),
|
| 932 |
+
("v.", "to_v."),
|
| 933 |
+
("proj_out.", "to_out.0."),
|
| 934 |
+
]
|
| 935 |
+
|
| 936 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
| 937 |
+
for k, v in mapping.items():
|
| 938 |
+
for sd_part, hf_part in vae_conversion_map:
|
| 939 |
+
v = v.replace(hf_part, sd_part)
|
| 940 |
+
mapping[k] = v
|
| 941 |
+
for k, v in mapping.items():
|
| 942 |
+
if "attentions" in k:
|
| 943 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
| 944 |
+
v = v.replace(hf_part, sd_part)
|
| 945 |
+
mapping[k] = v
|
| 946 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
| 947 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
| 948 |
+
for k, v in new_state_dict.items():
|
| 949 |
+
for weight_name in weights_to_convert:
|
| 950 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
| 951 |
+
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
| 952 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
| 953 |
+
|
| 954 |
+
return new_state_dict
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
# endregion
|
| 958 |
+
|
| 959 |
+
# region 自作のモデル読み書きなど
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def is_safetensors(path):
|
| 963 |
+
return os.path.splitext(path)[1].lower() == ".safetensors"
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
| 967 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
| 968 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
| 969 |
+
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
|
| 970 |
+
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
|
| 971 |
+
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
|
| 972 |
+
]
|
| 973 |
+
|
| 974 |
+
if is_safetensors(ckpt_path):
|
| 975 |
+
checkpoint = None
|
| 976 |
+
state_dict = load_file(ckpt_path) # , device) # may causes error
|
| 977 |
+
else:
|
| 978 |
+
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 979 |
+
if "state_dict" in checkpoint:
|
| 980 |
+
state_dict = checkpoint["state_dict"]
|
| 981 |
+
else:
|
| 982 |
+
state_dict = checkpoint
|
| 983 |
+
checkpoint = None
|
| 984 |
+
|
| 985 |
+
key_reps = []
|
| 986 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
| 987 |
+
for key in state_dict.keys():
|
| 988 |
+
if key.startswith(rep_from):
|
| 989 |
+
new_key = rep_to + key[len(rep_from) :]
|
| 990 |
+
key_reps.append((key, new_key))
|
| 991 |
+
|
| 992 |
+
for key, new_key in key_reps:
|
| 993 |
+
state_dict[new_key] = state_dict[key]
|
| 994 |
+
del state_dict[key]
|
| 995 |
+
|
| 996 |
+
return checkpoint, state_dict
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
| 1000 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
| 1001 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
| 1002 |
+
|
| 1003 |
+
# Convert the UNet2DConditionModel model.
|
| 1004 |
+
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
| 1005 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
| 1006 |
+
|
| 1007 |
+
unet = UNet2DConditionModel(**unet_config).to(device)
|
| 1008 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
| 1009 |
+
logger.info(f"loading u-net: {info}")
|
| 1010 |
+
|
| 1011 |
+
# Convert the VAE model.
|
| 1012 |
+
vae_config = create_vae_diffusers_config()
|
| 1013 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
| 1014 |
+
|
| 1015 |
+
vae = AutoencoderKL(**vae_config).to(device)
|
| 1016 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
| 1017 |
+
logger.info(f"loading vae: {info}")
|
| 1018 |
+
|
| 1019 |
+
# convert text_model
|
| 1020 |
+
if v2:
|
| 1021 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
| 1022 |
+
cfg = CLIPTextConfig(
|
| 1023 |
+
vocab_size=49408,
|
| 1024 |
+
hidden_size=1024,
|
| 1025 |
+
intermediate_size=4096,
|
| 1026 |
+
num_hidden_layers=23,
|
| 1027 |
+
num_attention_heads=16,
|
| 1028 |
+
max_position_embeddings=77,
|
| 1029 |
+
hidden_act="gelu",
|
| 1030 |
+
layer_norm_eps=1e-05,
|
| 1031 |
+
dropout=0.0,
|
| 1032 |
+
attention_dropout=0.0,
|
| 1033 |
+
initializer_range=0.02,
|
| 1034 |
+
initializer_factor=1.0,
|
| 1035 |
+
pad_token_id=1,
|
| 1036 |
+
bos_token_id=0,
|
| 1037 |
+
eos_token_id=2,
|
| 1038 |
+
model_type="clip_text_model",
|
| 1039 |
+
projection_dim=512,
|
| 1040 |
+
torch_dtype="float32",
|
| 1041 |
+
transformers_version="4.25.0.dev0",
|
| 1042 |
+
)
|
| 1043 |
+
text_model = CLIPTextModel._from_config(cfg)
|
| 1044 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 1045 |
+
else:
|
| 1046 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
| 1047 |
+
|
| 1048 |
+
# logging.set_verbosity_error() # don't show annoying warning
|
| 1049 |
+
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
| 1050 |
+
# logging.set_verbosity_warning()
|
| 1051 |
+
# logger.info(f"config: {text_model.config}")
|
| 1052 |
+
cfg = CLIPTextConfig(
|
| 1053 |
+
vocab_size=49408,
|
| 1054 |
+
hidden_size=768,
|
| 1055 |
+
intermediate_size=3072,
|
| 1056 |
+
num_hidden_layers=12,
|
| 1057 |
+
num_attention_heads=12,
|
| 1058 |
+
max_position_embeddings=77,
|
| 1059 |
+
hidden_act="quick_gelu",
|
| 1060 |
+
layer_norm_eps=1e-05,
|
| 1061 |
+
dropout=0.0,
|
| 1062 |
+
attention_dropout=0.0,
|
| 1063 |
+
initializer_range=0.02,
|
| 1064 |
+
initializer_factor=1.0,
|
| 1065 |
+
pad_token_id=1,
|
| 1066 |
+
bos_token_id=0,
|
| 1067 |
+
eos_token_id=2,
|
| 1068 |
+
model_type="clip_text_model",
|
| 1069 |
+
projection_dim=768,
|
| 1070 |
+
torch_dtype="float32",
|
| 1071 |
+
)
|
| 1072 |
+
text_model = CLIPTextModel._from_config(cfg)
|
| 1073 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
| 1074 |
+
logger.info(f"loading text encoder: {info}")
|
| 1075 |
+
|
| 1076 |
+
return text_model, vae, unet
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
|
| 1080 |
+
# only for reference
|
| 1081 |
+
version_str = "sd"
|
| 1082 |
+
if v2:
|
| 1083 |
+
version_str += "_v2"
|
| 1084 |
+
else:
|
| 1085 |
+
version_str += "_v1"
|
| 1086 |
+
if v_parameterization:
|
| 1087 |
+
version_str += "_v"
|
| 1088 |
+
return version_str
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
| 1092 |
+
def convert_key(key):
|
| 1093 |
+
# position_idsの除去
|
| 1094 |
+
if ".position_ids" in key:
|
| 1095 |
+
return None
|
| 1096 |
+
|
| 1097 |
+
# common
|
| 1098 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
| 1099 |
+
key = key.replace("text_model.", "")
|
| 1100 |
+
if "layers" in key:
|
| 1101 |
+
# resblocks conversion
|
| 1102 |
+
key = key.replace(".layers.", ".resblocks.")
|
| 1103 |
+
if ".layer_norm" in key:
|
| 1104 |
+
key = key.replace(".layer_norm", ".ln_")
|
| 1105 |
+
elif ".mlp." in key:
|
| 1106 |
+
key = key.replace(".fc1.", ".c_fc.")
|
| 1107 |
+
key = key.replace(".fc2.", ".c_proj.")
|
| 1108 |
+
elif ".self_attn.out_proj" in key:
|
| 1109 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
| 1110 |
+
elif ".self_attn." in key:
|
| 1111 |
+
key = None # 特殊なので後で処理する
|
| 1112 |
+
else:
|
| 1113 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
| 1114 |
+
elif ".position_embedding" in key:
|
| 1115 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
| 1116 |
+
elif ".token_embedding" in key:
|
| 1117 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
| 1118 |
+
elif "final_layer_norm" in key:
|
| 1119 |
+
key = key.replace("final_layer_norm", "ln_final")
|
| 1120 |
+
return key
|
| 1121 |
+
|
| 1122 |
+
keys = list(checkpoint.keys())
|
| 1123 |
+
new_sd = {}
|
| 1124 |
+
for key in keys:
|
| 1125 |
+
new_key = convert_key(key)
|
| 1126 |
+
if new_key is None:
|
| 1127 |
+
continue
|
| 1128 |
+
new_sd[new_key] = checkpoint[key]
|
| 1129 |
+
|
| 1130 |
+
# attnの変換
|
| 1131 |
+
for key in keys:
|
| 1132 |
+
if "layers" in key and "q_proj" in key:
|
| 1133 |
+
# 三つを結合
|
| 1134 |
+
key_q = key
|
| 1135 |
+
key_k = key.replace("q_proj", "k_proj")
|
| 1136 |
+
key_v = key.replace("q_proj", "v_proj")
|
| 1137 |
+
|
| 1138 |
+
value_q = checkpoint[key_q]
|
| 1139 |
+
value_k = checkpoint[key_k]
|
| 1140 |
+
value_v = checkpoint[key_v]
|
| 1141 |
+
value = torch.cat([value_q, value_k, value_v])
|
| 1142 |
+
|
| 1143 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
| 1144 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
| 1145 |
+
new_sd[new_key] = value
|
| 1146 |
+
|
| 1147 |
+
# 最後の層などを捏造するか
|
| 1148 |
+
if make_dummy_weights:
|
| 1149 |
+
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
| 1150 |
+
keys = list(new_sd.keys())
|
| 1151 |
+
for key in keys:
|
| 1152 |
+
if key.startswith("transformer.resblocks.22."):
|
| 1153 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
| 1154 |
+
|
| 1155 |
+
# Diffusersに含まれない重みを作っておく
|
| 1156 |
+
new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
| 1157 |
+
new_sd["logit_scale"] = torch.tensor(1)
|
| 1158 |
+
|
| 1159 |
+
return new_sd
|
| 1160 |
+
|
| 1161 |
+
|
| 1162 |
+
def save_stable_diffusion_checkpoint(
|
| 1163 |
+
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
|
| 1164 |
+
):
|
| 1165 |
+
if ckpt_path is not None:
|
| 1166 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
| 1167 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
| 1168 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
| 1169 |
+
checkpoint = {}
|
| 1170 |
+
strict = False
|
| 1171 |
+
else:
|
| 1172 |
+
strict = True
|
| 1173 |
+
if "state_dict" in state_dict:
|
| 1174 |
+
del state_dict["state_dict"]
|
| 1175 |
+
else:
|
| 1176 |
+
# 新しく作る
|
| 1177 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
| 1178 |
+
checkpoint = {}
|
| 1179 |
+
state_dict = {}
|
| 1180 |
+
strict = False
|
| 1181 |
+
|
| 1182 |
+
def update_sd(prefix, sd):
|
| 1183 |
+
for k, v in sd.items():
|
| 1184 |
+
key = prefix + k
|
| 1185 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
| 1186 |
+
if save_dtype is not None:
|
| 1187 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 1188 |
+
state_dict[key] = v
|
| 1189 |
+
|
| 1190 |
+
# Convert the UNet model
|
| 1191 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
| 1192 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
| 1193 |
+
|
| 1194 |
+
# Convert the text encoder model
|
| 1195 |
+
if v2:
|
| 1196 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
|
| 1197 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
| 1198 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
| 1199 |
+
else:
|
| 1200 |
+
text_enc_dict = text_encoder.state_dict()
|
| 1201 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
| 1202 |
+
|
| 1203 |
+
# Convert the VAE
|
| 1204 |
+
if vae is not None:
|
| 1205 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
| 1206 |
+
update_sd("first_stage_model.", vae_dict)
|
| 1207 |
+
|
| 1208 |
+
# Put together new checkpoint
|
| 1209 |
+
key_count = len(state_dict.keys())
|
| 1210 |
+
new_ckpt = {"state_dict": state_dict}
|
| 1211 |
+
|
| 1212 |
+
# epoch and global_step are sometimes not int
|
| 1213 |
+
try:
|
| 1214 |
+
if "epoch" in checkpoint:
|
| 1215 |
+
epochs += checkpoint["epoch"]
|
| 1216 |
+
if "global_step" in checkpoint:
|
| 1217 |
+
steps += checkpoint["global_step"]
|
| 1218 |
+
except:
|
| 1219 |
+
pass
|
| 1220 |
+
|
| 1221 |
+
new_ckpt["epoch"] = epochs
|
| 1222 |
+
new_ckpt["global_step"] = steps
|
| 1223 |
+
|
| 1224 |
+
if is_safetensors(output_file):
|
| 1225 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
| 1226 |
+
save_file(state_dict, output_file, metadata)
|
| 1227 |
+
else:
|
| 1228 |
+
torch.save(new_ckpt, output_file)
|
| 1229 |
+
|
| 1230 |
+
return key_count
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
| 1234 |
+
if pretrained_model_name_or_path is None:
|
| 1235 |
+
# load default settings for v1/v2
|
| 1236 |
+
if v2:
|
| 1237 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
| 1238 |
+
else:
|
| 1239 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
| 1240 |
+
|
| 1241 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
| 1242 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
| 1243 |
+
if vae is None:
|
| 1244 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
| 1245 |
+
|
| 1246 |
+
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
|
| 1247 |
+
# TODO this consumes a lot of memory
|
| 1248 |
+
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
|
| 1249 |
+
diffusers_unet.load_state_dict(unet.state_dict())
|
| 1250 |
+
|
| 1251 |
+
pipeline = StableDiffusionPipeline(
|
| 1252 |
+
unet=diffusers_unet,
|
| 1253 |
+
text_encoder=text_encoder,
|
| 1254 |
+
vae=vae,
|
| 1255 |
+
scheduler=scheduler,
|
| 1256 |
+
tokenizer=tokenizer,
|
| 1257 |
+
safety_checker=None,
|
| 1258 |
+
feature_extractor=None,
|
| 1259 |
+
requires_safety_checker=None,
|
| 1260 |
+
)
|
| 1261 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
| 1262 |
+
|
| 1263 |
+
|
| 1264 |
+
VAE_PREFIX = "first_stage_model."
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
def load_vae(vae_id, dtype):
|
| 1268 |
+
logger.info(f"load VAE: {vae_id}")
|
| 1269 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
| 1270 |
+
# Diffusers local/remote
|
| 1271 |
+
try:
|
| 1272 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
| 1273 |
+
except EnvironmentError as e:
|
| 1274 |
+
logger.error(f"exception occurs in loading vae: {e}")
|
| 1275 |
+
logger.error("retry with subfolder='vae'")
|
| 1276 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
| 1277 |
+
return vae
|
| 1278 |
+
|
| 1279 |
+
# local
|
| 1280 |
+
vae_config = create_vae_diffusers_config()
|
| 1281 |
+
|
| 1282 |
+
if vae_id.endswith(".bin"):
|
| 1283 |
+
# SD 1.5 VAE on Huggingface
|
| 1284 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
| 1285 |
+
else:
|
| 1286 |
+
# StableDiffusion
|
| 1287 |
+
vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
|
| 1288 |
+
vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
|
| 1289 |
+
|
| 1290 |
+
# vae only or full model
|
| 1291 |
+
full_model = False
|
| 1292 |
+
for vae_key in vae_sd:
|
| 1293 |
+
if vae_key.startswith(VAE_PREFIX):
|
| 1294 |
+
full_model = True
|
| 1295 |
+
break
|
| 1296 |
+
if not full_model:
|
| 1297 |
+
sd = {}
|
| 1298 |
+
for key, value in vae_sd.items():
|
| 1299 |
+
sd[VAE_PREFIX + key] = value
|
| 1300 |
+
vae_sd = sd
|
| 1301 |
+
del sd
|
| 1302 |
+
|
| 1303 |
+
# Convert the VAE model.
|
| 1304 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
| 1305 |
+
|
| 1306 |
+
vae = AutoencoderKL(**vae_config)
|
| 1307 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 1308 |
+
return vae
|
| 1309 |
+
|
| 1310 |
+
|
| 1311 |
+
# endregion
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
| 1315 |
+
max_width, max_height = max_reso
|
| 1316 |
+
max_area = max_width * max_height
|
| 1317 |
+
|
| 1318 |
+
resos = set()
|
| 1319 |
+
|
| 1320 |
+
width = int(math.sqrt(max_area) // divisible) * divisible
|
| 1321 |
+
resos.add((width, width))
|
| 1322 |
+
|
| 1323 |
+
width = min_size
|
| 1324 |
+
while width <= max_size:
|
| 1325 |
+
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
| 1326 |
+
if height >= min_size:
|
| 1327 |
+
resos.add((width, height))
|
| 1328 |
+
resos.add((height, width))
|
| 1329 |
+
|
| 1330 |
+
# # make additional resos
|
| 1331 |
+
# if width >= height and width - divisible >= min_size:
|
| 1332 |
+
# resos.add((width - divisible, height))
|
| 1333 |
+
# resos.add((height, width - divisible))
|
| 1334 |
+
# if height >= width and height - divisible >= min_size:
|
| 1335 |
+
# resos.add((width, height - divisible))
|
| 1336 |
+
# resos.add((height - divisible, width))
|
| 1337 |
+
|
| 1338 |
+
width += divisible
|
| 1339 |
+
|
| 1340 |
+
resos = list(resos)
|
| 1341 |
+
resos.sort()
|
| 1342 |
+
return resos
|
| 1343 |
+
|
| 1344 |
+
|
| 1345 |
+
if __name__ == "__main__":
|
| 1346 |
+
resos = make_bucket_resolutions((512, 768))
|
| 1347 |
+
logger.info(f"{len(resos)}")
|
| 1348 |
+
logger.info(f"{resos}")
|
| 1349 |
+
aspect_ratios = [w / h for w, h in resos]
|
| 1350 |
+
logger.info(f"{aspect_ratios}")
|
| 1351 |
+
|
| 1352 |
+
ars = set()
|
| 1353 |
+
for ar in aspect_ratios:
|
| 1354 |
+
if ar in ars:
|
| 1355 |
+
logger.error(f"error! duplicate ar: {ar}")
|
| 1356 |
+
ars.add(ar)
|
library/original_unet.py
ADDED
|
@@ -0,0 +1,1919 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
|
| 2 |
+
# 条件分岐等で不要な部分は削除している
|
| 3 |
+
# コードの多くはDiffusersからコピーしている
|
| 4 |
+
# 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
|
| 5 |
+
|
| 6 |
+
# Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
|
| 7 |
+
# Unnecessary parts are deleted by condition branching.
|
| 8 |
+
# As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
v1.5とv2.1の相違点は
|
| 12 |
+
- attention_head_dimがintかlist[int]か
|
| 13 |
+
- cross_attention_dimが768か1024か
|
| 14 |
+
- use_linear_projection: trueがない(=False, 1.5)かあるか
|
| 15 |
+
- upcast_attentionがFalse(1.5)かTrue(2.1)か
|
| 16 |
+
- (以下は多分無視していい)
|
| 17 |
+
- sample_sizeが64か96か
|
| 18 |
+
- dual_cross_attentionがあるかないか
|
| 19 |
+
- num_class_embedsがあるかないか
|
| 20 |
+
- only_cross_attentionがあるかないか
|
| 21 |
+
|
| 22 |
+
v1.5
|
| 23 |
+
{
|
| 24 |
+
"_class_name": "UNet2DConditionModel",
|
| 25 |
+
"_diffusers_version": "0.6.0",
|
| 26 |
+
"act_fn": "silu",
|
| 27 |
+
"attention_head_dim": 8,
|
| 28 |
+
"block_out_channels": [
|
| 29 |
+
320,
|
| 30 |
+
640,
|
| 31 |
+
1280,
|
| 32 |
+
1280
|
| 33 |
+
],
|
| 34 |
+
"center_input_sample": false,
|
| 35 |
+
"cross_attention_dim": 768,
|
| 36 |
+
"down_block_types": [
|
| 37 |
+
"CrossAttnDownBlock2D",
|
| 38 |
+
"CrossAttnDownBlock2D",
|
| 39 |
+
"CrossAttnDownBlock2D",
|
| 40 |
+
"DownBlock2D"
|
| 41 |
+
],
|
| 42 |
+
"downsample_padding": 1,
|
| 43 |
+
"flip_sin_to_cos": true,
|
| 44 |
+
"freq_shift": 0,
|
| 45 |
+
"in_channels": 4,
|
| 46 |
+
"layers_per_block": 2,
|
| 47 |
+
"mid_block_scale_factor": 1,
|
| 48 |
+
"norm_eps": 1e-05,
|
| 49 |
+
"norm_num_groups": 32,
|
| 50 |
+
"out_channels": 4,
|
| 51 |
+
"sample_size": 64,
|
| 52 |
+
"up_block_types": [
|
| 53 |
+
"UpBlock2D",
|
| 54 |
+
"CrossAttnUpBlock2D",
|
| 55 |
+
"CrossAttnUpBlock2D",
|
| 56 |
+
"CrossAttnUpBlock2D"
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
v2.1
|
| 61 |
+
{
|
| 62 |
+
"_class_name": "UNet2DConditionModel",
|
| 63 |
+
"_diffusers_version": "0.10.0.dev0",
|
| 64 |
+
"act_fn": "silu",
|
| 65 |
+
"attention_head_dim": [
|
| 66 |
+
5,
|
| 67 |
+
10,
|
| 68 |
+
20,
|
| 69 |
+
20
|
| 70 |
+
],
|
| 71 |
+
"block_out_channels": [
|
| 72 |
+
320,
|
| 73 |
+
640,
|
| 74 |
+
1280,
|
| 75 |
+
1280
|
| 76 |
+
],
|
| 77 |
+
"center_input_sample": false,
|
| 78 |
+
"cross_attention_dim": 1024,
|
| 79 |
+
"down_block_types": [
|
| 80 |
+
"CrossAttnDownBlock2D",
|
| 81 |
+
"CrossAttnDownBlock2D",
|
| 82 |
+
"CrossAttnDownBlock2D",
|
| 83 |
+
"DownBlock2D"
|
| 84 |
+
],
|
| 85 |
+
"downsample_padding": 1,
|
| 86 |
+
"dual_cross_attention": false,
|
| 87 |
+
"flip_sin_to_cos": true,
|
| 88 |
+
"freq_shift": 0,
|
| 89 |
+
"in_channels": 4,
|
| 90 |
+
"layers_per_block": 2,
|
| 91 |
+
"mid_block_scale_factor": 1,
|
| 92 |
+
"norm_eps": 1e-05,
|
| 93 |
+
"norm_num_groups": 32,
|
| 94 |
+
"num_class_embeds": null,
|
| 95 |
+
"only_cross_attention": false,
|
| 96 |
+
"out_channels": 4,
|
| 97 |
+
"sample_size": 96,
|
| 98 |
+
"up_block_types": [
|
| 99 |
+
"UpBlock2D",
|
| 100 |
+
"CrossAttnUpBlock2D",
|
| 101 |
+
"CrossAttnUpBlock2D",
|
| 102 |
+
"CrossAttnUpBlock2D"
|
| 103 |
+
],
|
| 104 |
+
"use_linear_projection": true,
|
| 105 |
+
"upcast_attention": true
|
| 106 |
+
}
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
import math
|
| 110 |
+
from types import SimpleNamespace
|
| 111 |
+
from typing import Dict, Optional, Tuple, Union
|
| 112 |
+
import torch
|
| 113 |
+
from torch import nn
|
| 114 |
+
from torch.nn import functional as F
|
| 115 |
+
from einops import rearrange
|
| 116 |
+
from library.utils import setup_logging
|
| 117 |
+
setup_logging()
|
| 118 |
+
import logging
|
| 119 |
+
logger = logging.getLogger(__name__)
|
| 120 |
+
|
| 121 |
+
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
| 122 |
+
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
| 123 |
+
TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
|
| 124 |
+
IN_CHANNELS: int = 4
|
| 125 |
+
OUT_CHANNELS: int = 4
|
| 126 |
+
LAYERS_PER_BLOCK: int = 2
|
| 127 |
+
LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
|
| 128 |
+
TIME_EMBED_FLIP_SIN_TO_COS: bool = True
|
| 129 |
+
TIME_EMBED_FREQ_SHIFT: int = 0
|
| 130 |
+
NORM_GROUPS: int = 32
|
| 131 |
+
NORM_EPS: float = 1e-5
|
| 132 |
+
TRANSFORMER_NORM_NUM_GROUPS = 32
|
| 133 |
+
|
| 134 |
+
DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
|
| 135 |
+
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# region memory efficient attention
|
| 139 |
+
|
| 140 |
+
# FlashAttentionを使うCrossAttention
|
| 141 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
| 142 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
| 143 |
+
|
| 144 |
+
# constants
|
| 145 |
+
|
| 146 |
+
EPSILON = 1e-6
|
| 147 |
+
|
| 148 |
+
# helper functions
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def exists(val):
|
| 152 |
+
return val is not None
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def default(val, d):
|
| 156 |
+
return val if exists(val) else d
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# flash attention forwards and backwards
|
| 160 |
+
|
| 161 |
+
# https://arxiv.org/abs/2205.14135
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
| 165 |
+
@staticmethod
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
| 168 |
+
"""Algorithm 2 in the paper"""
|
| 169 |
+
|
| 170 |
+
device = q.device
|
| 171 |
+
dtype = q.dtype
|
| 172 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 173 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 174 |
+
|
| 175 |
+
o = torch.zeros_like(q)
|
| 176 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
| 177 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
| 178 |
+
|
| 179 |
+
scale = q.shape[-1] ** -0.5
|
| 180 |
+
|
| 181 |
+
if not exists(mask):
|
| 182 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
| 183 |
+
else:
|
| 184 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
| 185 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
| 186 |
+
|
| 187 |
+
row_splits = zip(
|
| 188 |
+
q.split(q_bucket_size, dim=-2),
|
| 189 |
+
o.split(q_bucket_size, dim=-2),
|
| 190 |
+
mask,
|
| 191 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
| 192 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
| 196 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 197 |
+
|
| 198 |
+
col_splits = zip(
|
| 199 |
+
k.split(k_bucket_size, dim=-2),
|
| 200 |
+
v.split(k_bucket_size, dim=-2),
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
| 204 |
+
k_start_index = k_ind * k_bucket_size
|
| 205 |
+
|
| 206 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 207 |
+
|
| 208 |
+
if exists(row_mask):
|
| 209 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
| 210 |
+
|
| 211 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 212 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
| 213 |
+
q_start_index - k_start_index + 1
|
| 214 |
+
)
|
| 215 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 216 |
+
|
| 217 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
| 218 |
+
attn_weights -= block_row_maxes
|
| 219 |
+
exp_weights = torch.exp(attn_weights)
|
| 220 |
+
|
| 221 |
+
if exists(row_mask):
|
| 222 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
| 223 |
+
|
| 224 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
| 225 |
+
|
| 226 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
| 227 |
+
|
| 228 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
| 229 |
+
|
| 230 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
| 231 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
| 232 |
+
|
| 233 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
| 234 |
+
|
| 235 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
| 236 |
+
|
| 237 |
+
row_maxes.copy_(new_row_maxes)
|
| 238 |
+
row_sums.copy_(new_row_sums)
|
| 239 |
+
|
| 240 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
| 241 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
| 242 |
+
|
| 243 |
+
return o
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
@torch.no_grad()
|
| 247 |
+
def backward(ctx, do):
|
| 248 |
+
"""Algorithm 4 in the paper"""
|
| 249 |
+
|
| 250 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
| 251 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
| 252 |
+
|
| 253 |
+
device = q.device
|
| 254 |
+
|
| 255 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 256 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 257 |
+
|
| 258 |
+
dq = torch.zeros_like(q)
|
| 259 |
+
dk = torch.zeros_like(k)
|
| 260 |
+
dv = torch.zeros_like(v)
|
| 261 |
+
|
| 262 |
+
row_splits = zip(
|
| 263 |
+
q.split(q_bucket_size, dim=-2),
|
| 264 |
+
o.split(q_bucket_size, dim=-2),
|
| 265 |
+
do.split(q_bucket_size, dim=-2),
|
| 266 |
+
mask,
|
| 267 |
+
l.split(q_bucket_size, dim=-2),
|
| 268 |
+
m.split(q_bucket_size, dim=-2),
|
| 269 |
+
dq.split(q_bucket_size, dim=-2),
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
| 273 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 274 |
+
|
| 275 |
+
col_splits = zip(
|
| 276 |
+
k.split(k_bucket_size, dim=-2),
|
| 277 |
+
v.split(k_bucket_size, dim=-2),
|
| 278 |
+
dk.split(k_bucket_size, dim=-2),
|
| 279 |
+
dv.split(k_bucket_size, dim=-2),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
| 283 |
+
k_start_index = k_ind * k_bucket_size
|
| 284 |
+
|
| 285 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 286 |
+
|
| 287 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 288 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
| 289 |
+
q_start_index - k_start_index + 1
|
| 290 |
+
)
|
| 291 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 292 |
+
|
| 293 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
| 294 |
+
|
| 295 |
+
if exists(row_mask):
|
| 296 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
| 297 |
+
|
| 298 |
+
p = exp_attn_weights / lc
|
| 299 |
+
|
| 300 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
| 301 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
| 302 |
+
|
| 303 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
| 304 |
+
ds = p * scale * (dp - D)
|
| 305 |
+
|
| 306 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
| 307 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
| 308 |
+
|
| 309 |
+
dqc.add_(dq_chunk)
|
| 310 |
+
dkc.add_(dk_chunk)
|
| 311 |
+
dvc.add_(dv_chunk)
|
| 312 |
+
|
| 313 |
+
return dq, dk, dv, None, None, None, None
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# endregion
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
| 320 |
+
return next(parameter.parameters()).dtype
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
| 324 |
+
return next(parameter.parameters()).device
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def get_timestep_embedding(
|
| 328 |
+
timesteps: torch.Tensor,
|
| 329 |
+
embedding_dim: int,
|
| 330 |
+
flip_sin_to_cos: bool = False,
|
| 331 |
+
downscale_freq_shift: float = 1,
|
| 332 |
+
scale: float = 1,
|
| 333 |
+
max_period: int = 10000,
|
| 334 |
+
):
|
| 335 |
+
"""
|
| 336 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 337 |
+
|
| 338 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 339 |
+
These may be fractional.
|
| 340 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 341 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
| 342 |
+
"""
|
| 343 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 344 |
+
|
| 345 |
+
half_dim = embedding_dim // 2
|
| 346 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 347 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 348 |
+
|
| 349 |
+
emb = torch.exp(exponent)
|
| 350 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 351 |
+
|
| 352 |
+
# scale embeddings
|
| 353 |
+
emb = scale * emb
|
| 354 |
+
|
| 355 |
+
# concat sine and cosine embeddings
|
| 356 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 357 |
+
|
| 358 |
+
# flip sine and cosine embeddings
|
| 359 |
+
if flip_sin_to_cos:
|
| 360 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 361 |
+
|
| 362 |
+
# zero pad
|
| 363 |
+
if embedding_dim % 2 == 1:
|
| 364 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 365 |
+
return emb
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
| 369 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
| 370 |
+
org_dtype = x.dtype
|
| 371 |
+
if org_dtype == torch.bfloat16:
|
| 372 |
+
x = x.to(torch.float32)
|
| 373 |
+
|
| 374 |
+
if x.shape[-2:] != target.shape[-2:]:
|
| 375 |
+
if mode == "nearest":
|
| 376 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
| 377 |
+
else:
|
| 378 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
| 379 |
+
|
| 380 |
+
if org_dtype == torch.bfloat16:
|
| 381 |
+
x = x.to(org_dtype)
|
| 382 |
+
return x
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class SampleOutput:
|
| 386 |
+
def __init__(self, sample):
|
| 387 |
+
self.sample = sample
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class TimestepEmbedding(nn.Module):
|
| 391 |
+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
| 392 |
+
super().__init__()
|
| 393 |
+
|
| 394 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
| 395 |
+
self.act = None
|
| 396 |
+
if act_fn == "silu":
|
| 397 |
+
self.act = nn.SiLU()
|
| 398 |
+
elif act_fn == "mish":
|
| 399 |
+
self.act = nn.Mish()
|
| 400 |
+
|
| 401 |
+
if out_dim is not None:
|
| 402 |
+
time_embed_dim_out = out_dim
|
| 403 |
+
else:
|
| 404 |
+
time_embed_dim_out = time_embed_dim
|
| 405 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
| 406 |
+
|
| 407 |
+
def forward(self, sample):
|
| 408 |
+
sample = self.linear_1(sample)
|
| 409 |
+
|
| 410 |
+
if self.act is not None:
|
| 411 |
+
sample = self.act(sample)
|
| 412 |
+
|
| 413 |
+
sample = self.linear_2(sample)
|
| 414 |
+
return sample
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class Timesteps(nn.Module):
|
| 418 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.num_channels = num_channels
|
| 421 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 422 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 423 |
+
|
| 424 |
+
def forward(self, timesteps):
|
| 425 |
+
t_emb = get_timestep_embedding(
|
| 426 |
+
timesteps,
|
| 427 |
+
self.num_channels,
|
| 428 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 429 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 430 |
+
)
|
| 431 |
+
return t_emb
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class ResnetBlock2D(nn.Module):
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
in_channels,
|
| 438 |
+
out_channels,
|
| 439 |
+
):
|
| 440 |
+
super().__init__()
|
| 441 |
+
self.in_channels = in_channels
|
| 442 |
+
self.out_channels = out_channels
|
| 443 |
+
|
| 444 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
|
| 445 |
+
|
| 446 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 447 |
+
|
| 448 |
+
self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
|
| 449 |
+
|
| 450 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
|
| 451 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 452 |
+
|
| 453 |
+
# if non_linearity == "swish":
|
| 454 |
+
self.nonlinearity = lambda x: F.silu(x)
|
| 455 |
+
|
| 456 |
+
self.use_in_shortcut = self.in_channels != self.out_channels
|
| 457 |
+
|
| 458 |
+
self.conv_shortcut = None
|
| 459 |
+
if self.use_in_shortcut:
|
| 460 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 461 |
+
|
| 462 |
+
def forward(self, input_tensor, temb):
|
| 463 |
+
hidden_states = input_tensor
|
| 464 |
+
|
| 465 |
+
hidden_states = self.norm1(hidden_states)
|
| 466 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 467 |
+
|
| 468 |
+
hidden_states = self.conv1(hidden_states)
|
| 469 |
+
|
| 470 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
| 471 |
+
hidden_states = hidden_states + temb
|
| 472 |
+
|
| 473 |
+
hidden_states = self.norm2(hidden_states)
|
| 474 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 475 |
+
|
| 476 |
+
hidden_states = self.conv2(hidden_states)
|
| 477 |
+
|
| 478 |
+
if self.conv_shortcut is not None:
|
| 479 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 480 |
+
|
| 481 |
+
output_tensor = input_tensor + hidden_states
|
| 482 |
+
|
| 483 |
+
return output_tensor
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class DownBlock2D(nn.Module):
|
| 487 |
+
def __init__(
|
| 488 |
+
self,
|
| 489 |
+
in_channels: int,
|
| 490 |
+
out_channels: int,
|
| 491 |
+
add_downsample=True,
|
| 492 |
+
):
|
| 493 |
+
super().__init__()
|
| 494 |
+
|
| 495 |
+
self.has_cross_attention = False
|
| 496 |
+
resnets = []
|
| 497 |
+
|
| 498 |
+
for i in range(LAYERS_PER_BLOCK):
|
| 499 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 500 |
+
resnets.append(
|
| 501 |
+
ResnetBlock2D(
|
| 502 |
+
in_channels=in_channels,
|
| 503 |
+
out_channels=out_channels,
|
| 504 |
+
)
|
| 505 |
+
)
|
| 506 |
+
self.resnets = nn.ModuleList(resnets)
|
| 507 |
+
|
| 508 |
+
if add_downsample:
|
| 509 |
+
self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
|
| 510 |
+
else:
|
| 511 |
+
self.downsamplers = None
|
| 512 |
+
|
| 513 |
+
self.gradient_checkpointing = False
|
| 514 |
+
|
| 515 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 516 |
+
pass
|
| 517 |
+
|
| 518 |
+
def set_use_sdpa(self, sdpa):
|
| 519 |
+
pass
|
| 520 |
+
|
| 521 |
+
def forward(self, hidden_states, temb=None):
|
| 522 |
+
output_states = ()
|
| 523 |
+
|
| 524 |
+
for resnet in self.resnets:
|
| 525 |
+
if self.training and self.gradient_checkpointing:
|
| 526 |
+
|
| 527 |
+
def create_custom_forward(module):
|
| 528 |
+
def custom_forward(*inputs):
|
| 529 |
+
return module(*inputs)
|
| 530 |
+
|
| 531 |
+
return custom_forward
|
| 532 |
+
|
| 533 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 534 |
+
else:
|
| 535 |
+
hidden_states = resnet(hidden_states, temb)
|
| 536 |
+
|
| 537 |
+
output_states += (hidden_states,)
|
| 538 |
+
|
| 539 |
+
if self.downsamplers is not None:
|
| 540 |
+
for downsampler in self.downsamplers:
|
| 541 |
+
hidden_states = downsampler(hidden_states)
|
| 542 |
+
|
| 543 |
+
output_states += (hidden_states,)
|
| 544 |
+
|
| 545 |
+
return hidden_states, output_states
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
class Downsample2D(nn.Module):
|
| 549 |
+
def __init__(self, channels, out_channels):
|
| 550 |
+
super().__init__()
|
| 551 |
+
|
| 552 |
+
self.channels = channels
|
| 553 |
+
self.out_channels = out_channels
|
| 554 |
+
|
| 555 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
| 556 |
+
|
| 557 |
+
def forward(self, hidden_states):
|
| 558 |
+
assert hidden_states.shape[1] == self.channels
|
| 559 |
+
hidden_states = self.conv(hidden_states)
|
| 560 |
+
|
| 561 |
+
return hidden_states
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class CrossAttention(nn.Module):
|
| 565 |
+
def __init__(
|
| 566 |
+
self,
|
| 567 |
+
query_dim: int,
|
| 568 |
+
cross_attention_dim: Optional[int] = None,
|
| 569 |
+
heads: int = 8,
|
| 570 |
+
dim_head: int = 64,
|
| 571 |
+
upcast_attention: bool = False,
|
| 572 |
+
):
|
| 573 |
+
super().__init__()
|
| 574 |
+
inner_dim = dim_head * heads
|
| 575 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 576 |
+
self.upcast_attention = upcast_attention
|
| 577 |
+
|
| 578 |
+
self.scale = dim_head**-0.5
|
| 579 |
+
self.heads = heads
|
| 580 |
+
|
| 581 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 582 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 583 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 584 |
+
|
| 585 |
+
self.to_out = nn.ModuleList([])
|
| 586 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
| 587 |
+
# no dropout here
|
| 588 |
+
|
| 589 |
+
self.use_memory_efficient_attention_xformers = False
|
| 590 |
+
self.use_memory_efficient_attention_mem_eff = False
|
| 591 |
+
self.use_sdpa = False
|
| 592 |
+
|
| 593 |
+
# Attention processor
|
| 594 |
+
self.processor = None
|
| 595 |
+
|
| 596 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 597 |
+
self.use_memory_efficient_attention_xformers = xformers
|
| 598 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
| 599 |
+
|
| 600 |
+
def set_use_sdpa(self, sdpa):
|
| 601 |
+
self.use_sdpa = sdpa
|
| 602 |
+
|
| 603 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
| 604 |
+
batch_size, seq_len, dim = tensor.shape
|
| 605 |
+
head_size = self.heads
|
| 606 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 607 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 608 |
+
return tensor
|
| 609 |
+
|
| 610 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
| 611 |
+
batch_size, seq_len, dim = tensor.shape
|
| 612 |
+
head_size = self.heads
|
| 613 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 614 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 615 |
+
return tensor
|
| 616 |
+
|
| 617 |
+
def set_processor(self):
|
| 618 |
+
return self.processor
|
| 619 |
+
|
| 620 |
+
def get_processor(self):
|
| 621 |
+
return self.processor
|
| 622 |
+
|
| 623 |
+
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
| 624 |
+
if self.processor is not None:
|
| 625 |
+
(
|
| 626 |
+
hidden_states,
|
| 627 |
+
encoder_hidden_states,
|
| 628 |
+
attention_mask,
|
| 629 |
+
) = translate_attention_names_from_diffusers(
|
| 630 |
+
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
| 631 |
+
)
|
| 632 |
+
return self.processor(
|
| 633 |
+
attn=self,
|
| 634 |
+
hidden_states=hidden_states,
|
| 635 |
+
encoder_hidden_states=context,
|
| 636 |
+
attention_mask=mask,
|
| 637 |
+
**kwargs
|
| 638 |
+
)
|
| 639 |
+
if self.use_memory_efficient_attention_xformers:
|
| 640 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
| 641 |
+
if self.use_memory_efficient_attention_mem_eff:
|
| 642 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
| 643 |
+
if self.use_sdpa:
|
| 644 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
| 645 |
+
|
| 646 |
+
query = self.to_q(hidden_states)
|
| 647 |
+
context = context if context is not None else hidden_states
|
| 648 |
+
key = self.to_k(context)
|
| 649 |
+
value = self.to_v(context)
|
| 650 |
+
|
| 651 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 652 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 653 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 654 |
+
|
| 655 |
+
hidden_states = self._attention(query, key, value)
|
| 656 |
+
|
| 657 |
+
# linear proj
|
| 658 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 659 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
| 660 |
+
return hidden_states
|
| 661 |
+
|
| 662 |
+
def _attention(self, query, key, value):
|
| 663 |
+
if self.upcast_attention:
|
| 664 |
+
query = query.float()
|
| 665 |
+
key = key.float()
|
| 666 |
+
|
| 667 |
+
attention_scores = torch.baddbmm(
|
| 668 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
| 669 |
+
query,
|
| 670 |
+
key.transpose(-1, -2),
|
| 671 |
+
beta=0,
|
| 672 |
+
alpha=self.scale,
|
| 673 |
+
)
|
| 674 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 675 |
+
|
| 676 |
+
# cast back to the original dtype
|
| 677 |
+
attention_probs = attention_probs.to(value.dtype)
|
| 678 |
+
|
| 679 |
+
# compute attention output
|
| 680 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 681 |
+
|
| 682 |
+
# reshape hidden_states
|
| 683 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 684 |
+
return hidden_states
|
| 685 |
+
|
| 686 |
+
# TODO support Hypernetworks
|
| 687 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
| 688 |
+
import xformers.ops
|
| 689 |
+
|
| 690 |
+
h = self.heads
|
| 691 |
+
q_in = self.to_q(x)
|
| 692 |
+
context = context if context is not None else x
|
| 693 |
+
context = context.to(x.dtype)
|
| 694 |
+
k_in = self.to_k(context)
|
| 695 |
+
v_in = self.to_v(context)
|
| 696 |
+
|
| 697 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
| 698 |
+
del q_in, k_in, v_in
|
| 699 |
+
|
| 700 |
+
q = q.contiguous()
|
| 701 |
+
k = k.contiguous()
|
| 702 |
+
v = v.contiguous()
|
| 703 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
| 704 |
+
|
| 705 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
| 706 |
+
|
| 707 |
+
out = self.to_out[0](out)
|
| 708 |
+
return out
|
| 709 |
+
|
| 710 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
| 711 |
+
flash_func = FlashAttentionFunction
|
| 712 |
+
|
| 713 |
+
q_bucket_size = 512
|
| 714 |
+
k_bucket_size = 1024
|
| 715 |
+
|
| 716 |
+
h = self.heads
|
| 717 |
+
q = self.to_q(x)
|
| 718 |
+
context = context if context is not None else x
|
| 719 |
+
context = context.to(x.dtype)
|
| 720 |
+
k = self.to_k(context)
|
| 721 |
+
v = self.to_v(context)
|
| 722 |
+
del context, x
|
| 723 |
+
|
| 724 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 725 |
+
|
| 726 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
| 727 |
+
|
| 728 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 729 |
+
|
| 730 |
+
out = self.to_out[0](out)
|
| 731 |
+
return out
|
| 732 |
+
|
| 733 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
| 734 |
+
h = self.heads
|
| 735 |
+
q_in = self.to_q(x)
|
| 736 |
+
context = context if context is not None else x
|
| 737 |
+
context = context.to(x.dtype)
|
| 738 |
+
k_in = self.to_k(context)
|
| 739 |
+
v_in = self.to_v(context)
|
| 740 |
+
|
| 741 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
| 742 |
+
del q_in, k_in, v_in
|
| 743 |
+
|
| 744 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 745 |
+
|
| 746 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
| 747 |
+
|
| 748 |
+
out = self.to_out[0](out)
|
| 749 |
+
return out
|
| 750 |
+
|
| 751 |
+
def translate_attention_names_from_diffusers(
|
| 752 |
+
hidden_states: torch.FloatTensor,
|
| 753 |
+
context: Optional[torch.FloatTensor] = None,
|
| 754 |
+
mask: Optional[torch.FloatTensor] = None,
|
| 755 |
+
# HF naming
|
| 756 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 757 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
| 758 |
+
):
|
| 759 |
+
# translate from hugging face diffusers
|
| 760 |
+
context = context if context is not None else encoder_hidden_states
|
| 761 |
+
|
| 762 |
+
# translate from hugging face diffusers
|
| 763 |
+
mask = mask if mask is not None else attention_mask
|
| 764 |
+
|
| 765 |
+
return hidden_states, context, mask
|
| 766 |
+
|
| 767 |
+
# feedforward
|
| 768 |
+
class GEGLU(nn.Module):
|
| 769 |
+
r"""
|
| 770 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
| 771 |
+
|
| 772 |
+
Parameters:
|
| 773 |
+
dim_in (`int`): The number of channels in the input.
|
| 774 |
+
dim_out (`int`): The number of channels in the output.
|
| 775 |
+
"""
|
| 776 |
+
|
| 777 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 778 |
+
super().__init__()
|
| 779 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 780 |
+
|
| 781 |
+
def gelu(self, gate):
|
| 782 |
+
if gate.device.type != "mps":
|
| 783 |
+
return F.gelu(gate)
|
| 784 |
+
# mps: gelu is not implemented for float16
|
| 785 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 786 |
+
|
| 787 |
+
def forward(self, hidden_states):
|
| 788 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
| 789 |
+
return hidden_states * self.gelu(gate)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
class FeedForward(nn.Module):
|
| 793 |
+
def __init__(
|
| 794 |
+
self,
|
| 795 |
+
dim: int,
|
| 796 |
+
):
|
| 797 |
+
super().__init__()
|
| 798 |
+
inner_dim = int(dim * 4) # mult is always 4
|
| 799 |
+
|
| 800 |
+
self.net = nn.ModuleList([])
|
| 801 |
+
# project in
|
| 802 |
+
self.net.append(GEGLU(dim, inner_dim))
|
| 803 |
+
# project dropout
|
| 804 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
| 805 |
+
# project out
|
| 806 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
| 807 |
+
|
| 808 |
+
def forward(self, hidden_states):
|
| 809 |
+
for module in self.net:
|
| 810 |
+
hidden_states = module(hidden_states)
|
| 811 |
+
return hidden_states
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
class BasicTransformerBlock(nn.Module):
|
| 815 |
+
def __init__(
|
| 816 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
| 817 |
+
):
|
| 818 |
+
super().__init__()
|
| 819 |
+
|
| 820 |
+
# 1. Self-Attn
|
| 821 |
+
self.attn1 = CrossAttention(
|
| 822 |
+
query_dim=dim,
|
| 823 |
+
cross_attention_dim=None,
|
| 824 |
+
heads=num_attention_heads,
|
| 825 |
+
dim_head=attention_head_dim,
|
| 826 |
+
upcast_attention=upcast_attention,
|
| 827 |
+
)
|
| 828 |
+
self.ff = FeedForward(dim)
|
| 829 |
+
|
| 830 |
+
# 2. Cross-Attn
|
| 831 |
+
self.attn2 = CrossAttention(
|
| 832 |
+
query_dim=dim,
|
| 833 |
+
cross_attention_dim=cross_attention_dim,
|
| 834 |
+
heads=num_attention_heads,
|
| 835 |
+
dim_head=attention_head_dim,
|
| 836 |
+
upcast_attention=upcast_attention,
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 840 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 841 |
+
|
| 842 |
+
# 3. Feed-forward
|
| 843 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 844 |
+
|
| 845 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
| 846 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 847 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 848 |
+
|
| 849 |
+
def set_use_sdpa(self, sdpa: bool):
|
| 850 |
+
self.attn1.set_use_sdpa(sdpa)
|
| 851 |
+
self.attn2.set_use_sdpa(sdpa)
|
| 852 |
+
|
| 853 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
| 854 |
+
# 1. Self-Attention
|
| 855 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 856 |
+
|
| 857 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
| 858 |
+
|
| 859 |
+
# 2. Cross-Attention
|
| 860 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 861 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
| 862 |
+
|
| 863 |
+
# 3. Feed-forward
|
| 864 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 865 |
+
|
| 866 |
+
return hidden_states
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class Transformer2DModel(nn.Module):
|
| 870 |
+
def __init__(
|
| 871 |
+
self,
|
| 872 |
+
num_attention_heads: int = 16,
|
| 873 |
+
attention_head_dim: int = 88,
|
| 874 |
+
in_channels: Optional[int] = None,
|
| 875 |
+
cross_attention_dim: Optional[int] = None,
|
| 876 |
+
use_linear_projection: bool = False,
|
| 877 |
+
upcast_attention: bool = False,
|
| 878 |
+
):
|
| 879 |
+
super().__init__()
|
| 880 |
+
self.in_channels = in_channels
|
| 881 |
+
self.num_attention_heads = num_attention_heads
|
| 882 |
+
self.attention_head_dim = attention_head_dim
|
| 883 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 884 |
+
self.use_linear_projection = use_linear_projection
|
| 885 |
+
|
| 886 |
+
self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
|
| 887 |
+
|
| 888 |
+
if use_linear_projection:
|
| 889 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 890 |
+
else:
|
| 891 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 892 |
+
|
| 893 |
+
self.transformer_blocks = nn.ModuleList(
|
| 894 |
+
[
|
| 895 |
+
BasicTransformerBlock(
|
| 896 |
+
inner_dim,
|
| 897 |
+
num_attention_heads,
|
| 898 |
+
attention_head_dim,
|
| 899 |
+
cross_attention_dim=cross_attention_dim,
|
| 900 |
+
upcast_attention=upcast_attention,
|
| 901 |
+
)
|
| 902 |
+
]
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
if use_linear_projection:
|
| 906 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
| 907 |
+
else:
|
| 908 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 909 |
+
|
| 910 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 911 |
+
for transformer in self.transformer_blocks:
|
| 912 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 913 |
+
|
| 914 |
+
def set_use_sdpa(self, sdpa):
|
| 915 |
+
for transformer in self.transformer_blocks:
|
| 916 |
+
transformer.set_use_sdpa(sdpa)
|
| 917 |
+
|
| 918 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
| 919 |
+
# 1. Input
|
| 920 |
+
batch, _, height, weight = hidden_states.shape
|
| 921 |
+
residual = hidden_states
|
| 922 |
+
|
| 923 |
+
hidden_states = self.norm(hidden_states)
|
| 924 |
+
if not self.use_linear_projection:
|
| 925 |
+
hidden_states = self.proj_in(hidden_states)
|
| 926 |
+
inner_dim = hidden_states.shape[1]
|
| 927 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 928 |
+
else:
|
| 929 |
+
inner_dim = hidden_states.shape[1]
|
| 930 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 931 |
+
hidden_states = self.proj_in(hidden_states)
|
| 932 |
+
|
| 933 |
+
# 2. Blocks
|
| 934 |
+
for block in self.transformer_blocks:
|
| 935 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
| 936 |
+
|
| 937 |
+
# 3. Output
|
| 938 |
+
if not self.use_linear_projection:
|
| 939 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 940 |
+
hidden_states = self.proj_out(hidden_states)
|
| 941 |
+
else:
|
| 942 |
+
hidden_states = self.proj_out(hidden_states)
|
| 943 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 944 |
+
|
| 945 |
+
output = hidden_states + residual
|
| 946 |
+
|
| 947 |
+
if not return_dict:
|
| 948 |
+
return (output,)
|
| 949 |
+
|
| 950 |
+
return SampleOutput(sample=output)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
class CrossAttnDownBlock2D(nn.Module):
|
| 954 |
+
def __init__(
|
| 955 |
+
self,
|
| 956 |
+
in_channels: int,
|
| 957 |
+
out_channels: int,
|
| 958 |
+
add_downsample=True,
|
| 959 |
+
cross_attention_dim=1280,
|
| 960 |
+
attn_num_head_channels=1,
|
| 961 |
+
use_linear_projection=False,
|
| 962 |
+
upcast_attention=False,
|
| 963 |
+
):
|
| 964 |
+
super().__init__()
|
| 965 |
+
self.has_cross_attention = True
|
| 966 |
+
resnets = []
|
| 967 |
+
attentions = []
|
| 968 |
+
|
| 969 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 970 |
+
|
| 971 |
+
for i in range(LAYERS_PER_BLOCK):
|
| 972 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 973 |
+
|
| 974 |
+
resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
|
| 975 |
+
attentions.append(
|
| 976 |
+
Transformer2DModel(
|
| 977 |
+
attn_num_head_channels,
|
| 978 |
+
out_channels // attn_num_head_channels,
|
| 979 |
+
in_channels=out_channels,
|
| 980 |
+
cross_attention_dim=cross_attention_dim,
|
| 981 |
+
use_linear_projection=use_linear_projection,
|
| 982 |
+
upcast_attention=upcast_attention,
|
| 983 |
+
)
|
| 984 |
+
)
|
| 985 |
+
self.attentions = nn.ModuleList(attentions)
|
| 986 |
+
self.resnets = nn.ModuleList(resnets)
|
| 987 |
+
|
| 988 |
+
if add_downsample:
|
| 989 |
+
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
|
| 990 |
+
else:
|
| 991 |
+
self.downsamplers = None
|
| 992 |
+
|
| 993 |
+
self.gradient_checkpointing = False
|
| 994 |
+
|
| 995 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 996 |
+
for attn in self.attentions:
|
| 997 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 998 |
+
|
| 999 |
+
def set_use_sdpa(self, sdpa):
|
| 1000 |
+
for attn in self.attentions:
|
| 1001 |
+
attn.set_use_sdpa(sdpa)
|
| 1002 |
+
|
| 1003 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
| 1004 |
+
output_states = ()
|
| 1005 |
+
|
| 1006 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 1007 |
+
if self.training and self.gradient_checkpointing:
|
| 1008 |
+
|
| 1009 |
+
def create_custom_forward(module, return_dict=None):
|
| 1010 |
+
def custom_forward(*inputs):
|
| 1011 |
+
if return_dict is not None:
|
| 1012 |
+
return module(*inputs, return_dict=return_dict)
|
| 1013 |
+
else:
|
| 1014 |
+
return module(*inputs)
|
| 1015 |
+
|
| 1016 |
+
return custom_forward
|
| 1017 |
+
|
| 1018 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 1019 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1020 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
| 1021 |
+
)[0]
|
| 1022 |
+
else:
|
| 1023 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1024 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 1025 |
+
|
| 1026 |
+
output_states += (hidden_states,)
|
| 1027 |
+
|
| 1028 |
+
if self.downsamplers is not None:
|
| 1029 |
+
for downsampler in self.downsamplers:
|
| 1030 |
+
hidden_states = downsampler(hidden_states)
|
| 1031 |
+
|
| 1032 |
+
output_states += (hidden_states,)
|
| 1033 |
+
|
| 1034 |
+
return hidden_states, output_states
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
class UNetMidBlock2DCrossAttn(nn.Module):
|
| 1038 |
+
def __init__(
|
| 1039 |
+
self,
|
| 1040 |
+
in_channels: int,
|
| 1041 |
+
attn_num_head_channels=1,
|
| 1042 |
+
cross_attention_dim=1280,
|
| 1043 |
+
use_linear_projection=False,
|
| 1044 |
+
):
|
| 1045 |
+
super().__init__()
|
| 1046 |
+
|
| 1047 |
+
self.has_cross_attention = True
|
| 1048 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 1049 |
+
|
| 1050 |
+
# Middle block has two resnets and one attention
|
| 1051 |
+
resnets = [
|
| 1052 |
+
ResnetBlock2D(
|
| 1053 |
+
in_channels=in_channels,
|
| 1054 |
+
out_channels=in_channels,
|
| 1055 |
+
),
|
| 1056 |
+
ResnetBlock2D(
|
| 1057 |
+
in_channels=in_channels,
|
| 1058 |
+
out_channels=in_channels,
|
| 1059 |
+
),
|
| 1060 |
+
]
|
| 1061 |
+
attentions = [
|
| 1062 |
+
Transformer2DModel(
|
| 1063 |
+
attn_num_head_channels,
|
| 1064 |
+
in_channels // attn_num_head_channels,
|
| 1065 |
+
in_channels=in_channels,
|
| 1066 |
+
cross_attention_dim=cross_attention_dim,
|
| 1067 |
+
use_linear_projection=use_linear_projection,
|
| 1068 |
+
)
|
| 1069 |
+
]
|
| 1070 |
+
|
| 1071 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1072 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1073 |
+
|
| 1074 |
+
self.gradient_checkpointing = False
|
| 1075 |
+
|
| 1076 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 1077 |
+
for attn in self.attentions:
|
| 1078 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 1079 |
+
|
| 1080 |
+
def set_use_sdpa(self, sdpa):
|
| 1081 |
+
for attn in self.attentions:
|
| 1082 |
+
attn.set_use_sdpa(sdpa)
|
| 1083 |
+
|
| 1084 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
| 1085 |
+
for i, resnet in enumerate(self.resnets):
|
| 1086 |
+
attn = None if i == 0 else self.attentions[i - 1]
|
| 1087 |
+
|
| 1088 |
+
if self.training and self.gradient_checkpointing:
|
| 1089 |
+
|
| 1090 |
+
def create_custom_forward(module, return_dict=None):
|
| 1091 |
+
def custom_forward(*inputs):
|
| 1092 |
+
if return_dict is not None:
|
| 1093 |
+
return module(*inputs, return_dict=return_dict)
|
| 1094 |
+
else:
|
| 1095 |
+
return module(*inputs)
|
| 1096 |
+
|
| 1097 |
+
return custom_forward
|
| 1098 |
+
|
| 1099 |
+
if attn is not None:
|
| 1100 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1101 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
| 1102 |
+
)[0]
|
| 1103 |
+
|
| 1104 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 1105 |
+
else:
|
| 1106 |
+
if attn is not None:
|
| 1107 |
+
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
| 1108 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1109 |
+
|
| 1110 |
+
return hidden_states
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
class Upsample2D(nn.Module):
|
| 1114 |
+
def __init__(self, channels, out_channels):
|
| 1115 |
+
super().__init__()
|
| 1116 |
+
self.channels = channels
|
| 1117 |
+
self.out_channels = out_channels
|
| 1118 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
| 1119 |
+
|
| 1120 |
+
def forward(self, hidden_states, output_size):
|
| 1121 |
+
assert hidden_states.shape[1] == self.channels
|
| 1122 |
+
|
| 1123 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 1124 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
| 1125 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
| 1126 |
+
dtype = hidden_states.dtype
|
| 1127 |
+
if dtype == torch.bfloat16:
|
| 1128 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 1129 |
+
|
| 1130 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
| 1131 |
+
if hidden_states.shape[0] >= 64:
|
| 1132 |
+
hidden_states = hidden_states.contiguous()
|
| 1133 |
+
|
| 1134 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
| 1135 |
+
if output_size is None:
|
| 1136 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 1137 |
+
else:
|
| 1138 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
| 1139 |
+
|
| 1140 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 1141 |
+
if dtype == torch.bfloat16:
|
| 1142 |
+
hidden_states = hidden_states.to(dtype)
|
| 1143 |
+
|
| 1144 |
+
hidden_states = self.conv(hidden_states)
|
| 1145 |
+
|
| 1146 |
+
return hidden_states
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
class UpBlock2D(nn.Module):
|
| 1150 |
+
def __init__(
|
| 1151 |
+
self,
|
| 1152 |
+
in_channels: int,
|
| 1153 |
+
prev_output_channel: int,
|
| 1154 |
+
out_channels: int,
|
| 1155 |
+
add_upsample=True,
|
| 1156 |
+
):
|
| 1157 |
+
super().__init__()
|
| 1158 |
+
|
| 1159 |
+
self.has_cross_attention = False
|
| 1160 |
+
resnets = []
|
| 1161 |
+
|
| 1162 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
| 1163 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
| 1164 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1165 |
+
|
| 1166 |
+
resnets.append(
|
| 1167 |
+
ResnetBlock2D(
|
| 1168 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1169 |
+
out_channels=out_channels,
|
| 1170 |
+
)
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1174 |
+
|
| 1175 |
+
if add_upsample:
|
| 1176 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
| 1177 |
+
else:
|
| 1178 |
+
self.upsamplers = None
|
| 1179 |
+
|
| 1180 |
+
self.gradient_checkpointing = False
|
| 1181 |
+
|
| 1182 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 1183 |
+
pass
|
| 1184 |
+
|
| 1185 |
+
def set_use_sdpa(self, sdpa):
|
| 1186 |
+
pass
|
| 1187 |
+
|
| 1188 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 1189 |
+
for resnet in self.resnets:
|
| 1190 |
+
# pop res hidden states
|
| 1191 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1192 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1193 |
+
|
| 1194 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1195 |
+
|
| 1196 |
+
if self.training and self.gradient_checkpointing:
|
| 1197 |
+
|
| 1198 |
+
def create_custom_forward(module):
|
| 1199 |
+
def custom_forward(*inputs):
|
| 1200 |
+
return module(*inputs)
|
| 1201 |
+
|
| 1202 |
+
return custom_forward
|
| 1203 |
+
|
| 1204 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 1205 |
+
else:
|
| 1206 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1207 |
+
|
| 1208 |
+
if self.upsamplers is not None:
|
| 1209 |
+
for upsampler in self.upsamplers:
|
| 1210 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1211 |
+
|
| 1212 |
+
return hidden_states
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
class CrossAttnUpBlock2D(nn.Module):
|
| 1216 |
+
def __init__(
|
| 1217 |
+
self,
|
| 1218 |
+
in_channels: int,
|
| 1219 |
+
out_channels: int,
|
| 1220 |
+
prev_output_channel: int,
|
| 1221 |
+
attn_num_head_channels=1,
|
| 1222 |
+
cross_attention_dim=1280,
|
| 1223 |
+
add_upsample=True,
|
| 1224 |
+
use_linear_projection=False,
|
| 1225 |
+
upcast_attention=False,
|
| 1226 |
+
):
|
| 1227 |
+
super().__init__()
|
| 1228 |
+
resnets = []
|
| 1229 |
+
attentions = []
|
| 1230 |
+
|
| 1231 |
+
self.has_cross_attention = True
|
| 1232 |
+
self.attn_num_head_channels = attn_num_head_channels
|
| 1233 |
+
|
| 1234 |
+
for i in range(LAYERS_PER_BLOCK_UP):
|
| 1235 |
+
res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
|
| 1236 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
| 1237 |
+
|
| 1238 |
+
resnets.append(
|
| 1239 |
+
ResnetBlock2D(
|
| 1240 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
| 1241 |
+
out_channels=out_channels,
|
| 1242 |
+
)
|
| 1243 |
+
)
|
| 1244 |
+
attentions.append(
|
| 1245 |
+
Transformer2DModel(
|
| 1246 |
+
attn_num_head_channels,
|
| 1247 |
+
out_channels // attn_num_head_channels,
|
| 1248 |
+
in_channels=out_channels,
|
| 1249 |
+
cross_attention_dim=cross_attention_dim,
|
| 1250 |
+
use_linear_projection=use_linear_projection,
|
| 1251 |
+
upcast_attention=upcast_attention,
|
| 1252 |
+
)
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
self.attentions = nn.ModuleList(attentions)
|
| 1256 |
+
self.resnets = nn.ModuleList(resnets)
|
| 1257 |
+
|
| 1258 |
+
if add_upsample:
|
| 1259 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
|
| 1260 |
+
else:
|
| 1261 |
+
self.upsamplers = None
|
| 1262 |
+
|
| 1263 |
+
self.gradient_checkpointing = False
|
| 1264 |
+
|
| 1265 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 1266 |
+
for attn in self.attentions:
|
| 1267 |
+
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 1268 |
+
|
| 1269 |
+
def set_use_sdpa(self, sdpa):
|
| 1270 |
+
for attn in self.attentions:
|
| 1271 |
+
attn.set_use_sdpa(sdpa)
|
| 1272 |
+
|
| 1273 |
+
def forward(
|
| 1274 |
+
self,
|
| 1275 |
+
hidden_states,
|
| 1276 |
+
res_hidden_states_tuple,
|
| 1277 |
+
temb=None,
|
| 1278 |
+
encoder_hidden_states=None,
|
| 1279 |
+
upsample_size=None,
|
| 1280 |
+
):
|
| 1281 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 1282 |
+
# pop res hidden states
|
| 1283 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1284 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1285 |
+
|
| 1286 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1287 |
+
|
| 1288 |
+
if self.training and self.gradient_checkpointing:
|
| 1289 |
+
|
| 1290 |
+
def create_custom_forward(module, return_dict=None):
|
| 1291 |
+
def custom_forward(*inputs):
|
| 1292 |
+
if return_dict is not None:
|
| 1293 |
+
return module(*inputs, return_dict=return_dict)
|
| 1294 |
+
else:
|
| 1295 |
+
return module(*inputs)
|
| 1296 |
+
|
| 1297 |
+
return custom_forward
|
| 1298 |
+
|
| 1299 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 1300 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1301 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
| 1302 |
+
)[0]
|
| 1303 |
+
else:
|
| 1304 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1305 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 1306 |
+
|
| 1307 |
+
if self.upsamplers is not None:
|
| 1308 |
+
for upsampler in self.upsamplers:
|
| 1309 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1310 |
+
|
| 1311 |
+
return hidden_states
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
def get_down_block(
|
| 1315 |
+
down_block_type,
|
| 1316 |
+
in_channels,
|
| 1317 |
+
out_channels,
|
| 1318 |
+
add_downsample,
|
| 1319 |
+
attn_num_head_channels,
|
| 1320 |
+
cross_attention_dim,
|
| 1321 |
+
use_linear_projection,
|
| 1322 |
+
upcast_attention,
|
| 1323 |
+
):
|
| 1324 |
+
if down_block_type == "DownBlock2D":
|
| 1325 |
+
return DownBlock2D(
|
| 1326 |
+
in_channels=in_channels,
|
| 1327 |
+
out_channels=out_channels,
|
| 1328 |
+
add_downsample=add_downsample,
|
| 1329 |
+
)
|
| 1330 |
+
elif down_block_type == "CrossAttnDownBlock2D":
|
| 1331 |
+
return CrossAttnDownBlock2D(
|
| 1332 |
+
in_channels=in_channels,
|
| 1333 |
+
out_channels=out_channels,
|
| 1334 |
+
add_downsample=add_downsample,
|
| 1335 |
+
cross_attention_dim=cross_attention_dim,
|
| 1336 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 1337 |
+
use_linear_projection=use_linear_projection,
|
| 1338 |
+
upcast_attention=upcast_attention,
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
|
| 1342 |
+
def get_up_block(
|
| 1343 |
+
up_block_type,
|
| 1344 |
+
in_channels,
|
| 1345 |
+
out_channels,
|
| 1346 |
+
prev_output_channel,
|
| 1347 |
+
add_upsample,
|
| 1348 |
+
attn_num_head_channels,
|
| 1349 |
+
cross_attention_dim=None,
|
| 1350 |
+
use_linear_projection=False,
|
| 1351 |
+
upcast_attention=False,
|
| 1352 |
+
):
|
| 1353 |
+
if up_block_type == "UpBlock2D":
|
| 1354 |
+
return UpBlock2D(
|
| 1355 |
+
in_channels=in_channels,
|
| 1356 |
+
prev_output_channel=prev_output_channel,
|
| 1357 |
+
out_channels=out_channels,
|
| 1358 |
+
add_upsample=add_upsample,
|
| 1359 |
+
)
|
| 1360 |
+
elif up_block_type == "CrossAttnUpBlock2D":
|
| 1361 |
+
return CrossAttnUpBlock2D(
|
| 1362 |
+
in_channels=in_channels,
|
| 1363 |
+
out_channels=out_channels,
|
| 1364 |
+
prev_output_channel=prev_output_channel,
|
| 1365 |
+
attn_num_head_channels=attn_num_head_channels,
|
| 1366 |
+
cross_attention_dim=cross_attention_dim,
|
| 1367 |
+
add_upsample=add_upsample,
|
| 1368 |
+
use_linear_projection=use_linear_projection,
|
| 1369 |
+
upcast_attention=upcast_attention,
|
| 1370 |
+
)
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
class UNet2DConditionModel(nn.Module):
|
| 1374 |
+
_supports_gradient_checkpointing = True
|
| 1375 |
+
|
| 1376 |
+
def __init__(
|
| 1377 |
+
self,
|
| 1378 |
+
sample_size: Optional[int] = None,
|
| 1379 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 1380 |
+
cross_attention_dim: int = 1280,
|
| 1381 |
+
use_linear_projection: bool = False,
|
| 1382 |
+
upcast_attention: bool = False,
|
| 1383 |
+
**kwargs,
|
| 1384 |
+
):
|
| 1385 |
+
super().__init__()
|
| 1386 |
+
assert sample_size is not None, "sample_size must be specified"
|
| 1387 |
+
logger.info(
|
| 1388 |
+
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
| 1389 |
+
)
|
| 1390 |
+
|
| 1391 |
+
# 外部からの参照用に定義しておく
|
| 1392 |
+
self.in_channels = IN_CHANNELS
|
| 1393 |
+
self.out_channels = OUT_CHANNELS
|
| 1394 |
+
|
| 1395 |
+
self.sample_size = sample_size
|
| 1396 |
+
self.prepare_config(sample_size=sample_size)
|
| 1397 |
+
|
| 1398 |
+
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
| 1399 |
+
|
| 1400 |
+
# input
|
| 1401 |
+
self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
|
| 1402 |
+
|
| 1403 |
+
# time
|
| 1404 |
+
self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
|
| 1405 |
+
|
| 1406 |
+
self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
|
| 1407 |
+
|
| 1408 |
+
self.down_blocks = nn.ModuleList([])
|
| 1409 |
+
self.mid_block = None
|
| 1410 |
+
self.up_blocks = nn.ModuleList([])
|
| 1411 |
+
|
| 1412 |
+
if isinstance(attention_head_dim, int):
|
| 1413 |
+
attention_head_dim = (attention_head_dim,) * 4
|
| 1414 |
+
|
| 1415 |
+
# down
|
| 1416 |
+
output_channel = BLOCK_OUT_CHANNELS[0]
|
| 1417 |
+
for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
|
| 1418 |
+
input_channel = output_channel
|
| 1419 |
+
output_channel = BLOCK_OUT_CHANNELS[i]
|
| 1420 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
| 1421 |
+
|
| 1422 |
+
down_block = get_down_block(
|
| 1423 |
+
down_block_type,
|
| 1424 |
+
in_channels=input_channel,
|
| 1425 |
+
out_channels=output_channel,
|
| 1426 |
+
add_downsample=not is_final_block,
|
| 1427 |
+
attn_num_head_channels=attention_head_dim[i],
|
| 1428 |
+
cross_attention_dim=cross_attention_dim,
|
| 1429 |
+
use_linear_projection=use_linear_projection,
|
| 1430 |
+
upcast_attention=upcast_attention,
|
| 1431 |
+
)
|
| 1432 |
+
self.down_blocks.append(down_block)
|
| 1433 |
+
|
| 1434 |
+
# mid
|
| 1435 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
| 1436 |
+
in_channels=BLOCK_OUT_CHANNELS[-1],
|
| 1437 |
+
attn_num_head_channels=attention_head_dim[-1],
|
| 1438 |
+
cross_attention_dim=cross_attention_dim,
|
| 1439 |
+
use_linear_projection=use_linear_projection,
|
| 1440 |
+
)
|
| 1441 |
+
|
| 1442 |
+
# count how many layers upsample the images
|
| 1443 |
+
self.num_upsamplers = 0
|
| 1444 |
+
|
| 1445 |
+
# up
|
| 1446 |
+
reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
|
| 1447 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
| 1448 |
+
output_channel = reversed_block_out_channels[0]
|
| 1449 |
+
for i, up_block_type in enumerate(UP_BLOCK_TYPES):
|
| 1450 |
+
is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
|
| 1451 |
+
|
| 1452 |
+
prev_output_channel = output_channel
|
| 1453 |
+
output_channel = reversed_block_out_channels[i]
|
| 1454 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
|
| 1455 |
+
|
| 1456 |
+
# add upsample block for all BUT final layer
|
| 1457 |
+
if not is_final_block:
|
| 1458 |
+
add_upsample = True
|
| 1459 |
+
self.num_upsamplers += 1
|
| 1460 |
+
else:
|
| 1461 |
+
add_upsample = False
|
| 1462 |
+
|
| 1463 |
+
up_block = get_up_block(
|
| 1464 |
+
up_block_type,
|
| 1465 |
+
in_channels=input_channel,
|
| 1466 |
+
out_channels=output_channel,
|
| 1467 |
+
prev_output_channel=prev_output_channel,
|
| 1468 |
+
add_upsample=add_upsample,
|
| 1469 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
| 1470 |
+
cross_attention_dim=cross_attention_dim,
|
| 1471 |
+
use_linear_projection=use_linear_projection,
|
| 1472 |
+
upcast_attention=upcast_attention,
|
| 1473 |
+
)
|
| 1474 |
+
self.up_blocks.append(up_block)
|
| 1475 |
+
prev_output_channel = output_channel
|
| 1476 |
+
|
| 1477 |
+
# out
|
| 1478 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
|
| 1479 |
+
self.conv_act = nn.SiLU()
|
| 1480 |
+
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
| 1481 |
+
|
| 1482 |
+
# region diffusers compatibility
|
| 1483 |
+
def prepare_config(self, *args, **kwargs):
|
| 1484 |
+
self.config = SimpleNamespace(**kwargs)
|
| 1485 |
+
|
| 1486 |
+
@property
|
| 1487 |
+
def dtype(self) -> torch.dtype:
|
| 1488 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
| 1489 |
+
return get_parameter_dtype(self)
|
| 1490 |
+
|
| 1491 |
+
@property
|
| 1492 |
+
def device(self) -> torch.device:
|
| 1493 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
| 1494 |
+
return get_parameter_device(self)
|
| 1495 |
+
|
| 1496 |
+
def set_attention_slice(self, slice_size):
|
| 1497 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
| 1498 |
+
|
| 1499 |
+
def is_gradient_checkpointing(self) -> bool:
|
| 1500 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
| 1501 |
+
|
| 1502 |
+
def enable_gradient_checkpointing(self):
|
| 1503 |
+
self.set_gradient_checkpointing(value=True)
|
| 1504 |
+
|
| 1505 |
+
def disable_gradient_checkpointing(self):
|
| 1506 |
+
self.set_gradient_checkpointing(value=False)
|
| 1507 |
+
|
| 1508 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
| 1509 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
| 1510 |
+
for module in modules:
|
| 1511 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 1512 |
+
|
| 1513 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
| 1514 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
| 1515 |
+
for module in modules:
|
| 1516 |
+
module.set_use_sdpa(sdpa)
|
| 1517 |
+
|
| 1518 |
+
def set_gradient_checkpointing(self, value=False):
|
| 1519 |
+
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
| 1520 |
+
for module in modules:
|
| 1521 |
+
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
| 1522 |
+
module.gradient_checkpointing = value
|
| 1523 |
+
|
| 1524 |
+
# endregion
|
| 1525 |
+
|
| 1526 |
+
def forward(
|
| 1527 |
+
self,
|
| 1528 |
+
sample: torch.FloatTensor,
|
| 1529 |
+
timestep: Union[torch.Tensor, float, int],
|
| 1530 |
+
encoder_hidden_states: torch.Tensor,
|
| 1531 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 1532 |
+
return_dict: bool = True,
|
| 1533 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1534 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 1535 |
+
) -> Union[Dict, Tuple]:
|
| 1536 |
+
r"""
|
| 1537 |
+
Args:
|
| 1538 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 1539 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 1540 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 1541 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1542 |
+
Whether or not to return a dict instead of a plain tuple.
|
| 1543 |
+
|
| 1544 |
+
Returns:
|
| 1545 |
+
`SampleOutput` or `tuple`:
|
| 1546 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
| 1547 |
+
"""
|
| 1548 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 1549 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 1550 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 1551 |
+
# on the fly if necessary.
|
| 1552 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
| 1553 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
| 1554 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
| 1555 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 1556 |
+
|
| 1557 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 1558 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
| 1559 |
+
forward_upsample_size = False
|
| 1560 |
+
upsample_size = None
|
| 1561 |
+
|
| 1562 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 1563 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
| 1564 |
+
forward_upsample_size = True
|
| 1565 |
+
|
| 1566 |
+
# 1. time
|
| 1567 |
+
timesteps = timestep
|
| 1568 |
+
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
| 1569 |
+
|
| 1570 |
+
t_emb = self.time_proj(timesteps)
|
| 1571 |
+
|
| 1572 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 1573 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 1574 |
+
# there might be better ways to encapsulate this.
|
| 1575 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
| 1576 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
| 1577 |
+
# time_projでキャストしておけばいいんじゃね?
|
| 1578 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 1579 |
+
emb = self.time_embedding(t_emb)
|
| 1580 |
+
|
| 1581 |
+
# 2. pre-process
|
| 1582 |
+
sample = self.conv_in(sample)
|
| 1583 |
+
|
| 1584 |
+
down_block_res_samples = (sample,)
|
| 1585 |
+
for downsample_block in self.down_blocks:
|
| 1586 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
| 1587 |
+
# まあこちらのほうがわかりやすいかもしれない
|
| 1588 |
+
if downsample_block.has_cross_attention:
|
| 1589 |
+
sample, res_samples = downsample_block(
|
| 1590 |
+
hidden_states=sample,
|
| 1591 |
+
temb=emb,
|
| 1592 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1593 |
+
)
|
| 1594 |
+
else:
|
| 1595 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 1596 |
+
|
| 1597 |
+
down_block_res_samples += res_samples
|
| 1598 |
+
|
| 1599 |
+
# skip connectionにControlNetの出力を追加する
|
| 1600 |
+
if down_block_additional_residuals is not None:
|
| 1601 |
+
down_block_res_samples = list(down_block_res_samples)
|
| 1602 |
+
for i in range(len(down_block_res_samples)):
|
| 1603 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
| 1604 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
| 1605 |
+
|
| 1606 |
+
# 4. mid
|
| 1607 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
| 1608 |
+
|
| 1609 |
+
# ControlNetの出力を追加する
|
| 1610 |
+
if mid_block_additional_residual is not None:
|
| 1611 |
+
sample += mid_block_additional_residual
|
| 1612 |
+
|
| 1613 |
+
# 5. up
|
| 1614 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 1615 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 1616 |
+
|
| 1617 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 1618 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
| 1619 |
+
|
| 1620 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
| 1621 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
| 1622 |
+
if not is_final_block and forward_upsample_size:
|
| 1623 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 1624 |
+
|
| 1625 |
+
if upsample_block.has_cross_attention:
|
| 1626 |
+
sample = upsample_block(
|
| 1627 |
+
hidden_states=sample,
|
| 1628 |
+
temb=emb,
|
| 1629 |
+
res_hidden_states_tuple=res_samples,
|
| 1630 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1631 |
+
upsample_size=upsample_size,
|
| 1632 |
+
)
|
| 1633 |
+
else:
|
| 1634 |
+
sample = upsample_block(
|
| 1635 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 1636 |
+
)
|
| 1637 |
+
|
| 1638 |
+
# 6. post-process
|
| 1639 |
+
sample = self.conv_norm_out(sample)
|
| 1640 |
+
sample = self.conv_act(sample)
|
| 1641 |
+
sample = self.conv_out(sample)
|
| 1642 |
+
|
| 1643 |
+
if not return_dict:
|
| 1644 |
+
return (sample,)
|
| 1645 |
+
|
| 1646 |
+
return SampleOutput(sample=sample)
|
| 1647 |
+
|
| 1648 |
+
def handle_unusual_timesteps(self, sample, timesteps):
|
| 1649 |
+
r"""
|
| 1650 |
+
timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
|
| 1651 |
+
"""
|
| 1652 |
+
if not torch.is_tensor(timesteps):
|
| 1653 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 1654 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 1655 |
+
is_mps = sample.device.type == "mps"
|
| 1656 |
+
if isinstance(timesteps, float):
|
| 1657 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 1658 |
+
else:
|
| 1659 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 1660 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 1661 |
+
elif len(timesteps.shape) == 0:
|
| 1662 |
+
timesteps = timesteps[None].to(sample.device)
|
| 1663 |
+
|
| 1664 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 1665 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 1666 |
+
|
| 1667 |
+
return timesteps
|
| 1668 |
+
|
| 1669 |
+
|
| 1670 |
+
class InferUNet2DConditionModel:
|
| 1671 |
+
def __init__(self, original_unet: UNet2DConditionModel):
|
| 1672 |
+
self.delegate = original_unet
|
| 1673 |
+
|
| 1674 |
+
# override original model's forward method: because forward is not called by `__call__`
|
| 1675 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
| 1676 |
+
self.delegate.forward = self.forward
|
| 1677 |
+
|
| 1678 |
+
# override original model's up blocks' forward method
|
| 1679 |
+
for up_block in self.delegate.up_blocks:
|
| 1680 |
+
if up_block.__class__.__name__ == "UpBlock2D":
|
| 1681 |
+
|
| 1682 |
+
def resnet_wrapper(func, block):
|
| 1683 |
+
def forward(*args, **kwargs):
|
| 1684 |
+
return func(block, *args, **kwargs)
|
| 1685 |
+
|
| 1686 |
+
return forward
|
| 1687 |
+
|
| 1688 |
+
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
| 1689 |
+
|
| 1690 |
+
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
| 1691 |
+
|
| 1692 |
+
def cross_attn_up_wrapper(func, block):
|
| 1693 |
+
def forward(*args, **kwargs):
|
| 1694 |
+
return func(block, *args, **kwargs)
|
| 1695 |
+
|
| 1696 |
+
return forward
|
| 1697 |
+
|
| 1698 |
+
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
| 1699 |
+
|
| 1700 |
+
# Deep Shrink
|
| 1701 |
+
self.ds_depth_1 = None
|
| 1702 |
+
self.ds_depth_2 = None
|
| 1703 |
+
self.ds_timesteps_1 = None
|
| 1704 |
+
self.ds_timesteps_2 = None
|
| 1705 |
+
self.ds_ratio = None
|
| 1706 |
+
|
| 1707 |
+
# call original model's methods
|
| 1708 |
+
def __getattr__(self, name):
|
| 1709 |
+
return getattr(self.delegate, name)
|
| 1710 |
+
|
| 1711 |
+
def __call__(self, *args, **kwargs):
|
| 1712 |
+
return self.delegate(*args, **kwargs)
|
| 1713 |
+
|
| 1714 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
| 1715 |
+
if ds_depth_1 is None:
|
| 1716 |
+
logger.info("Deep Shrink is disabled.")
|
| 1717 |
+
self.ds_depth_1 = None
|
| 1718 |
+
self.ds_timesteps_1 = None
|
| 1719 |
+
self.ds_depth_2 = None
|
| 1720 |
+
self.ds_timesteps_2 = None
|
| 1721 |
+
self.ds_ratio = None
|
| 1722 |
+
else:
|
| 1723 |
+
logger.info(
|
| 1724 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
| 1725 |
+
)
|
| 1726 |
+
self.ds_depth_1 = ds_depth_1
|
| 1727 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
| 1728 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
| 1729 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
| 1730 |
+
self.ds_ratio = ds_ratio
|
| 1731 |
+
|
| 1732 |
+
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 1733 |
+
for resnet in _self.resnets:
|
| 1734 |
+
# pop res hidden states
|
| 1735 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1736 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1737 |
+
|
| 1738 |
+
# Deep Shrink
|
| 1739 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
| 1740 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
| 1741 |
+
|
| 1742 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1743 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1744 |
+
|
| 1745 |
+
if _self.upsamplers is not None:
|
| 1746 |
+
for upsampler in _self.upsamplers:
|
| 1747 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1748 |
+
|
| 1749 |
+
return hidden_states
|
| 1750 |
+
|
| 1751 |
+
def cross_attn_up_block_forward(
|
| 1752 |
+
self,
|
| 1753 |
+
_self,
|
| 1754 |
+
hidden_states,
|
| 1755 |
+
res_hidden_states_tuple,
|
| 1756 |
+
temb=None,
|
| 1757 |
+
encoder_hidden_states=None,
|
| 1758 |
+
upsample_size=None,
|
| 1759 |
+
):
|
| 1760 |
+
for resnet, attn in zip(_self.resnets, _self.attentions):
|
| 1761 |
+
# pop res hidden states
|
| 1762 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 1763 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 1764 |
+
|
| 1765 |
+
# Deep Shrink
|
| 1766 |
+
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
| 1767 |
+
hidden_states = resize_like(hidden_states, res_hidden_states)
|
| 1768 |
+
|
| 1769 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 1770 |
+
hidden_states = resnet(hidden_states, temb)
|
| 1771 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
| 1772 |
+
|
| 1773 |
+
if _self.upsamplers is not None:
|
| 1774 |
+
for upsampler in _self.upsamplers:
|
| 1775 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 1776 |
+
|
| 1777 |
+
return hidden_states
|
| 1778 |
+
|
| 1779 |
+
def forward(
|
| 1780 |
+
self,
|
| 1781 |
+
sample: torch.FloatTensor,
|
| 1782 |
+
timestep: Union[torch.Tensor, float, int],
|
| 1783 |
+
encoder_hidden_states: torch.Tensor,
|
| 1784 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 1785 |
+
return_dict: bool = True,
|
| 1786 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1787 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 1788 |
+
) -> Union[Dict, Tuple]:
|
| 1789 |
+
r"""
|
| 1790 |
+
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
| 1791 |
+
"""
|
| 1792 |
+
|
| 1793 |
+
r"""
|
| 1794 |
+
Args:
|
| 1795 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 1796 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 1797 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 1798 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1799 |
+
Whether or not to return a dict instead of a plain tuple.
|
| 1800 |
+
|
| 1801 |
+
Returns:
|
| 1802 |
+
`SampleOutput` or `tuple`:
|
| 1803 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
| 1804 |
+
"""
|
| 1805 |
+
|
| 1806 |
+
_self = self.delegate
|
| 1807 |
+
|
| 1808 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 1809 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 1810 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 1811 |
+
# on the fly if necessary.
|
| 1812 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
| 1813 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
| 1814 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
| 1815 |
+
default_overall_up_factor = 2**_self.num_upsamplers
|
| 1816 |
+
|
| 1817 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 1818 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
| 1819 |
+
forward_upsample_size = False
|
| 1820 |
+
upsample_size = None
|
| 1821 |
+
|
| 1822 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 1823 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
| 1824 |
+
forward_upsample_size = True
|
| 1825 |
+
|
| 1826 |
+
# 1. time
|
| 1827 |
+
timesteps = timestep
|
| 1828 |
+
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
| 1829 |
+
|
| 1830 |
+
t_emb = _self.time_proj(timesteps)
|
| 1831 |
+
|
| 1832 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 1833 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 1834 |
+
# there might be better ways to encapsulate this.
|
| 1835 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
| 1836 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
| 1837 |
+
# time_projでキャストしておけばいいんじゃね?
|
| 1838 |
+
t_emb = t_emb.to(dtype=_self.dtype)
|
| 1839 |
+
emb = _self.time_embedding(t_emb)
|
| 1840 |
+
|
| 1841 |
+
# 2. pre-process
|
| 1842 |
+
sample = _self.conv_in(sample)
|
| 1843 |
+
|
| 1844 |
+
down_block_res_samples = (sample,)
|
| 1845 |
+
for depth, downsample_block in enumerate(_self.down_blocks):
|
| 1846 |
+
# Deep Shrink
|
| 1847 |
+
if self.ds_depth_1 is not None:
|
| 1848 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
| 1849 |
+
self.ds_depth_2 is not None
|
| 1850 |
+
and depth == self.ds_depth_2
|
| 1851 |
+
and timesteps[0] < self.ds_timesteps_1
|
| 1852 |
+
and timesteps[0] >= self.ds_timesteps_2
|
| 1853 |
+
):
|
| 1854 |
+
org_dtype = sample.dtype
|
| 1855 |
+
if org_dtype == torch.bfloat16:
|
| 1856 |
+
sample = sample.to(torch.float32)
|
| 1857 |
+
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
| 1858 |
+
|
| 1859 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
| 1860 |
+
# まあこちらのほうがわかりやすいかもしれない
|
| 1861 |
+
if downsample_block.has_cross_attention:
|
| 1862 |
+
sample, res_samples = downsample_block(
|
| 1863 |
+
hidden_states=sample,
|
| 1864 |
+
temb=emb,
|
| 1865 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1866 |
+
)
|
| 1867 |
+
else:
|
| 1868 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 1869 |
+
|
| 1870 |
+
down_block_res_samples += res_samples
|
| 1871 |
+
|
| 1872 |
+
# skip connectionにControlNetの出力を追加する
|
| 1873 |
+
if down_block_additional_residuals is not None:
|
| 1874 |
+
down_block_res_samples = list(down_block_res_samples)
|
| 1875 |
+
for i in range(len(down_block_res_samples)):
|
| 1876 |
+
down_block_res_samples[i] += down_block_additional_residuals[i]
|
| 1877 |
+
down_block_res_samples = tuple(down_block_res_samples)
|
| 1878 |
+
|
| 1879 |
+
# 4. mid
|
| 1880 |
+
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
| 1881 |
+
|
| 1882 |
+
# ControlNetの出力を追加する
|
| 1883 |
+
if mid_block_additional_residual is not None:
|
| 1884 |
+
sample += mid_block_additional_residual
|
| 1885 |
+
|
| 1886 |
+
# 5. up
|
| 1887 |
+
for i, upsample_block in enumerate(_self.up_blocks):
|
| 1888 |
+
is_final_block = i == len(_self.up_blocks) - 1
|
| 1889 |
+
|
| 1890 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 1891 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
| 1892 |
+
|
| 1893 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
| 1894 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
| 1895 |
+
if not is_final_block and forward_upsample_size:
|
| 1896 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 1897 |
+
|
| 1898 |
+
if upsample_block.has_cross_attention:
|
| 1899 |
+
sample = upsample_block(
|
| 1900 |
+
hidden_states=sample,
|
| 1901 |
+
temb=emb,
|
| 1902 |
+
res_hidden_states_tuple=res_samples,
|
| 1903 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1904 |
+
upsample_size=upsample_size,
|
| 1905 |
+
)
|
| 1906 |
+
else:
|
| 1907 |
+
sample = upsample_block(
|
| 1908 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 1909 |
+
)
|
| 1910 |
+
|
| 1911 |
+
# 6. post-process
|
| 1912 |
+
sample = _self.conv_norm_out(sample)
|
| 1913 |
+
sample = _self.conv_act(sample)
|
| 1914 |
+
sample = _self.conv_out(sample)
|
| 1915 |
+
|
| 1916 |
+
if not return_dict:
|
| 1917 |
+
return (sample,)
|
| 1918 |
+
|
| 1919 |
+
return SampleOutput(sample=sample)
|
library/sai_model_spec.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# based on https://github.com/Stability-AI/ModelSpec
|
| 2 |
+
import datetime
|
| 3 |
+
import hashlib
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
import os
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
import safetensors
|
| 8 |
+
from library.utils import setup_logging
|
| 9 |
+
setup_logging()
|
| 10 |
+
import logging
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
r"""
|
| 14 |
+
# Metadata Example
|
| 15 |
+
metadata = {
|
| 16 |
+
# === Must ===
|
| 17 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
| 18 |
+
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
|
| 19 |
+
"modelspec.implementation": "sgm",
|
| 20 |
+
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
|
| 21 |
+
# === Should ===
|
| 22 |
+
"modelspec.author": "Example Corp", # Your name or company name
|
| 23 |
+
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
|
| 24 |
+
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
|
| 25 |
+
# === Can ===
|
| 26 |
+
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
|
| 27 |
+
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
BASE_METADATA = {
|
| 32 |
+
# === Must ===
|
| 33 |
+
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
|
| 34 |
+
"modelspec.architecture": None,
|
| 35 |
+
"modelspec.implementation": None,
|
| 36 |
+
"modelspec.title": None,
|
| 37 |
+
"modelspec.resolution": None,
|
| 38 |
+
# === Should ===
|
| 39 |
+
"modelspec.description": None,
|
| 40 |
+
"modelspec.author": None,
|
| 41 |
+
"modelspec.date": None,
|
| 42 |
+
# === Can ===
|
| 43 |
+
"modelspec.license": None,
|
| 44 |
+
"modelspec.tags": None,
|
| 45 |
+
"modelspec.merged_from": None,
|
| 46 |
+
"modelspec.prediction_type": None,
|
| 47 |
+
"modelspec.timestep_range": None,
|
| 48 |
+
"modelspec.encoder_layer": None,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# 別に使うやつだけ定義
|
| 52 |
+
MODELSPEC_TITLE = "modelspec.title"
|
| 53 |
+
|
| 54 |
+
ARCH_SD_V1 = "stable-diffusion-v1"
|
| 55 |
+
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
| 56 |
+
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
| 57 |
+
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
| 58 |
+
|
| 59 |
+
ADAPTER_LORA = "lora"
|
| 60 |
+
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
| 61 |
+
|
| 62 |
+
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
| 63 |
+
IMPL_DIFFUSERS = "diffusers"
|
| 64 |
+
|
| 65 |
+
PRED_TYPE_EPSILON = "epsilon"
|
| 66 |
+
PRED_TYPE_V = "v"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_bytes_in_safetensors(tensors):
|
| 70 |
+
bytes = safetensors.torch.save(tensors)
|
| 71 |
+
b = BytesIO(bytes)
|
| 72 |
+
|
| 73 |
+
b.seek(0)
|
| 74 |
+
header = b.read(8)
|
| 75 |
+
n = int.from_bytes(header, "little")
|
| 76 |
+
|
| 77 |
+
offset = n + 8
|
| 78 |
+
b.seek(offset)
|
| 79 |
+
|
| 80 |
+
return b.read()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def precalculate_safetensors_hashes(state_dict):
|
| 84 |
+
# calculate each tensor one by one to reduce memory usage
|
| 85 |
+
hash_sha256 = hashlib.sha256()
|
| 86 |
+
for tensor in state_dict.values():
|
| 87 |
+
single_tensor_sd = {"tensor": tensor}
|
| 88 |
+
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
|
| 89 |
+
hash_sha256.update(bytes_for_tensor)
|
| 90 |
+
|
| 91 |
+
return f"0x{hash_sha256.hexdigest()}"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def update_hash_sha256(metadata: dict, state_dict: dict):
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_metadata(
|
| 99 |
+
state_dict: Optional[dict],
|
| 100 |
+
v2: bool,
|
| 101 |
+
v_parameterization: bool,
|
| 102 |
+
sdxl: bool,
|
| 103 |
+
lora: bool,
|
| 104 |
+
textual_inversion: bool,
|
| 105 |
+
timestamp: float,
|
| 106 |
+
title: Optional[str] = None,
|
| 107 |
+
reso: Optional[Union[int, Tuple[int, int]]] = None,
|
| 108 |
+
is_stable_diffusion_ckpt: Optional[bool] = None,
|
| 109 |
+
author: Optional[str] = None,
|
| 110 |
+
description: Optional[str] = None,
|
| 111 |
+
license: Optional[str] = None,
|
| 112 |
+
tags: Optional[str] = None,
|
| 113 |
+
merged_from: Optional[str] = None,
|
| 114 |
+
timesteps: Optional[Tuple[int, int]] = None,
|
| 115 |
+
clip_skip: Optional[int] = None,
|
| 116 |
+
):
|
| 117 |
+
# if state_dict is None, hash is not calculated
|
| 118 |
+
|
| 119 |
+
metadata = {}
|
| 120 |
+
metadata.update(BASE_METADATA)
|
| 121 |
+
|
| 122 |
+
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
|
| 123 |
+
# if state_dict is not None:
|
| 124 |
+
# hash = precalculate_safetensors_hashes(state_dict)
|
| 125 |
+
# metadata["modelspec.hash_sha256"] = hash
|
| 126 |
+
|
| 127 |
+
if sdxl:
|
| 128 |
+
arch = ARCH_SD_XL_V1_BASE
|
| 129 |
+
elif v2:
|
| 130 |
+
if v_parameterization:
|
| 131 |
+
arch = ARCH_SD_V2_768_V
|
| 132 |
+
else:
|
| 133 |
+
arch = ARCH_SD_V2_512
|
| 134 |
+
else:
|
| 135 |
+
arch = ARCH_SD_V1
|
| 136 |
+
|
| 137 |
+
if lora:
|
| 138 |
+
arch += f"/{ADAPTER_LORA}"
|
| 139 |
+
elif textual_inversion:
|
| 140 |
+
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
|
| 141 |
+
|
| 142 |
+
metadata["modelspec.architecture"] = arch
|
| 143 |
+
|
| 144 |
+
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
| 145 |
+
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
| 146 |
+
|
| 147 |
+
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
| 148 |
+
# Stable Diffusion ckpt, TI, SDXL LoRA
|
| 149 |
+
impl = IMPL_STABILITY_AI
|
| 150 |
+
else:
|
| 151 |
+
# v1/v2 LoRA or Diffusers
|
| 152 |
+
impl = IMPL_DIFFUSERS
|
| 153 |
+
metadata["modelspec.implementation"] = impl
|
| 154 |
+
|
| 155 |
+
if title is None:
|
| 156 |
+
if lora:
|
| 157 |
+
title = "LoRA"
|
| 158 |
+
elif textual_inversion:
|
| 159 |
+
title = "TextualInversion"
|
| 160 |
+
else:
|
| 161 |
+
title = "Checkpoint"
|
| 162 |
+
title += f"@{timestamp}"
|
| 163 |
+
metadata[MODELSPEC_TITLE] = title
|
| 164 |
+
|
| 165 |
+
if author is not None:
|
| 166 |
+
metadata["modelspec.author"] = author
|
| 167 |
+
else:
|
| 168 |
+
del metadata["modelspec.author"]
|
| 169 |
+
|
| 170 |
+
if description is not None:
|
| 171 |
+
metadata["modelspec.description"] = description
|
| 172 |
+
else:
|
| 173 |
+
del metadata["modelspec.description"]
|
| 174 |
+
|
| 175 |
+
if merged_from is not None:
|
| 176 |
+
metadata["modelspec.merged_from"] = merged_from
|
| 177 |
+
else:
|
| 178 |
+
del metadata["modelspec.merged_from"]
|
| 179 |
+
|
| 180 |
+
if license is not None:
|
| 181 |
+
metadata["modelspec.license"] = license
|
| 182 |
+
else:
|
| 183 |
+
del metadata["modelspec.license"]
|
| 184 |
+
|
| 185 |
+
if tags is not None:
|
| 186 |
+
metadata["modelspec.tags"] = tags
|
| 187 |
+
else:
|
| 188 |
+
del metadata["modelspec.tags"]
|
| 189 |
+
|
| 190 |
+
# remove microsecond from time
|
| 191 |
+
int_ts = int(timestamp)
|
| 192 |
+
|
| 193 |
+
# time to iso-8601 compliant date
|
| 194 |
+
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
|
| 195 |
+
metadata["modelspec.date"] = date
|
| 196 |
+
|
| 197 |
+
if reso is not None:
|
| 198 |
+
# comma separated to tuple
|
| 199 |
+
if isinstance(reso, str):
|
| 200 |
+
reso = tuple(map(int, reso.split(",")))
|
| 201 |
+
if len(reso) == 1:
|
| 202 |
+
reso = (reso[0], reso[0])
|
| 203 |
+
else:
|
| 204 |
+
# resolution is defined in dataset, so use default
|
| 205 |
+
if sdxl:
|
| 206 |
+
reso = 1024
|
| 207 |
+
elif v2 and v_parameterization:
|
| 208 |
+
reso = 768
|
| 209 |
+
else:
|
| 210 |
+
reso = 512
|
| 211 |
+
if isinstance(reso, int):
|
| 212 |
+
reso = (reso, reso)
|
| 213 |
+
|
| 214 |
+
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
| 215 |
+
|
| 216 |
+
if v_parameterization:
|
| 217 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
| 218 |
+
else:
|
| 219 |
+
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
| 220 |
+
|
| 221 |
+
if timesteps is not None:
|
| 222 |
+
if isinstance(timesteps, str) or isinstance(timesteps, int):
|
| 223 |
+
timesteps = (timesteps, timesteps)
|
| 224 |
+
if len(timesteps) == 1:
|
| 225 |
+
timesteps = (timesteps[0], timesteps[0])
|
| 226 |
+
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
|
| 227 |
+
else:
|
| 228 |
+
del metadata["modelspec.timestep_range"]
|
| 229 |
+
|
| 230 |
+
if clip_skip is not None:
|
| 231 |
+
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
|
| 232 |
+
else:
|
| 233 |
+
del metadata["modelspec.encoder_layer"]
|
| 234 |
+
|
| 235 |
+
# # assert all values are filled
|
| 236 |
+
# assert all([v is not None for v in metadata.values()]), metadata
|
| 237 |
+
if not all([v is not None for v in metadata.values()]):
|
| 238 |
+
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
| 239 |
+
|
| 240 |
+
return metadata
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# region utils
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def get_title(metadata: dict) -> Optional[str]:
|
| 247 |
+
return metadata.get(MODELSPEC_TITLE, None)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_metadata_from_safetensors(model: str) -> dict:
|
| 251 |
+
if not model.endswith(".safetensors"):
|
| 252 |
+
return {}
|
| 253 |
+
|
| 254 |
+
with safetensors.safe_open(model, framework="pt") as f:
|
| 255 |
+
metadata = f.metadata()
|
| 256 |
+
if metadata is None:
|
| 257 |
+
metadata = {}
|
| 258 |
+
return metadata
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def build_merged_from(models: List[str]) -> str:
|
| 262 |
+
def get_title(model: str):
|
| 263 |
+
metadata = load_metadata_from_safetensors(model)
|
| 264 |
+
title = metadata.get(MODELSPEC_TITLE, None)
|
| 265 |
+
if title is None:
|
| 266 |
+
title = os.path.splitext(os.path.basename(model))[0] # use filename
|
| 267 |
+
return title
|
| 268 |
+
|
| 269 |
+
titles = [get_title(model) for model in models]
|
| 270 |
+
return ", ".join(titles)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# endregion
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
r"""
|
| 277 |
+
if __name__ == "__main__":
|
| 278 |
+
import argparse
|
| 279 |
+
import torch
|
| 280 |
+
from safetensors.torch import load_file
|
| 281 |
+
from library import train_util
|
| 282 |
+
|
| 283 |
+
parser = argparse.ArgumentParser()
|
| 284 |
+
parser.add_argument("--ckpt", type=str, required=True)
|
| 285 |
+
args = parser.parse_args()
|
| 286 |
+
|
| 287 |
+
print(f"Loading {args.ckpt}")
|
| 288 |
+
state_dict = load_file(args.ckpt)
|
| 289 |
+
|
| 290 |
+
print(f"Calculating metadata")
|
| 291 |
+
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
|
| 292 |
+
print(metadata)
|
| 293 |
+
del state_dict
|
| 294 |
+
|
| 295 |
+
# by reference implementation
|
| 296 |
+
with open(args.ckpt, mode="rb") as file_data:
|
| 297 |
+
file_hash = hashlib.sha256()
|
| 298 |
+
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
|
| 299 |
+
header = json.loads(file_data.read(head_len[0])) # header itself, json string
|
| 300 |
+
content = (
|
| 301 |
+
file_data.read()
|
| 302 |
+
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
|
| 303 |
+
file_hash.update(content)
|
| 304 |
+
# ===== Update the hash for modelspec =====
|
| 305 |
+
by_ref = f"0x{file_hash.hexdigest()}"
|
| 306 |
+
print(by_ref)
|
| 307 |
+
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
|
| 308 |
+
|
| 309 |
+
"""
|
library/sdxl_lpw_stable_diffusion.py
ADDED
|
@@ -0,0 +1,1347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
| 2 |
+
# and modify to support SD2.x
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
import re
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import torch
|
| 11 |
+
from packaging import version
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
| 14 |
+
|
| 15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
+
from diffusers.utils import logging
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
from library import sdxl_model_util, sdxl_train_util, train_util
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from diffusers.utils import PIL_INTERPOLATION
|
| 26 |
+
except ImportError:
|
| 27 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 28 |
+
PIL_INTERPOLATION = {
|
| 29 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 30 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 31 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 32 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 33 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 34 |
+
}
|
| 35 |
+
else:
|
| 36 |
+
PIL_INTERPOLATION = {
|
| 37 |
+
"linear": PIL.Image.LINEAR,
|
| 38 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 39 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 40 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 41 |
+
"nearest": PIL.Image.NEAREST,
|
| 42 |
+
}
|
| 43 |
+
# ------------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
re_attention = re.compile(
|
| 48 |
+
r"""
|
| 49 |
+
\\\(|
|
| 50 |
+
\\\)|
|
| 51 |
+
\\\[|
|
| 52 |
+
\\]|
|
| 53 |
+
\\\\|
|
| 54 |
+
\\|
|
| 55 |
+
\(|
|
| 56 |
+
\[|
|
| 57 |
+
:([+-]?[.\d]+)\)|
|
| 58 |
+
\)|
|
| 59 |
+
]|
|
| 60 |
+
[^\\()\[\]:]+|
|
| 61 |
+
:
|
| 62 |
+
""",
|
| 63 |
+
re.X,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def parse_prompt_attention(text):
|
| 68 |
+
"""
|
| 69 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 70 |
+
Accepted tokens are:
|
| 71 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 72 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 73 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 74 |
+
\( - literal character '('
|
| 75 |
+
\[ - literal character '['
|
| 76 |
+
\) - literal character ')'
|
| 77 |
+
\] - literal character ']'
|
| 78 |
+
\\ - literal character '\'
|
| 79 |
+
anything else - just text
|
| 80 |
+
>>> parse_prompt_attention('normal text')
|
| 81 |
+
[['normal text', 1.0]]
|
| 82 |
+
>>> parse_prompt_attention('an (important) word')
|
| 83 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 84 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 85 |
+
[['unbalanced', 1.1]]
|
| 86 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 87 |
+
[['(literal]', 1.0]]
|
| 88 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 89 |
+
[['unnecessaryparens', 1.1]]
|
| 90 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 91 |
+
[['a ', 1.0],
|
| 92 |
+
['house', 1.5730000000000004],
|
| 93 |
+
[' ', 1.1],
|
| 94 |
+
['on', 1.0],
|
| 95 |
+
[' a ', 1.1],
|
| 96 |
+
['hill', 0.55],
|
| 97 |
+
[', sun, ', 1.1],
|
| 98 |
+
['sky', 1.4641000000000006],
|
| 99 |
+
['.', 1.1]]
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
res = []
|
| 103 |
+
round_brackets = []
|
| 104 |
+
square_brackets = []
|
| 105 |
+
|
| 106 |
+
round_bracket_multiplier = 1.1
|
| 107 |
+
square_bracket_multiplier = 1 / 1.1
|
| 108 |
+
|
| 109 |
+
def multiply_range(start_position, multiplier):
|
| 110 |
+
for p in range(start_position, len(res)):
|
| 111 |
+
res[p][1] *= multiplier
|
| 112 |
+
|
| 113 |
+
for m in re_attention.finditer(text):
|
| 114 |
+
text = m.group(0)
|
| 115 |
+
weight = m.group(1)
|
| 116 |
+
|
| 117 |
+
if text.startswith("\\"):
|
| 118 |
+
res.append([text[1:], 1.0])
|
| 119 |
+
elif text == "(":
|
| 120 |
+
round_brackets.append(len(res))
|
| 121 |
+
elif text == "[":
|
| 122 |
+
square_brackets.append(len(res))
|
| 123 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 124 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 125 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 126 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 127 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 128 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 129 |
+
else:
|
| 130 |
+
res.append([text, 1.0])
|
| 131 |
+
|
| 132 |
+
for pos in round_brackets:
|
| 133 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 134 |
+
|
| 135 |
+
for pos in square_brackets:
|
| 136 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 137 |
+
|
| 138 |
+
if len(res) == 0:
|
| 139 |
+
res = [["", 1.0]]
|
| 140 |
+
|
| 141 |
+
# merge runs of identical weights
|
| 142 |
+
i = 0
|
| 143 |
+
while i + 1 < len(res):
|
| 144 |
+
if res[i][1] == res[i + 1][1]:
|
| 145 |
+
res[i][0] += res[i + 1][0]
|
| 146 |
+
res.pop(i + 1)
|
| 147 |
+
else:
|
| 148 |
+
i += 1
|
| 149 |
+
|
| 150 |
+
return res
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
| 154 |
+
r"""
|
| 155 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 156 |
+
|
| 157 |
+
No padding, starting or ending token is included.
|
| 158 |
+
"""
|
| 159 |
+
tokens = []
|
| 160 |
+
weights = []
|
| 161 |
+
truncated = False
|
| 162 |
+
for text in prompt:
|
| 163 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 164 |
+
text_token = []
|
| 165 |
+
text_weight = []
|
| 166 |
+
for word, weight in texts_and_weights:
|
| 167 |
+
# tokenize and discard the starting and the ending token
|
| 168 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
| 169 |
+
text_token += token
|
| 170 |
+
# copy the weight by length of token
|
| 171 |
+
text_weight += [weight] * len(token)
|
| 172 |
+
# stop if the text is too long (longer than truncation limit)
|
| 173 |
+
if len(text_token) > max_length:
|
| 174 |
+
truncated = True
|
| 175 |
+
break
|
| 176 |
+
# truncate
|
| 177 |
+
if len(text_token) > max_length:
|
| 178 |
+
truncated = True
|
| 179 |
+
text_token = text_token[:max_length]
|
| 180 |
+
text_weight = text_weight[:max_length]
|
| 181 |
+
tokens.append(text_token)
|
| 182 |
+
weights.append(text_weight)
|
| 183 |
+
if truncated:
|
| 184 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 185 |
+
return tokens, weights
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
|
| 189 |
+
r"""
|
| 190 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 191 |
+
"""
|
| 192 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 193 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 194 |
+
for i in range(len(tokens)):
|
| 195 |
+
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
|
| 196 |
+
if no_boseos_middle:
|
| 197 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 198 |
+
else:
|
| 199 |
+
w = []
|
| 200 |
+
if len(weights[i]) == 0:
|
| 201 |
+
w = [1.0] * weights_length
|
| 202 |
+
else:
|
| 203 |
+
for j in range(max_embeddings_multiples):
|
| 204 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 205 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 206 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 207 |
+
w += [1.0] * (weights_length - len(w))
|
| 208 |
+
weights[i] = w[:]
|
| 209 |
+
|
| 210 |
+
return tokens, weights
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
|
| 214 |
+
if not is_sdxl_text_encoder2:
|
| 215 |
+
# text_encoder1: same as SD1/2
|
| 216 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
| 217 |
+
hidden_states = enc_out["hidden_states"][11]
|
| 218 |
+
pool = None
|
| 219 |
+
else:
|
| 220 |
+
# text_encoder2
|
| 221 |
+
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
| 222 |
+
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
| 223 |
+
# pool = enc_out["text_embeds"]
|
| 224 |
+
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
|
| 225 |
+
hidden_states = hidden_states.to(device)
|
| 226 |
+
if pool is not None:
|
| 227 |
+
pool = pool.to(device)
|
| 228 |
+
return hidden_states, pool
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_unweighted_text_embeddings(
|
| 232 |
+
pipe: StableDiffusionPipeline,
|
| 233 |
+
text_input: torch.Tensor,
|
| 234 |
+
chunk_length: int,
|
| 235 |
+
clip_skip: int,
|
| 236 |
+
eos: int,
|
| 237 |
+
pad: int,
|
| 238 |
+
is_sdxl_text_encoder2: bool,
|
| 239 |
+
no_boseos_middle: Optional[bool] = True,
|
| 240 |
+
):
|
| 241 |
+
"""
|
| 242 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 243 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 244 |
+
"""
|
| 245 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 246 |
+
text_pool = None
|
| 247 |
+
if max_embeddings_multiples > 1:
|
| 248 |
+
text_embeddings = []
|
| 249 |
+
for i in range(max_embeddings_multiples):
|
| 250 |
+
# extract the i-th chunk
|
| 251 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 252 |
+
|
| 253 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 254 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 255 |
+
if pad == eos: # v1
|
| 256 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 257 |
+
else: # v2
|
| 258 |
+
for j in range(len(text_input_chunk)):
|
| 259 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 260 |
+
text_input_chunk[j, -1] = eos
|
| 261 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 262 |
+
text_input_chunk[j, 1] = eos
|
| 263 |
+
|
| 264 |
+
text_embedding, current_text_pool = get_hidden_states(
|
| 265 |
+
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
|
| 266 |
+
)
|
| 267 |
+
if text_pool is None:
|
| 268 |
+
text_pool = current_text_pool
|
| 269 |
+
|
| 270 |
+
if no_boseos_middle:
|
| 271 |
+
if i == 0:
|
| 272 |
+
# discard the ending token
|
| 273 |
+
text_embedding = text_embedding[:, :-1]
|
| 274 |
+
elif i == max_embeddings_multiples - 1:
|
| 275 |
+
# discard the starting token
|
| 276 |
+
text_embedding = text_embedding[:, 1:]
|
| 277 |
+
else:
|
| 278 |
+
# discard both starting and ending tokens
|
| 279 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 280 |
+
|
| 281 |
+
text_embeddings.append(text_embedding)
|
| 282 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 283 |
+
else:
|
| 284 |
+
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
|
| 285 |
+
return text_embeddings, text_pool
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_weighted_text_embeddings(
|
| 289 |
+
pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
|
| 290 |
+
prompt: Union[str, List[str]],
|
| 291 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
| 292 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 293 |
+
no_boseos_middle: Optional[bool] = False,
|
| 294 |
+
skip_parsing: Optional[bool] = False,
|
| 295 |
+
skip_weighting: Optional[bool] = False,
|
| 296 |
+
clip_skip=None,
|
| 297 |
+
is_sdxl_text_encoder2=False,
|
| 298 |
+
):
|
| 299 |
+
r"""
|
| 300 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 301 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 302 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 303 |
+
|
| 304 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
pipe (`StableDiffusionPipeline`):
|
| 308 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
| 309 |
+
prompt (`str` or `List[str]`):
|
| 310 |
+
The prompt or prompts to guide the image generation.
|
| 311 |
+
uncond_prompt (`str` or `List[str]`):
|
| 312 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
| 313 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
| 314 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 315 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 316 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 317 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 318 |
+
ending token in each of the chunk in the middle.
|
| 319 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 320 |
+
Skip the parsing of brackets.
|
| 321 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 322 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 323 |
+
"""
|
| 324 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 325 |
+
if isinstance(prompt, str):
|
| 326 |
+
prompt = [prompt]
|
| 327 |
+
|
| 328 |
+
if not skip_parsing:
|
| 329 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
| 330 |
+
if uncond_prompt is not None:
|
| 331 |
+
if isinstance(uncond_prompt, str):
|
| 332 |
+
uncond_prompt = [uncond_prompt]
|
| 333 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
| 334 |
+
else:
|
| 335 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
| 336 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
| 337 |
+
if uncond_prompt is not None:
|
| 338 |
+
if isinstance(uncond_prompt, str):
|
| 339 |
+
uncond_prompt = [uncond_prompt]
|
| 340 |
+
uncond_tokens = [
|
| 341 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
| 342 |
+
]
|
| 343 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
| 344 |
+
|
| 345 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 346 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 347 |
+
if uncond_prompt is not None:
|
| 348 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
| 349 |
+
|
| 350 |
+
max_embeddings_multiples = min(
|
| 351 |
+
max_embeddings_multiples,
|
| 352 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
| 353 |
+
)
|
| 354 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 355 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 356 |
+
|
| 357 |
+
# pad the length of tokens and weights
|
| 358 |
+
bos = pipe.tokenizer.bos_token_id
|
| 359 |
+
eos = pipe.tokenizer.eos_token_id
|
| 360 |
+
pad = pipe.tokenizer.pad_token_id
|
| 361 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 362 |
+
prompt_tokens,
|
| 363 |
+
prompt_weights,
|
| 364 |
+
max_length,
|
| 365 |
+
bos,
|
| 366 |
+
eos,
|
| 367 |
+
pad,
|
| 368 |
+
no_boseos_middle=no_boseos_middle,
|
| 369 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 370 |
+
)
|
| 371 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
| 372 |
+
if uncond_prompt is not None:
|
| 373 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 374 |
+
uncond_tokens,
|
| 375 |
+
uncond_weights,
|
| 376 |
+
max_length,
|
| 377 |
+
bos,
|
| 378 |
+
eos,
|
| 379 |
+
pad,
|
| 380 |
+
no_boseos_middle=no_boseos_middle,
|
| 381 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 382 |
+
)
|
| 383 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
| 384 |
+
|
| 385 |
+
# get the embeddings
|
| 386 |
+
text_embeddings, text_pool = get_unweighted_text_embeddings(
|
| 387 |
+
pipe,
|
| 388 |
+
prompt_tokens,
|
| 389 |
+
pipe.tokenizer.model_max_length,
|
| 390 |
+
clip_skip,
|
| 391 |
+
eos,
|
| 392 |
+
pad,
|
| 393 |
+
is_sdxl_text_encoder2,
|
| 394 |
+
no_boseos_middle=no_boseos_middle,
|
| 395 |
+
)
|
| 396 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
| 397 |
+
|
| 398 |
+
if uncond_prompt is not None:
|
| 399 |
+
uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
|
| 400 |
+
pipe,
|
| 401 |
+
uncond_tokens,
|
| 402 |
+
pipe.tokenizer.model_max_length,
|
| 403 |
+
clip_skip,
|
| 404 |
+
eos,
|
| 405 |
+
pad,
|
| 406 |
+
is_sdxl_text_encoder2,
|
| 407 |
+
no_boseos_middle=no_boseos_middle,
|
| 408 |
+
)
|
| 409 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
| 410 |
+
|
| 411 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 412 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
| 413 |
+
if (not skip_parsing) and (not skip_weighting):
|
| 414 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 415 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
| 416 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 417 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 418 |
+
if uncond_prompt is not None:
|
| 419 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 420 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
| 421 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 422 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 423 |
+
|
| 424 |
+
if uncond_prompt is not None:
|
| 425 |
+
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
| 426 |
+
return text_embeddings, text_pool, None, None
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def preprocess_image(image):
|
| 430 |
+
w, h = image.size
|
| 431 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 432 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
| 433 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 434 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 435 |
+
image = torch.from_numpy(image)
|
| 436 |
+
return 2.0 * image - 1.0
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def preprocess_mask(mask, scale_factor=8):
|
| 440 |
+
mask = mask.convert("L")
|
| 441 |
+
w, h = mask.size
|
| 442 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 443 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
| 444 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
| 445 |
+
mask = np.tile(mask, (4, 1, 1))
|
| 446 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
| 447 |
+
mask = 1 - mask # repaint white, keep black
|
| 448 |
+
mask = torch.from_numpy(mask)
|
| 449 |
+
return mask
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def prepare_controlnet_image(
|
| 453 |
+
image: PIL.Image.Image,
|
| 454 |
+
width: int,
|
| 455 |
+
height: int,
|
| 456 |
+
batch_size: int,
|
| 457 |
+
num_images_per_prompt: int,
|
| 458 |
+
device: torch.device,
|
| 459 |
+
dtype: torch.dtype,
|
| 460 |
+
do_classifier_free_guidance: bool = False,
|
| 461 |
+
guess_mode: bool = False,
|
| 462 |
+
):
|
| 463 |
+
if not isinstance(image, torch.Tensor):
|
| 464 |
+
if isinstance(image, PIL.Image.Image):
|
| 465 |
+
image = [image]
|
| 466 |
+
|
| 467 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 468 |
+
images = []
|
| 469 |
+
|
| 470 |
+
for image_ in image:
|
| 471 |
+
image_ = image_.convert("RGB")
|
| 472 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
| 473 |
+
image_ = np.array(image_)
|
| 474 |
+
image_ = image_[None, :]
|
| 475 |
+
images.append(image_)
|
| 476 |
+
|
| 477 |
+
image = images
|
| 478 |
+
|
| 479 |
+
image = np.concatenate(image, axis=0)
|
| 480 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 481 |
+
image = image.transpose(0, 3, 1, 2)
|
| 482 |
+
image = torch.from_numpy(image)
|
| 483 |
+
elif isinstance(image[0], torch.Tensor):
|
| 484 |
+
image = torch.cat(image, dim=0)
|
| 485 |
+
|
| 486 |
+
image_batch_size = image.shape[0]
|
| 487 |
+
|
| 488 |
+
if image_batch_size == 1:
|
| 489 |
+
repeat_by = batch_size
|
| 490 |
+
else:
|
| 491 |
+
# image batch size is the same as prompt batch size
|
| 492 |
+
repeat_by = num_images_per_prompt
|
| 493 |
+
|
| 494 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 495 |
+
|
| 496 |
+
image = image.to(device=device, dtype=dtype)
|
| 497 |
+
|
| 498 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 499 |
+
image = torch.cat([image] * 2)
|
| 500 |
+
|
| 501 |
+
return image
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class SdxlStableDiffusionLongPromptWeightingPipeline:
|
| 505 |
+
r"""
|
| 506 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
| 507 |
+
weighting in prompt.
|
| 508 |
+
|
| 509 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 510 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
vae ([`AutoencoderKL`]):
|
| 514 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 515 |
+
text_encoder ([`CLIPTextModel`]):
|
| 516 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 517 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 518 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 519 |
+
tokenizer (`CLIPTokenizer`):
|
| 520 |
+
Tokenizer of class
|
| 521 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 522 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 523 |
+
scheduler ([`SchedulerMixin`]):
|
| 524 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 525 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 526 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 527 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 528 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
| 529 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
| 530 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
| 534 |
+
|
| 535 |
+
def __init__(
|
| 536 |
+
self,
|
| 537 |
+
vae: AutoencoderKL,
|
| 538 |
+
text_encoder: List[CLIPTextModel],
|
| 539 |
+
tokenizer: List[CLIPTokenizer],
|
| 540 |
+
unet: UNet2DConditionModel,
|
| 541 |
+
scheduler: SchedulerMixin,
|
| 542 |
+
# clip_skip: int,
|
| 543 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 544 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 545 |
+
requires_safety_checker: bool = True,
|
| 546 |
+
clip_skip: int = 1,
|
| 547 |
+
):
|
| 548 |
+
# clip skip is ignored currently
|
| 549 |
+
self.tokenizer = tokenizer[0]
|
| 550 |
+
self.text_encoder = text_encoder[0]
|
| 551 |
+
self.unet = unet
|
| 552 |
+
self.scheduler = scheduler
|
| 553 |
+
self.safety_checker = safety_checker
|
| 554 |
+
self.feature_extractor = feature_extractor
|
| 555 |
+
self.requires_safety_checker = requires_safety_checker
|
| 556 |
+
self.vae = vae
|
| 557 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 558 |
+
self.progress_bar = lambda x: tqdm(x, leave=False)
|
| 559 |
+
|
| 560 |
+
self.clip_skip = clip_skip
|
| 561 |
+
self.tokenizers = tokenizer
|
| 562 |
+
self.text_encoders = text_encoder
|
| 563 |
+
|
| 564 |
+
# self.__init__additional__()
|
| 565 |
+
|
| 566 |
+
# def __init__additional__(self):
|
| 567 |
+
# if not hasattr(self, "vae_scale_factor"):
|
| 568 |
+
# setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
| 569 |
+
|
| 570 |
+
def to(self, device=None, dtype=None):
|
| 571 |
+
if device is not None:
|
| 572 |
+
self.device = device
|
| 573 |
+
# self.vae.to(device=self.device)
|
| 574 |
+
if dtype is not None:
|
| 575 |
+
self.dtype = dtype
|
| 576 |
+
|
| 577 |
+
# do not move Text Encoders to device, because Text Encoder should be on CPU
|
| 578 |
+
|
| 579 |
+
@property
|
| 580 |
+
def _execution_device(self):
|
| 581 |
+
r"""
|
| 582 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
| 583 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
| 584 |
+
hooks.
|
| 585 |
+
"""
|
| 586 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
| 587 |
+
return self.device
|
| 588 |
+
for module in self.unet.modules():
|
| 589 |
+
if (
|
| 590 |
+
hasattr(module, "_hf_hook")
|
| 591 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 592 |
+
and module._hf_hook.execution_device is not None
|
| 593 |
+
):
|
| 594 |
+
return torch.device(module._hf_hook.execution_device)
|
| 595 |
+
return self.device
|
| 596 |
+
|
| 597 |
+
def _encode_prompt(
|
| 598 |
+
self,
|
| 599 |
+
prompt,
|
| 600 |
+
device,
|
| 601 |
+
num_images_per_prompt,
|
| 602 |
+
do_classifier_free_guidance,
|
| 603 |
+
negative_prompt,
|
| 604 |
+
max_embeddings_multiples,
|
| 605 |
+
is_sdxl_text_encoder2,
|
| 606 |
+
):
|
| 607 |
+
r"""
|
| 608 |
+
Encodes the prompt into text encoder hidden states.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
prompt (`str` or `list(int)`):
|
| 612 |
+
prompt to be encoded
|
| 613 |
+
device: (`torch.device`):
|
| 614 |
+
torch device
|
| 615 |
+
num_images_per_prompt (`int`):
|
| 616 |
+
number of images that should be generated per prompt
|
| 617 |
+
do_classifier_free_guidance (`bool`):
|
| 618 |
+
whether to use classifier free guidance or not
|
| 619 |
+
negative_prompt (`str` or `List[str]`):
|
| 620 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 621 |
+
if `guidance_scale` is less than `1`).
|
| 622 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 623 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 624 |
+
"""
|
| 625 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 626 |
+
|
| 627 |
+
if negative_prompt is None:
|
| 628 |
+
negative_prompt = [""] * batch_size
|
| 629 |
+
elif isinstance(negative_prompt, str):
|
| 630 |
+
negative_prompt = [negative_prompt] * batch_size
|
| 631 |
+
if batch_size != len(negative_prompt):
|
| 632 |
+
raise ValueError(
|
| 633 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 634 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 635 |
+
" the batch size of `prompt`."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
|
| 639 |
+
pipe=self,
|
| 640 |
+
prompt=prompt,
|
| 641 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
| 642 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 643 |
+
clip_skip=self.clip_skip,
|
| 644 |
+
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
|
| 645 |
+
)
|
| 646 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
| 647 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
|
| 648 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 649 |
+
if text_pool is not None:
|
| 650 |
+
text_pool = text_pool.repeat(1, num_images_per_prompt)
|
| 651 |
+
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
|
| 652 |
+
|
| 653 |
+
if do_classifier_free_guidance:
|
| 654 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
| 655 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 656 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 657 |
+
if uncond_pool is not None:
|
| 658 |
+
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
|
| 659 |
+
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
|
| 660 |
+
|
| 661 |
+
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
|
| 662 |
+
|
| 663 |
+
return text_embeddings, text_pool, None, None
|
| 664 |
+
|
| 665 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
| 666 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
| 667 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 668 |
+
|
| 669 |
+
if strength < 0 or strength > 1:
|
| 670 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 671 |
+
|
| 672 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 673 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 674 |
+
|
| 675 |
+
if (callback_steps is None) or (
|
| 676 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 677 |
+
):
|
| 678 |
+
raise ValueError(
|
| 679 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
| 683 |
+
if is_text2img:
|
| 684 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
| 685 |
+
else:
|
| 686 |
+
# get the original timestep using init_timestep
|
| 687 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
| 688 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
| 689 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
| 690 |
+
|
| 691 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 692 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
| 693 |
+
return timesteps, num_inference_steps - t_start
|
| 694 |
+
|
| 695 |
+
def run_safety_checker(self, image, device, dtype):
|
| 696 |
+
if self.safety_checker is not None:
|
| 697 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
| 698 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
| 699 |
+
else:
|
| 700 |
+
has_nsfw_concept = None
|
| 701 |
+
return image, has_nsfw_concept
|
| 702 |
+
|
| 703 |
+
def decode_latents(self, latents):
|
| 704 |
+
with torch.no_grad():
|
| 705 |
+
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
| 706 |
+
|
| 707 |
+
# print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
|
| 708 |
+
# x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
|
| 709 |
+
# print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
|
| 710 |
+
# self.vae.to("cpu")
|
| 711 |
+
# self.vae.set_use_memory_efficient_attention_xformers(False)
|
| 712 |
+
# image = self.vae.decode(latents.to("cpu")).sample
|
| 713 |
+
|
| 714 |
+
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
| 715 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 716 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 717 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 718 |
+
return image
|
| 719 |
+
|
| 720 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 721 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 722 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 723 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 724 |
+
# and should be between [0, 1]
|
| 725 |
+
|
| 726 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 727 |
+
extra_step_kwargs = {}
|
| 728 |
+
if accepts_eta:
|
| 729 |
+
extra_step_kwargs["eta"] = eta
|
| 730 |
+
|
| 731 |
+
# check if the scheduler accepts generator
|
| 732 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 733 |
+
if accepts_generator:
|
| 734 |
+
extra_step_kwargs["generator"] = generator
|
| 735 |
+
return extra_step_kwargs
|
| 736 |
+
|
| 737 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
| 738 |
+
if image is None:
|
| 739 |
+
shape = (
|
| 740 |
+
batch_size,
|
| 741 |
+
self.unet.in_channels,
|
| 742 |
+
height // self.vae_scale_factor,
|
| 743 |
+
width // self.vae_scale_factor,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
if latents is None:
|
| 747 |
+
if device.type == "mps":
|
| 748 |
+
# randn does not work reproducibly on mps
|
| 749 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 750 |
+
else:
|
| 751 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 752 |
+
else:
|
| 753 |
+
if latents.shape != shape:
|
| 754 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 755 |
+
latents = latents.to(device)
|
| 756 |
+
|
| 757 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 758 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 759 |
+
return latents, None, None
|
| 760 |
+
else:
|
| 761 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
| 762 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
| 763 |
+
init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
|
| 764 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
| 765 |
+
init_latents_orig = init_latents
|
| 766 |
+
shape = init_latents.shape
|
| 767 |
+
|
| 768 |
+
# add noise to latents using the timesteps
|
| 769 |
+
if device.type == "mps":
|
| 770 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 771 |
+
else:
|
| 772 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 773 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 774 |
+
return latents, init_latents_orig, noise
|
| 775 |
+
|
| 776 |
+
@torch.no_grad()
|
| 777 |
+
def __call__(
|
| 778 |
+
self,
|
| 779 |
+
prompt: Union[str, List[str]],
|
| 780 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 781 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 782 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 783 |
+
height: int = 512,
|
| 784 |
+
width: int = 512,
|
| 785 |
+
num_inference_steps: int = 50,
|
| 786 |
+
guidance_scale: float = 7.5,
|
| 787 |
+
strength: float = 0.8,
|
| 788 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 789 |
+
eta: float = 0.0,
|
| 790 |
+
generator: Optional[torch.Generator] = None,
|
| 791 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 792 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 793 |
+
output_type: Optional[str] = "pil",
|
| 794 |
+
return_dict: bool = True,
|
| 795 |
+
controlnet=None,
|
| 796 |
+
controlnet_image=None,
|
| 797 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 798 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 799 |
+
callback_steps: int = 1,
|
| 800 |
+
):
|
| 801 |
+
r"""
|
| 802 |
+
Function invoked when calling the pipeline for generation.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
prompt (`str` or `List[str]`):
|
| 806 |
+
The prompt or prompts to guide the image generation.
|
| 807 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 808 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 809 |
+
if `guidance_scale` is less than `1`).
|
| 810 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 811 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 812 |
+
process.
|
| 813 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 814 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 815 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 816 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 817 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 818 |
+
height (`int`, *optional*, defaults to 512):
|
| 819 |
+
The height in pixels of the generated image.
|
| 820 |
+
width (`int`, *optional*, defaults to 512):
|
| 821 |
+
The width in pixels of the generated image.
|
| 822 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 823 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 824 |
+
expense of slower inference.
|
| 825 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 826 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 827 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 828 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 829 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 830 |
+
usually at the expense of lower image quality.
|
| 831 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 832 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 833 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 834 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 835 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 836 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 837 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 838 |
+
The number of images to generate per prompt.
|
| 839 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 840 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 841 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 842 |
+
generator (`torch.Generator`, *optional*):
|
| 843 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 844 |
+
deterministic.
|
| 845 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 846 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 847 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 848 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 849 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 850 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 851 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 852 |
+
The output format of the generate image. Choose between
|
| 853 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 854 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 855 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 856 |
+
plain tuple.
|
| 857 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
| 858 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
| 859 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
| 860 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
| 861 |
+
inference.
|
| 862 |
+
callback (`Callable`, *optional*):
|
| 863 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 864 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 865 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 866 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 867 |
+
`True`, the inference will be cancelled.
|
| 868 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 869 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 870 |
+
called at every step.
|
| 871 |
+
|
| 872 |
+
Returns:
|
| 873 |
+
`None` if cancelled by `is_cancelled_callback`,
|
| 874 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 875 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 876 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 877 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 878 |
+
(nsfw) content, according to the `safety_checker`.
|
| 879 |
+
"""
|
| 880 |
+
if controlnet is not None and controlnet_image is None:
|
| 881 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
| 882 |
+
|
| 883 |
+
# 0. Default height and width to unet
|
| 884 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 885 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 886 |
+
|
| 887 |
+
# 1. Check inputs. Raise error if not correct
|
| 888 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
| 889 |
+
|
| 890 |
+
# 2. Define call parameters
|
| 891 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 892 |
+
device = self._execution_device
|
| 893 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 894 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 895 |
+
# corresponds to doing no classifier free guidance.
|
| 896 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 897 |
+
|
| 898 |
+
# 3. Encode input prompt
|
| 899 |
+
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
|
| 900 |
+
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
|
| 901 |
+
text_embeddings_list = []
|
| 902 |
+
text_pool = None
|
| 903 |
+
uncond_embeddings_list = []
|
| 904 |
+
uncond_pool = None
|
| 905 |
+
for i in range(len(self.tokenizers)):
|
| 906 |
+
self.tokenizer = self.tokenizers[i]
|
| 907 |
+
self.text_encoder = self.text_encoders[i]
|
| 908 |
+
|
| 909 |
+
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
|
| 910 |
+
prompt,
|
| 911 |
+
device,
|
| 912 |
+
num_images_per_prompt,
|
| 913 |
+
do_classifier_free_guidance,
|
| 914 |
+
negative_prompt,
|
| 915 |
+
max_embeddings_multiples,
|
| 916 |
+
is_sdxl_text_encoder2=i == 1,
|
| 917 |
+
)
|
| 918 |
+
text_embeddings_list.append(text_embeddings)
|
| 919 |
+
uncond_embeddings_list.append(uncond_embeddings)
|
| 920 |
+
|
| 921 |
+
if tp1 is not None:
|
| 922 |
+
text_pool = tp1
|
| 923 |
+
if up1 is not None:
|
| 924 |
+
uncond_pool = up1
|
| 925 |
+
|
| 926 |
+
unet_dtype = self.unet.dtype
|
| 927 |
+
dtype = unet_dtype
|
| 928 |
+
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
|
| 929 |
+
dtype = torch.float16
|
| 930 |
+
self.unet.to(dtype)
|
| 931 |
+
|
| 932 |
+
# 4. Preprocess image and mask
|
| 933 |
+
if isinstance(image, PIL.Image.Image):
|
| 934 |
+
image = preprocess_image(image)
|
| 935 |
+
if image is not None:
|
| 936 |
+
image = image.to(device=self.device, dtype=dtype)
|
| 937 |
+
if isinstance(mask_image, PIL.Image.Image):
|
| 938 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
| 939 |
+
if mask_image is not None:
|
| 940 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
| 941 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
| 942 |
+
else:
|
| 943 |
+
mask = None
|
| 944 |
+
|
| 945 |
+
# ControlNet is not working yet in SDXL, but keep the code here for future use
|
| 946 |
+
if controlnet_image is not None:
|
| 947 |
+
controlnet_image = prepare_controlnet_image(
|
| 948 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
# 5. set timesteps
|
| 952 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 953 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
| 954 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 955 |
+
|
| 956 |
+
# 6. Prepare latent variables
|
| 957 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
| 958 |
+
image,
|
| 959 |
+
latent_timestep,
|
| 960 |
+
batch_size * num_images_per_prompt,
|
| 961 |
+
height,
|
| 962 |
+
width,
|
| 963 |
+
dtype,
|
| 964 |
+
device,
|
| 965 |
+
generator,
|
| 966 |
+
latents,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 970 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 971 |
+
|
| 972 |
+
# create size embs and concat embeddings for SDXL
|
| 973 |
+
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
|
| 974 |
+
crop_size = torch.zeros_like(orig_size)
|
| 975 |
+
target_size = orig_size
|
| 976 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
|
| 977 |
+
|
| 978 |
+
# make conditionings
|
| 979 |
+
if do_classifier_free_guidance:
|
| 980 |
+
text_embeddings = torch.cat(text_embeddings_list, dim=2)
|
| 981 |
+
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
|
| 982 |
+
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
|
| 983 |
+
|
| 984 |
+
cond_vector = torch.cat([text_pool, embs], dim=1)
|
| 985 |
+
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
|
| 986 |
+
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
|
| 987 |
+
else:
|
| 988 |
+
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
|
| 989 |
+
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
|
| 990 |
+
|
| 991 |
+
# 8. Denoising loop
|
| 992 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 993 |
+
# expand the latents if we are doing classifier free guidance
|
| 994 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 995 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 996 |
+
|
| 997 |
+
unet_additional_args = {}
|
| 998 |
+
if controlnet is not None:
|
| 999 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
| 1000 |
+
latent_model_input,
|
| 1001 |
+
t,
|
| 1002 |
+
encoder_hidden_states=text_embeddings,
|
| 1003 |
+
controlnet_cond=controlnet_image,
|
| 1004 |
+
conditioning_scale=1.0,
|
| 1005 |
+
guess_mode=False,
|
| 1006 |
+
return_dict=False,
|
| 1007 |
+
)
|
| 1008 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
| 1009 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
| 1010 |
+
|
| 1011 |
+
# predict the noise residual
|
| 1012 |
+
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
| 1013 |
+
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
| 1014 |
+
|
| 1015 |
+
# perform guidance
|
| 1016 |
+
if do_classifier_free_guidance:
|
| 1017 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1018 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1019 |
+
|
| 1020 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1021 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 1022 |
+
|
| 1023 |
+
if mask is not None:
|
| 1024 |
+
# masking
|
| 1025 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
| 1026 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 1027 |
+
|
| 1028 |
+
# call the callback, if provided
|
| 1029 |
+
if i % callback_steps == 0:
|
| 1030 |
+
if callback is not None:
|
| 1031 |
+
callback(i, t, latents)
|
| 1032 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 1033 |
+
return None
|
| 1034 |
+
|
| 1035 |
+
self.unet.to(unet_dtype)
|
| 1036 |
+
return latents
|
| 1037 |
+
|
| 1038 |
+
def latents_to_image(self, latents):
|
| 1039 |
+
# 9. Post-processing
|
| 1040 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
| 1041 |
+
image = self.numpy_to_pil(image)
|
| 1042 |
+
return image
|
| 1043 |
+
|
| 1044 |
+
# copy from pil_utils.py
|
| 1045 |
+
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
| 1046 |
+
"""
|
| 1047 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 1048 |
+
"""
|
| 1049 |
+
if images.ndim == 3:
|
| 1050 |
+
images = images[None, ...]
|
| 1051 |
+
images = (images * 255).round().astype("uint8")
|
| 1052 |
+
if images.shape[-1] == 1:
|
| 1053 |
+
# special case for grayscale (single channel) images
|
| 1054 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 1055 |
+
else:
|
| 1056 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 1057 |
+
|
| 1058 |
+
return pil_images
|
| 1059 |
+
|
| 1060 |
+
def text2img(
|
| 1061 |
+
self,
|
| 1062 |
+
prompt: Union[str, List[str]],
|
| 1063 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1064 |
+
height: int = 512,
|
| 1065 |
+
width: int = 512,
|
| 1066 |
+
num_inference_steps: int = 50,
|
| 1067 |
+
guidance_scale: float = 7.5,
|
| 1068 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1069 |
+
eta: float = 0.0,
|
| 1070 |
+
generator: Optional[torch.Generator] = None,
|
| 1071 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 1072 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1073 |
+
output_type: Optional[str] = "pil",
|
| 1074 |
+
return_dict: bool = True,
|
| 1075 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1076 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1077 |
+
callback_steps: int = 1,
|
| 1078 |
+
):
|
| 1079 |
+
r"""
|
| 1080 |
+
Function for text-to-image generation.
|
| 1081 |
+
Args:
|
| 1082 |
+
prompt (`str` or `List[str]`):
|
| 1083 |
+
The prompt or prompts to guide the image generation.
|
| 1084 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1085 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1086 |
+
if `guidance_scale` is less than `1`).
|
| 1087 |
+
height (`int`, *optional*, defaults to 512):
|
| 1088 |
+
The height in pixels of the generated image.
|
| 1089 |
+
width (`int`, *optional*, defaults to 512):
|
| 1090 |
+
The width in pixels of the generated image.
|
| 1091 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1092 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 1093 |
+
expense of slower inference.
|
| 1094 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1095 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1096 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1097 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1098 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1099 |
+
usually at the expense of lower image quality.
|
| 1100 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1101 |
+
The number of images to generate per prompt.
|
| 1102 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1103 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1104 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1105 |
+
generator (`torch.Generator`, *optional*):
|
| 1106 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1107 |
+
deterministic.
|
| 1108 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 1109 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 1110 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 1111 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 1112 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1113 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1114 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1115 |
+
The output format of the generate image. Choose between
|
| 1116 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1117 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1118 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1119 |
+
plain tuple.
|
| 1120 |
+
callback (`Callable`, *optional*):
|
| 1121 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1122 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1123 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1124 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1125 |
+
`True`, the inference will be cancelled.
|
| 1126 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1127 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1128 |
+
called at every step.
|
| 1129 |
+
Returns:
|
| 1130 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1131 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1132 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1133 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1134 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1135 |
+
"""
|
| 1136 |
+
return self.__call__(
|
| 1137 |
+
prompt=prompt,
|
| 1138 |
+
negative_prompt=negative_prompt,
|
| 1139 |
+
height=height,
|
| 1140 |
+
width=width,
|
| 1141 |
+
num_inference_steps=num_inference_steps,
|
| 1142 |
+
guidance_scale=guidance_scale,
|
| 1143 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1144 |
+
eta=eta,
|
| 1145 |
+
generator=generator,
|
| 1146 |
+
latents=latents,
|
| 1147 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1148 |
+
output_type=output_type,
|
| 1149 |
+
return_dict=return_dict,
|
| 1150 |
+
callback=callback,
|
| 1151 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1152 |
+
callback_steps=callback_steps,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
def img2img(
|
| 1156 |
+
self,
|
| 1157 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1158 |
+
prompt: Union[str, List[str]],
|
| 1159 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1160 |
+
strength: float = 0.8,
|
| 1161 |
+
num_inference_steps: Optional[int] = 50,
|
| 1162 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1163 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1164 |
+
eta: Optional[float] = 0.0,
|
| 1165 |
+
generator: Optional[torch.Generator] = None,
|
| 1166 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1167 |
+
output_type: Optional[str] = "pil",
|
| 1168 |
+
return_dict: bool = True,
|
| 1169 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1170 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1171 |
+
callback_steps: int = 1,
|
| 1172 |
+
):
|
| 1173 |
+
r"""
|
| 1174 |
+
Function for image-to-image generation.
|
| 1175 |
+
Args:
|
| 1176 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1177 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1178 |
+
process.
|
| 1179 |
+
prompt (`str` or `List[str]`):
|
| 1180 |
+
The prompt or prompts to guide the image generation.
|
| 1181 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1182 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1183 |
+
if `guidance_scale` is less than `1`).
|
| 1184 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1185 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 1186 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 1187 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 1188 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 1189 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 1190 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1191 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 1192 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
| 1193 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1194 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1195 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1196 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1197 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1198 |
+
usually at the expense of lower image quality.
|
| 1199 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1200 |
+
The number of images to generate per prompt.
|
| 1201 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1202 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1203 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1204 |
+
generator (`torch.Generator`, *optional*):
|
| 1205 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1206 |
+
deterministic.
|
| 1207 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1208 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1209 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1210 |
+
The output format of the generate image. Choose between
|
| 1211 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1212 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1213 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1214 |
+
plain tuple.
|
| 1215 |
+
callback (`Callable`, *optional*):
|
| 1216 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1217 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1218 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1219 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1220 |
+
`True`, the inference will be cancelled.
|
| 1221 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1222 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1223 |
+
called at every step.
|
| 1224 |
+
Returns:
|
| 1225 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1226 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1227 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1228 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1229 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1230 |
+
"""
|
| 1231 |
+
return self.__call__(
|
| 1232 |
+
prompt=prompt,
|
| 1233 |
+
negative_prompt=negative_prompt,
|
| 1234 |
+
image=image,
|
| 1235 |
+
num_inference_steps=num_inference_steps,
|
| 1236 |
+
guidance_scale=guidance_scale,
|
| 1237 |
+
strength=strength,
|
| 1238 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1239 |
+
eta=eta,
|
| 1240 |
+
generator=generator,
|
| 1241 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1242 |
+
output_type=output_type,
|
| 1243 |
+
return_dict=return_dict,
|
| 1244 |
+
callback=callback,
|
| 1245 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1246 |
+
callback_steps=callback_steps,
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
def inpaint(
|
| 1250 |
+
self,
|
| 1251 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1252 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1253 |
+
prompt: Union[str, List[str]],
|
| 1254 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1255 |
+
strength: float = 0.8,
|
| 1256 |
+
num_inference_steps: Optional[int] = 50,
|
| 1257 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1258 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1259 |
+
eta: Optional[float] = 0.0,
|
| 1260 |
+
generator: Optional[torch.Generator] = None,
|
| 1261 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1262 |
+
output_type: Optional[str] = "pil",
|
| 1263 |
+
return_dict: bool = True,
|
| 1264 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1265 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1266 |
+
callback_steps: int = 1,
|
| 1267 |
+
):
|
| 1268 |
+
r"""
|
| 1269 |
+
Function for inpaint.
|
| 1270 |
+
Args:
|
| 1271 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1272 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1273 |
+
process. This is the image whose masked region will be inpainted.
|
| 1274 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1275 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 1276 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 1277 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 1278 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 1279 |
+
prompt (`str` or `List[str]`):
|
| 1280 |
+
The prompt or prompts to guide the image generation.
|
| 1281 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1282 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1283 |
+
if `guidance_scale` is less than `1`).
|
| 1284 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1285 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
| 1286 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
| 1287 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
| 1288 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
| 1289 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1290 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
| 1291 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
| 1292 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1293 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1294 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1295 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1296 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1297 |
+
usually at the expense of lower image quality.
|
| 1298 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1299 |
+
The number of images to generate per prompt.
|
| 1300 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1301 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1302 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1303 |
+
generator (`torch.Generator`, *optional*):
|
| 1304 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1305 |
+
deterministic.
|
| 1306 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1307 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1308 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1309 |
+
The output format of the generate image. Choose between
|
| 1310 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1311 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1312 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1313 |
+
plain tuple.
|
| 1314 |
+
callback (`Callable`, *optional*):
|
| 1315 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1316 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1317 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1318 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1319 |
+
`True`, the inference will be cancelled.
|
| 1320 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1321 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1322 |
+
called at every step.
|
| 1323 |
+
Returns:
|
| 1324 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1325 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1326 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1327 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1328 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1329 |
+
"""
|
| 1330 |
+
return self.__call__(
|
| 1331 |
+
prompt=prompt,
|
| 1332 |
+
negative_prompt=negative_prompt,
|
| 1333 |
+
image=image,
|
| 1334 |
+
mask_image=mask_image,
|
| 1335 |
+
num_inference_steps=num_inference_steps,
|
| 1336 |
+
guidance_scale=guidance_scale,
|
| 1337 |
+
strength=strength,
|
| 1338 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1339 |
+
eta=eta,
|
| 1340 |
+
generator=generator,
|
| 1341 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1342 |
+
output_type=output_type,
|
| 1343 |
+
return_dict=return_dict,
|
| 1344 |
+
callback=callback,
|
| 1345 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1346 |
+
callback_steps=callback_steps,
|
| 1347 |
+
)
|
library/sdxl_model_util.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import safetensors
|
| 3 |
+
from accelerate import init_empty_weights
|
| 4 |
+
from accelerate.utils.modeling import set_module_tensor_to_device
|
| 5 |
+
from safetensors.torch import load_file, save_file
|
| 6 |
+
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
| 7 |
+
from typing import List
|
| 8 |
+
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
| 9 |
+
from library import model_util
|
| 10 |
+
from library import sdxl_original_unet
|
| 11 |
+
from .utils import setup_logging
|
| 12 |
+
|
| 13 |
+
setup_logging()
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
VAE_SCALE_FACTOR = 0.13025
|
| 19 |
+
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
| 20 |
+
|
| 21 |
+
# Diffusersの設定を読み込むための参照モデル
|
| 22 |
+
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 23 |
+
|
| 24 |
+
DIFFUSERS_SDXL_UNET_CONFIG = {
|
| 25 |
+
"act_fn": "silu",
|
| 26 |
+
"addition_embed_type": "text_time",
|
| 27 |
+
"addition_embed_type_num_heads": 64,
|
| 28 |
+
"addition_time_embed_dim": 256,
|
| 29 |
+
"attention_head_dim": [5, 10, 20],
|
| 30 |
+
"block_out_channels": [320, 640, 1280],
|
| 31 |
+
"center_input_sample": False,
|
| 32 |
+
"class_embed_type": None,
|
| 33 |
+
"class_embeddings_concat": False,
|
| 34 |
+
"conv_in_kernel": 3,
|
| 35 |
+
"conv_out_kernel": 3,
|
| 36 |
+
"cross_attention_dim": 2048,
|
| 37 |
+
"cross_attention_norm": None,
|
| 38 |
+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
|
| 39 |
+
"downsample_padding": 1,
|
| 40 |
+
"dual_cross_attention": False,
|
| 41 |
+
"encoder_hid_dim": None,
|
| 42 |
+
"encoder_hid_dim_type": None,
|
| 43 |
+
"flip_sin_to_cos": True,
|
| 44 |
+
"freq_shift": 0,
|
| 45 |
+
"in_channels": 4,
|
| 46 |
+
"layers_per_block": 2,
|
| 47 |
+
"mid_block_only_cross_attention": None,
|
| 48 |
+
"mid_block_scale_factor": 1,
|
| 49 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 50 |
+
"norm_eps": 1e-05,
|
| 51 |
+
"norm_num_groups": 32,
|
| 52 |
+
"num_attention_heads": None,
|
| 53 |
+
"num_class_embeds": None,
|
| 54 |
+
"only_cross_attention": False,
|
| 55 |
+
"out_channels": 4,
|
| 56 |
+
"projection_class_embeddings_input_dim": 2816,
|
| 57 |
+
"resnet_out_scale_factor": 1.0,
|
| 58 |
+
"resnet_skip_time_act": False,
|
| 59 |
+
"resnet_time_scale_shift": "default",
|
| 60 |
+
"sample_size": 128,
|
| 61 |
+
"time_cond_proj_dim": None,
|
| 62 |
+
"time_embedding_act_fn": None,
|
| 63 |
+
"time_embedding_dim": None,
|
| 64 |
+
"time_embedding_type": "positional",
|
| 65 |
+
"timestep_post_act": None,
|
| 66 |
+
"transformer_layers_per_block": [1, 2, 10],
|
| 67 |
+
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
|
| 68 |
+
"upcast_attention": False,
|
| 69 |
+
"use_linear_projection": True,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
| 74 |
+
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
|
| 75 |
+
|
| 76 |
+
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
|
| 77 |
+
# logit_scaleはcheckpointの保存時に使用する
|
| 78 |
+
def convert_key(key):
|
| 79 |
+
# common conversion
|
| 80 |
+
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
|
| 81 |
+
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
|
| 82 |
+
|
| 83 |
+
if "resblocks" in key:
|
| 84 |
+
# resblocks conversion
|
| 85 |
+
key = key.replace(".resblocks.", ".layers.")
|
| 86 |
+
if ".ln_" in key:
|
| 87 |
+
key = key.replace(".ln_", ".layer_norm")
|
| 88 |
+
elif ".mlp." in key:
|
| 89 |
+
key = key.replace(".c_fc.", ".fc1.")
|
| 90 |
+
key = key.replace(".c_proj.", ".fc2.")
|
| 91 |
+
elif ".attn.out_proj" in key:
|
| 92 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
| 93 |
+
elif ".attn.in_proj" in key:
|
| 94 |
+
key = None # 特殊なので後で処理する
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
| 97 |
+
elif ".positional_embedding" in key:
|
| 98 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
| 99 |
+
elif ".text_projection" in key:
|
| 100 |
+
key = key.replace("text_model.text_projection", "text_projection.weight")
|
| 101 |
+
elif ".logit_scale" in key:
|
| 102 |
+
key = None # 後で処理する
|
| 103 |
+
elif ".token_embedding" in key:
|
| 104 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
| 105 |
+
elif ".ln_final" in key:
|
| 106 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
| 107 |
+
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
|
| 108 |
+
elif ".embeddings.position_ids" in key:
|
| 109 |
+
key = None # remove this key: position_ids is not used in newer transformers
|
| 110 |
+
return key
|
| 111 |
+
|
| 112 |
+
keys = list(checkpoint.keys())
|
| 113 |
+
new_sd = {}
|
| 114 |
+
for key in keys:
|
| 115 |
+
new_key = convert_key(key)
|
| 116 |
+
if new_key is None:
|
| 117 |
+
continue
|
| 118 |
+
new_sd[new_key] = checkpoint[key]
|
| 119 |
+
|
| 120 |
+
# attnの変換
|
| 121 |
+
for key in keys:
|
| 122 |
+
if ".resblocks" in key and ".attn.in_proj_" in key:
|
| 123 |
+
# 三つに分割
|
| 124 |
+
values = torch.chunk(checkpoint[key], 3)
|
| 125 |
+
|
| 126 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
| 127 |
+
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
|
| 128 |
+
key_pfx = key_pfx.replace("_weight", "")
|
| 129 |
+
key_pfx = key_pfx.replace("_bias", "")
|
| 130 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
| 131 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
| 132 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
| 133 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
| 134 |
+
|
| 135 |
+
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
| 136 |
+
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
| 137 |
+
|
| 138 |
+
# temporary workaround for text_projection.weight.weight for Playground-v2
|
| 139 |
+
if "text_projection.weight.weight" in new_sd:
|
| 140 |
+
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
| 141 |
+
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
| 142 |
+
del new_sd["text_projection.weight.weight"]
|
| 143 |
+
|
| 144 |
+
return new_sd, logit_scale
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# load state_dict without allocating new tensors
|
| 148 |
+
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
|
| 149 |
+
# dtype will use fp32 as default
|
| 150 |
+
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
| 151 |
+
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
| 152 |
+
|
| 153 |
+
# similar to model.load_state_dict()
|
| 154 |
+
if not missing_keys and not unexpected_keys:
|
| 155 |
+
for k in list(state_dict.keys()):
|
| 156 |
+
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
| 157 |
+
return "<All keys matched successfully>"
|
| 158 |
+
|
| 159 |
+
# error_msgs
|
| 160 |
+
error_msgs: List[str] = []
|
| 161 |
+
if missing_keys:
|
| 162 |
+
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
|
| 163 |
+
if unexpected_keys:
|
| 164 |
+
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
|
| 165 |
+
|
| 166 |
+
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
|
| 170 |
+
# model_version is reserved for future use
|
| 171 |
+
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
|
| 172 |
+
|
| 173 |
+
# Load the state dict
|
| 174 |
+
if model_util.is_safetensors(ckpt_path):
|
| 175 |
+
checkpoint = None
|
| 176 |
+
if disable_mmap:
|
| 177 |
+
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
|
| 178 |
+
else:
|
| 179 |
+
try:
|
| 180 |
+
state_dict = load_file(ckpt_path, device=map_location)
|
| 181 |
+
except:
|
| 182 |
+
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
| 183 |
+
epoch = None
|
| 184 |
+
global_step = None
|
| 185 |
+
else:
|
| 186 |
+
checkpoint = torch.load(ckpt_path, map_location=map_location)
|
| 187 |
+
if "state_dict" in checkpoint:
|
| 188 |
+
state_dict = checkpoint["state_dict"]
|
| 189 |
+
epoch = checkpoint.get("epoch", 0)
|
| 190 |
+
global_step = checkpoint.get("global_step", 0)
|
| 191 |
+
else:
|
| 192 |
+
state_dict = checkpoint
|
| 193 |
+
epoch = 0
|
| 194 |
+
global_step = 0
|
| 195 |
+
checkpoint = None
|
| 196 |
+
|
| 197 |
+
# U-Net
|
| 198 |
+
logger.info("building U-Net")
|
| 199 |
+
with init_empty_weights():
|
| 200 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
| 201 |
+
|
| 202 |
+
logger.info("loading U-Net from checkpoint")
|
| 203 |
+
unet_sd = {}
|
| 204 |
+
for k in list(state_dict.keys()):
|
| 205 |
+
if k.startswith("model.diffusion_model."):
|
| 206 |
+
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
| 207 |
+
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
| 208 |
+
logger.info(f"U-Net: {info}")
|
| 209 |
+
|
| 210 |
+
# Text Encoders
|
| 211 |
+
logger.info("building text encoders")
|
| 212 |
+
|
| 213 |
+
# Text Encoder 1 is same to Stability AI's SDXL
|
| 214 |
+
text_model1_cfg = CLIPTextConfig(
|
| 215 |
+
vocab_size=49408,
|
| 216 |
+
hidden_size=768,
|
| 217 |
+
intermediate_size=3072,
|
| 218 |
+
num_hidden_layers=12,
|
| 219 |
+
num_attention_heads=12,
|
| 220 |
+
max_position_embeddings=77,
|
| 221 |
+
hidden_act="quick_gelu",
|
| 222 |
+
layer_norm_eps=1e-05,
|
| 223 |
+
dropout=0.0,
|
| 224 |
+
attention_dropout=0.0,
|
| 225 |
+
initializer_range=0.02,
|
| 226 |
+
initializer_factor=1.0,
|
| 227 |
+
pad_token_id=1,
|
| 228 |
+
bos_token_id=0,
|
| 229 |
+
eos_token_id=2,
|
| 230 |
+
model_type="clip_text_model",
|
| 231 |
+
projection_dim=768,
|
| 232 |
+
# torch_dtype="float32",
|
| 233 |
+
# transformers_version="4.25.0.dev0",
|
| 234 |
+
)
|
| 235 |
+
with init_empty_weights():
|
| 236 |
+
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
|
| 237 |
+
|
| 238 |
+
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
|
| 239 |
+
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
|
| 240 |
+
text_model2_cfg = CLIPTextConfig(
|
| 241 |
+
vocab_size=49408,
|
| 242 |
+
hidden_size=1280,
|
| 243 |
+
intermediate_size=5120,
|
| 244 |
+
num_hidden_layers=32,
|
| 245 |
+
num_attention_heads=20,
|
| 246 |
+
max_position_embeddings=77,
|
| 247 |
+
hidden_act="gelu",
|
| 248 |
+
layer_norm_eps=1e-05,
|
| 249 |
+
dropout=0.0,
|
| 250 |
+
attention_dropout=0.0,
|
| 251 |
+
initializer_range=0.02,
|
| 252 |
+
initializer_factor=1.0,
|
| 253 |
+
pad_token_id=1,
|
| 254 |
+
bos_token_id=0,
|
| 255 |
+
eos_token_id=2,
|
| 256 |
+
model_type="clip_text_model",
|
| 257 |
+
projection_dim=1280,
|
| 258 |
+
# torch_dtype="float32",
|
| 259 |
+
# transformers_version="4.25.0.dev0",
|
| 260 |
+
)
|
| 261 |
+
with init_empty_weights():
|
| 262 |
+
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
| 263 |
+
|
| 264 |
+
logger.info("loading text encoders from checkpoint")
|
| 265 |
+
te1_sd = {}
|
| 266 |
+
te2_sd = {}
|
| 267 |
+
for k in list(state_dict.keys()):
|
| 268 |
+
if k.startswith("conditioner.embedders.0.transformer."):
|
| 269 |
+
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
| 270 |
+
elif k.startswith("conditioner.embedders.1.model."):
|
| 271 |
+
te2_sd[k] = state_dict.pop(k)
|
| 272 |
+
|
| 273 |
+
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
|
| 274 |
+
if "text_model.embeddings.position_ids" in te1_sd:
|
| 275 |
+
te1_sd.pop("text_model.embeddings.position_ids")
|
| 276 |
+
|
| 277 |
+
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
| 278 |
+
logger.info(f"text encoder 1: {info1}")
|
| 279 |
+
|
| 280 |
+
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
| 281 |
+
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
| 282 |
+
logger.info(f"text encoder 2: {info2}")
|
| 283 |
+
|
| 284 |
+
# prepare vae
|
| 285 |
+
logger.info("building VAE")
|
| 286 |
+
vae_config = model_util.create_vae_diffusers_config()
|
| 287 |
+
with init_empty_weights():
|
| 288 |
+
vae = AutoencoderKL(**vae_config)
|
| 289 |
+
|
| 290 |
+
logger.info("loading VAE from checkpoint")
|
| 291 |
+
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
| 292 |
+
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
| 293 |
+
logger.info(f"VAE: {info}")
|
| 294 |
+
|
| 295 |
+
ckpt_info = (epoch, global_step) if epoch is not None else None
|
| 296 |
+
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def make_unet_conversion_map():
|
| 300 |
+
unet_conversion_map_layer = []
|
| 301 |
+
|
| 302 |
+
for i in range(3): # num_blocks is 3 in sdxl
|
| 303 |
+
# loop over downblocks/upblocks
|
| 304 |
+
for j in range(2):
|
| 305 |
+
# loop over resnets/attentions for downblocks
|
| 306 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 307 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 308 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 309 |
+
|
| 310 |
+
if i < 3:
|
| 311 |
+
# no attention layers in down_blocks.3
|
| 312 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 313 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 314 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 315 |
+
|
| 316 |
+
for j in range(3):
|
| 317 |
+
# loop over resnets/attentions for upblocks
|
| 318 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 319 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 320 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 321 |
+
|
| 322 |
+
# if i > 0: commentout for sdxl
|
| 323 |
+
# no attention layers in up_blocks.0
|
| 324 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 325 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 326 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 327 |
+
|
| 328 |
+
if i < 3:
|
| 329 |
+
# no downsample in down_blocks.3
|
| 330 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 331 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 332 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 333 |
+
|
| 334 |
+
# no upsample in up_blocks.3
|
| 335 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 336 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
| 337 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 338 |
+
|
| 339 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 340 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 341 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 342 |
+
|
| 343 |
+
for j in range(2):
|
| 344 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 345 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 346 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 347 |
+
|
| 348 |
+
unet_conversion_map_resnet = [
|
| 349 |
+
# (stable-diffusion, HF Diffusers)
|
| 350 |
+
("in_layers.0.", "norm1."),
|
| 351 |
+
("in_layers.2.", "conv1."),
|
| 352 |
+
("out_layers.0.", "norm2."),
|
| 353 |
+
("out_layers.3.", "conv2."),
|
| 354 |
+
("emb_layers.1.", "time_emb_proj."),
|
| 355 |
+
("skip_connection.", "conv_shortcut."),
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
unet_conversion_map = []
|
| 359 |
+
for sd, hf in unet_conversion_map_layer:
|
| 360 |
+
if "resnets" in hf:
|
| 361 |
+
for sd_res, hf_res in unet_conversion_map_resnet:
|
| 362 |
+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
| 363 |
+
else:
|
| 364 |
+
unet_conversion_map.append((sd, hf))
|
| 365 |
+
|
| 366 |
+
for j in range(2):
|
| 367 |
+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
| 368 |
+
sd_time_embed_prefix = f"time_embed.{j*2}."
|
| 369 |
+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
| 370 |
+
|
| 371 |
+
for j in range(2):
|
| 372 |
+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
| 373 |
+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
| 374 |
+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
| 375 |
+
|
| 376 |
+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
| 377 |
+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
| 378 |
+
unet_conversion_map.append(("out.2.", "conv_out."))
|
| 379 |
+
|
| 380 |
+
return unet_conversion_map
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
|
| 384 |
+
unet_conversion_map = make_unet_conversion_map()
|
| 385 |
+
|
| 386 |
+
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
|
| 387 |
+
return convert_unet_state_dict(du_sd, conversion_map)
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def convert_unet_state_dict(src_sd, conversion_map):
|
| 391 |
+
converted_sd = {}
|
| 392 |
+
for src_key, value in src_sd.items():
|
| 393 |
+
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
|
| 394 |
+
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
|
| 395 |
+
while len(src_key_fragments) > 0:
|
| 396 |
+
src_key_prefix = ".".join(src_key_fragments) + "."
|
| 397 |
+
if src_key_prefix in conversion_map:
|
| 398 |
+
converted_prefix = conversion_map[src_key_prefix]
|
| 399 |
+
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
|
| 400 |
+
converted_sd[converted_key] = value
|
| 401 |
+
break
|
| 402 |
+
src_key_fragments.pop(-1)
|
| 403 |
+
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
|
| 404 |
+
|
| 405 |
+
return converted_sd
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def convert_sdxl_unet_state_dict_to_diffusers(sd):
|
| 409 |
+
unet_conversion_map = make_unet_conversion_map()
|
| 410 |
+
|
| 411 |
+
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
|
| 412 |
+
return convert_unet_state_dict(sd, conversion_dict)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
|
| 416 |
+
def convert_key(key):
|
| 417 |
+
# position_idsの除去
|
| 418 |
+
if ".position_ids" in key:
|
| 419 |
+
return None
|
| 420 |
+
|
| 421 |
+
# common
|
| 422 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
| 423 |
+
key = key.replace("text_model.", "")
|
| 424 |
+
if "layers" in key:
|
| 425 |
+
# resblocks conversion
|
| 426 |
+
key = key.replace(".layers.", ".resblocks.")
|
| 427 |
+
if ".layer_norm" in key:
|
| 428 |
+
key = key.replace(".layer_norm", ".ln_")
|
| 429 |
+
elif ".mlp." in key:
|
| 430 |
+
key = key.replace(".fc1.", ".c_fc.")
|
| 431 |
+
key = key.replace(".fc2.", ".c_proj.")
|
| 432 |
+
elif ".self_attn.out_proj" in key:
|
| 433 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
| 434 |
+
elif ".self_attn." in key:
|
| 435 |
+
key = None # 特殊なので後で処理する
|
| 436 |
+
else:
|
| 437 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
| 438 |
+
elif ".position_embedding" in key:
|
| 439 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
| 440 |
+
elif ".token_embedding" in key:
|
| 441 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
| 442 |
+
elif "text_projection" in key: # no dot in key
|
| 443 |
+
key = key.replace("text_projection.weight", "text_projection")
|
| 444 |
+
elif "final_layer_norm" in key:
|
| 445 |
+
key = key.replace("final_layer_norm", "ln_final")
|
| 446 |
+
return key
|
| 447 |
+
|
| 448 |
+
keys = list(checkpoint.keys())
|
| 449 |
+
new_sd = {}
|
| 450 |
+
for key in keys:
|
| 451 |
+
new_key = convert_key(key)
|
| 452 |
+
if new_key is None:
|
| 453 |
+
continue
|
| 454 |
+
new_sd[new_key] = checkpoint[key]
|
| 455 |
+
|
| 456 |
+
# attnの変換
|
| 457 |
+
for key in keys:
|
| 458 |
+
if "layers" in key and "q_proj" in key:
|
| 459 |
+
# 三つを結合
|
| 460 |
+
key_q = key
|
| 461 |
+
key_k = key.replace("q_proj", "k_proj")
|
| 462 |
+
key_v = key.replace("q_proj", "v_proj")
|
| 463 |
+
|
| 464 |
+
value_q = checkpoint[key_q]
|
| 465 |
+
value_k = checkpoint[key_k]
|
| 466 |
+
value_v = checkpoint[key_v]
|
| 467 |
+
value = torch.cat([value_q, value_k, value_v])
|
| 468 |
+
|
| 469 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
| 470 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
| 471 |
+
new_sd[new_key] = value
|
| 472 |
+
|
| 473 |
+
if logit_scale is not None:
|
| 474 |
+
new_sd["logit_scale"] = logit_scale
|
| 475 |
+
|
| 476 |
+
return new_sd
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def save_stable_diffusion_checkpoint(
|
| 480 |
+
output_file,
|
| 481 |
+
text_encoder1,
|
| 482 |
+
text_encoder2,
|
| 483 |
+
unet,
|
| 484 |
+
epochs,
|
| 485 |
+
steps,
|
| 486 |
+
ckpt_info,
|
| 487 |
+
vae,
|
| 488 |
+
logit_scale,
|
| 489 |
+
metadata,
|
| 490 |
+
save_dtype=None,
|
| 491 |
+
):
|
| 492 |
+
state_dict = {}
|
| 493 |
+
|
| 494 |
+
def update_sd(prefix, sd):
|
| 495 |
+
for k, v in sd.items():
|
| 496 |
+
key = prefix + k
|
| 497 |
+
if save_dtype is not None:
|
| 498 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 499 |
+
state_dict[key] = v
|
| 500 |
+
|
| 501 |
+
# Convert the UNet model
|
| 502 |
+
update_sd("model.diffusion_model.", unet.state_dict())
|
| 503 |
+
|
| 504 |
+
# Convert the text encoders
|
| 505 |
+
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
|
| 506 |
+
|
| 507 |
+
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
|
| 508 |
+
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
| 509 |
+
|
| 510 |
+
# Convert the VAE
|
| 511 |
+
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
|
| 512 |
+
update_sd("first_stage_model.", vae_dict)
|
| 513 |
+
|
| 514 |
+
# Put together new checkpoint
|
| 515 |
+
key_count = len(state_dict.keys())
|
| 516 |
+
new_ckpt = {"state_dict": state_dict}
|
| 517 |
+
|
| 518 |
+
# epoch and global_step are sometimes not int
|
| 519 |
+
if ckpt_info is not None:
|
| 520 |
+
epochs += ckpt_info[0]
|
| 521 |
+
steps += ckpt_info[1]
|
| 522 |
+
|
| 523 |
+
new_ckpt["epoch"] = epochs
|
| 524 |
+
new_ckpt["global_step"] = steps
|
| 525 |
+
|
| 526 |
+
if model_util.is_safetensors(output_file):
|
| 527 |
+
save_file(state_dict, output_file, metadata)
|
| 528 |
+
else:
|
| 529 |
+
torch.save(new_ckpt, output_file)
|
| 530 |
+
|
| 531 |
+
return key_count
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
def save_diffusers_checkpoint(
|
| 535 |
+
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
| 536 |
+
):
|
| 537 |
+
from diffusers import StableDiffusionXLPipeline
|
| 538 |
+
|
| 539 |
+
# convert U-Net
|
| 540 |
+
unet_sd = unet.state_dict()
|
| 541 |
+
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
| 542 |
+
|
| 543 |
+
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
|
| 544 |
+
if save_dtype is not None:
|
| 545 |
+
diffusers_unet.to(save_dtype)
|
| 546 |
+
diffusers_unet.load_state_dict(du_unet_sd)
|
| 547 |
+
|
| 548 |
+
# create pipeline to save
|
| 549 |
+
if pretrained_model_name_or_path is None:
|
| 550 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
|
| 551 |
+
|
| 552 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
| 553 |
+
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
| 554 |
+
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
|
| 555 |
+
if vae is None:
|
| 556 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
| 557 |
+
|
| 558 |
+
# prevent local path from being saved
|
| 559 |
+
def remove_name_or_path(model):
|
| 560 |
+
if hasattr(model, "config"):
|
| 561 |
+
model.config._name_or_path = None
|
| 562 |
+
model.config._name_or_path = None
|
| 563 |
+
|
| 564 |
+
remove_name_or_path(diffusers_unet)
|
| 565 |
+
remove_name_or_path(text_encoder1)
|
| 566 |
+
remove_name_or_path(text_encoder2)
|
| 567 |
+
remove_name_or_path(scheduler)
|
| 568 |
+
remove_name_or_path(tokenizer1)
|
| 569 |
+
remove_name_or_path(tokenizer2)
|
| 570 |
+
remove_name_or_path(vae)
|
| 571 |
+
|
| 572 |
+
pipeline = StableDiffusionXLPipeline(
|
| 573 |
+
unet=diffusers_unet,
|
| 574 |
+
text_encoder=text_encoder1,
|
| 575 |
+
text_encoder_2=text_encoder2,
|
| 576 |
+
vae=vae,
|
| 577 |
+
scheduler=scheduler,
|
| 578 |
+
tokenizer=tokenizer1,
|
| 579 |
+
tokenizer_2=tokenizer2,
|
| 580 |
+
)
|
| 581 |
+
if save_dtype is not None:
|
| 582 |
+
pipeline.to(None, save_dtype)
|
| 583 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
library/sdxl_original_unet.py
ADDED
|
@@ -0,0 +1,1286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusersのコードをベースとした sd_xl_baseのU-Net
|
| 2 |
+
# state dictの形式をSDXLに合わせてある
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
| 6 |
+
params:
|
| 7 |
+
adm_in_channels: 2816
|
| 8 |
+
num_classes: sequential
|
| 9 |
+
use_checkpoint: True
|
| 10 |
+
in_channels: 4
|
| 11 |
+
out_channels: 4
|
| 12 |
+
model_channels: 320
|
| 13 |
+
attention_resolutions: [4, 2]
|
| 14 |
+
num_res_blocks: 2
|
| 15 |
+
channel_mult: [1, 2, 4]
|
| 16 |
+
num_head_channels: 64
|
| 17 |
+
use_spatial_transformer: True
|
| 18 |
+
use_linear_in_transformer: True
|
| 19 |
+
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
| 20 |
+
context_dim: 2048
|
| 21 |
+
spatial_transformer_attn_type: softmax-xformers
|
| 22 |
+
legacy: False
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
from types import SimpleNamespace
|
| 27 |
+
from typing import Any, Optional
|
| 28 |
+
import torch
|
| 29 |
+
import torch.utils.checkpoint
|
| 30 |
+
from torch import nn
|
| 31 |
+
from torch.nn import functional as F
|
| 32 |
+
from einops import rearrange
|
| 33 |
+
from .utils import setup_logging
|
| 34 |
+
|
| 35 |
+
setup_logging()
|
| 36 |
+
import logging
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
IN_CHANNELS: int = 4
|
| 41 |
+
OUT_CHANNELS: int = 4
|
| 42 |
+
ADM_IN_CHANNELS: int = 2816
|
| 43 |
+
CONTEXT_DIM: int = 2048
|
| 44 |
+
MODEL_CHANNELS: int = 320
|
| 45 |
+
TIME_EMBED_DIM = 320 * 4
|
| 46 |
+
|
| 47 |
+
USE_REENTRANT = True
|
| 48 |
+
|
| 49 |
+
# region memory efficient attention
|
| 50 |
+
|
| 51 |
+
# FlashAttentionを使うCrossAttention
|
| 52 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
| 53 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
| 54 |
+
|
| 55 |
+
# constants
|
| 56 |
+
|
| 57 |
+
EPSILON = 1e-6
|
| 58 |
+
|
| 59 |
+
# helper functions
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def exists(val):
|
| 63 |
+
return val is not None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def default(val, d):
|
| 67 |
+
return val if exists(val) else d
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# flash attention forwards and backwards
|
| 71 |
+
|
| 72 |
+
# https://arxiv.org/abs/2205.14135
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FlashAttentionFunction(torch.autograd.Function):
|
| 76 |
+
@staticmethod
|
| 77 |
+
@torch.no_grad()
|
| 78 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
| 79 |
+
"""Algorithm 2 in the paper"""
|
| 80 |
+
|
| 81 |
+
device = q.device
|
| 82 |
+
dtype = q.dtype
|
| 83 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 84 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 85 |
+
|
| 86 |
+
o = torch.zeros_like(q)
|
| 87 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
| 88 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
| 89 |
+
|
| 90 |
+
scale = q.shape[-1] ** -0.5
|
| 91 |
+
|
| 92 |
+
if not exists(mask):
|
| 93 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
| 94 |
+
else:
|
| 95 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
| 96 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
| 97 |
+
|
| 98 |
+
row_splits = zip(
|
| 99 |
+
q.split(q_bucket_size, dim=-2),
|
| 100 |
+
o.split(q_bucket_size, dim=-2),
|
| 101 |
+
mask,
|
| 102 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
| 103 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
| 107 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 108 |
+
|
| 109 |
+
col_splits = zip(
|
| 110 |
+
k.split(k_bucket_size, dim=-2),
|
| 111 |
+
v.split(k_bucket_size, dim=-2),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
| 115 |
+
k_start_index = k_ind * k_bucket_size
|
| 116 |
+
|
| 117 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 118 |
+
|
| 119 |
+
if exists(row_mask):
|
| 120 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
| 121 |
+
|
| 122 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 123 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
| 124 |
+
q_start_index - k_start_index + 1
|
| 125 |
+
)
|
| 126 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 127 |
+
|
| 128 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
| 129 |
+
attn_weights -= block_row_maxes
|
| 130 |
+
exp_weights = torch.exp(attn_weights)
|
| 131 |
+
|
| 132 |
+
if exists(row_mask):
|
| 133 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
| 134 |
+
|
| 135 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
| 136 |
+
|
| 137 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
| 138 |
+
|
| 139 |
+
exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
|
| 140 |
+
|
| 141 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
| 142 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
| 143 |
+
|
| 144 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
| 145 |
+
|
| 146 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
| 147 |
+
|
| 148 |
+
row_maxes.copy_(new_row_maxes)
|
| 149 |
+
row_sums.copy_(new_row_sums)
|
| 150 |
+
|
| 151 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
| 152 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
| 153 |
+
|
| 154 |
+
return o
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def backward(ctx, do):
|
| 159 |
+
"""Algorithm 4 in the paper"""
|
| 160 |
+
|
| 161 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
| 162 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
| 163 |
+
|
| 164 |
+
device = q.device
|
| 165 |
+
|
| 166 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 167 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 168 |
+
|
| 169 |
+
dq = torch.zeros_like(q)
|
| 170 |
+
dk = torch.zeros_like(k)
|
| 171 |
+
dv = torch.zeros_like(v)
|
| 172 |
+
|
| 173 |
+
row_splits = zip(
|
| 174 |
+
q.split(q_bucket_size, dim=-2),
|
| 175 |
+
o.split(q_bucket_size, dim=-2),
|
| 176 |
+
do.split(q_bucket_size, dim=-2),
|
| 177 |
+
mask,
|
| 178 |
+
l.split(q_bucket_size, dim=-2),
|
| 179 |
+
m.split(q_bucket_size, dim=-2),
|
| 180 |
+
dq.split(q_bucket_size, dim=-2),
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
| 184 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 185 |
+
|
| 186 |
+
col_splits = zip(
|
| 187 |
+
k.split(k_bucket_size, dim=-2),
|
| 188 |
+
v.split(k_bucket_size, dim=-2),
|
| 189 |
+
dk.split(k_bucket_size, dim=-2),
|
| 190 |
+
dv.split(k_bucket_size, dim=-2),
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
| 194 |
+
k_start_index = k_ind * k_bucket_size
|
| 195 |
+
|
| 196 |
+
attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 197 |
+
|
| 198 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 199 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
|
| 200 |
+
q_start_index - k_start_index + 1
|
| 201 |
+
)
|
| 202 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 203 |
+
|
| 204 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
| 205 |
+
|
| 206 |
+
if exists(row_mask):
|
| 207 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
| 208 |
+
|
| 209 |
+
p = exp_attn_weights / lc
|
| 210 |
+
|
| 211 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
| 212 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
| 213 |
+
|
| 214 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
| 215 |
+
ds = p * scale * (dp - D)
|
| 216 |
+
|
| 217 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
| 218 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
| 219 |
+
|
| 220 |
+
dqc.add_(dq_chunk)
|
| 221 |
+
dkc.add_(dk_chunk)
|
| 222 |
+
dvc.add_(dv_chunk)
|
| 223 |
+
|
| 224 |
+
return dq, dk, dv, None, None, None, None
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# endregion
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
| 231 |
+
return next(parameter.parameters()).dtype
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
| 235 |
+
return next(parameter.parameters()).device
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def get_timestep_embedding(
|
| 239 |
+
timesteps: torch.Tensor,
|
| 240 |
+
embedding_dim: int,
|
| 241 |
+
downscale_freq_shift: float = 1,
|
| 242 |
+
scale: float = 1,
|
| 243 |
+
max_period: int = 10000,
|
| 244 |
+
):
|
| 245 |
+
"""
|
| 246 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 247 |
+
|
| 248 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 249 |
+
These may be fractional.
|
| 250 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 251 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
| 252 |
+
"""
|
| 253 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 254 |
+
|
| 255 |
+
half_dim = embedding_dim // 2
|
| 256 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 257 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 258 |
+
|
| 259 |
+
emb = torch.exp(exponent)
|
| 260 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 261 |
+
|
| 262 |
+
# scale embeddings
|
| 263 |
+
emb = scale * emb
|
| 264 |
+
|
| 265 |
+
# concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
|
| 266 |
+
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
|
| 267 |
+
|
| 268 |
+
# zero pad
|
| 269 |
+
if embedding_dim % 2 == 1:
|
| 270 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 271 |
+
return emb
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# Deep Shrink: We do not common this function, because minimize dependencies.
|
| 275 |
+
def resize_like(x, target, mode="bicubic", align_corners=False):
|
| 276 |
+
org_dtype = x.dtype
|
| 277 |
+
if org_dtype == torch.bfloat16:
|
| 278 |
+
x = x.to(torch.float32)
|
| 279 |
+
|
| 280 |
+
if x.shape[-2:] != target.shape[-2:]:
|
| 281 |
+
if mode == "nearest":
|
| 282 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
| 283 |
+
else:
|
| 284 |
+
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
| 285 |
+
|
| 286 |
+
if org_dtype == torch.bfloat16:
|
| 287 |
+
x = x.to(org_dtype)
|
| 288 |
+
return x
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class GroupNorm32(nn.GroupNorm):
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
if self.weight.dtype != torch.float32:
|
| 294 |
+
return super().forward(x)
|
| 295 |
+
return super().forward(x.float()).type(x.dtype)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class ResnetBlock2D(nn.Module):
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
in_channels,
|
| 302 |
+
out_channels,
|
| 303 |
+
):
|
| 304 |
+
super().__init__()
|
| 305 |
+
self.in_channels = in_channels
|
| 306 |
+
self.out_channels = out_channels
|
| 307 |
+
|
| 308 |
+
self.in_layers = nn.Sequential(
|
| 309 |
+
GroupNorm32(32, in_channels),
|
| 310 |
+
nn.SiLU(),
|
| 311 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
|
| 315 |
+
|
| 316 |
+
self.out_layers = nn.Sequential(
|
| 317 |
+
GroupNorm32(32, out_channels),
|
| 318 |
+
nn.SiLU(),
|
| 319 |
+
nn.Identity(), # to make state_dict compatible with original model
|
| 320 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if in_channels != out_channels:
|
| 324 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 325 |
+
else:
|
| 326 |
+
self.skip_connection = nn.Identity()
|
| 327 |
+
|
| 328 |
+
self.gradient_checkpointing = False
|
| 329 |
+
|
| 330 |
+
def forward_body(self, x, emb):
|
| 331 |
+
h = self.in_layers(x)
|
| 332 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
| 333 |
+
h = h + emb_out[:, :, None, None]
|
| 334 |
+
h = self.out_layers(h)
|
| 335 |
+
x = self.skip_connection(x)
|
| 336 |
+
return x + h
|
| 337 |
+
|
| 338 |
+
def forward(self, x, emb):
|
| 339 |
+
if self.training and self.gradient_checkpointing:
|
| 340 |
+
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
| 341 |
+
|
| 342 |
+
def create_custom_forward(func):
|
| 343 |
+
def custom_forward(*inputs):
|
| 344 |
+
return func(*inputs)
|
| 345 |
+
|
| 346 |
+
return custom_forward
|
| 347 |
+
|
| 348 |
+
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
|
| 349 |
+
else:
|
| 350 |
+
x = self.forward_body(x, emb)
|
| 351 |
+
|
| 352 |
+
return x
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Downsample2D(nn.Module):
|
| 356 |
+
def __init__(self, channels, out_channels):
|
| 357 |
+
super().__init__()
|
| 358 |
+
|
| 359 |
+
self.channels = channels
|
| 360 |
+
self.out_channels = out_channels
|
| 361 |
+
|
| 362 |
+
self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
|
| 363 |
+
|
| 364 |
+
self.gradient_checkpointing = False
|
| 365 |
+
|
| 366 |
+
def forward_body(self, hidden_states):
|
| 367 |
+
assert hidden_states.shape[1] == self.channels
|
| 368 |
+
hidden_states = self.op(hidden_states)
|
| 369 |
+
|
| 370 |
+
return hidden_states
|
| 371 |
+
|
| 372 |
+
def forward(self, hidden_states):
|
| 373 |
+
if self.training and self.gradient_checkpointing:
|
| 374 |
+
# logger.info("Downsample2D: gradient_checkpointing")
|
| 375 |
+
|
| 376 |
+
def create_custom_forward(func):
|
| 377 |
+
def custom_forward(*inputs):
|
| 378 |
+
return func(*inputs)
|
| 379 |
+
|
| 380 |
+
return custom_forward
|
| 381 |
+
|
| 382 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 383 |
+
create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
hidden_states = self.forward_body(hidden_states)
|
| 387 |
+
|
| 388 |
+
return hidden_states
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class CrossAttention(nn.Module):
|
| 392 |
+
def __init__(
|
| 393 |
+
self,
|
| 394 |
+
query_dim: int,
|
| 395 |
+
cross_attention_dim: Optional[int] = None,
|
| 396 |
+
heads: int = 8,
|
| 397 |
+
dim_head: int = 64,
|
| 398 |
+
upcast_attention: bool = False,
|
| 399 |
+
):
|
| 400 |
+
super().__init__()
|
| 401 |
+
inner_dim = dim_head * heads
|
| 402 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 403 |
+
self.upcast_attention = upcast_attention
|
| 404 |
+
|
| 405 |
+
self.scale = dim_head**-0.5
|
| 406 |
+
self.heads = heads
|
| 407 |
+
|
| 408 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
| 409 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 410 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
|
| 411 |
+
|
| 412 |
+
self.to_out = nn.ModuleList([])
|
| 413 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
| 414 |
+
# no dropout here
|
| 415 |
+
|
| 416 |
+
self.use_memory_efficient_attention_xformers = False
|
| 417 |
+
self.use_memory_efficient_attention_mem_eff = False
|
| 418 |
+
self.use_sdpa = False
|
| 419 |
+
|
| 420 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 421 |
+
self.use_memory_efficient_attention_xformers = xformers
|
| 422 |
+
self.use_memory_efficient_attention_mem_eff = mem_eff
|
| 423 |
+
|
| 424 |
+
def set_use_sdpa(self, sdpa):
|
| 425 |
+
self.use_sdpa = sdpa
|
| 426 |
+
|
| 427 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
| 428 |
+
batch_size, seq_len, dim = tensor.shape
|
| 429 |
+
head_size = self.heads
|
| 430 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 431 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 432 |
+
return tensor
|
| 433 |
+
|
| 434 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
| 435 |
+
batch_size, seq_len, dim = tensor.shape
|
| 436 |
+
head_size = self.heads
|
| 437 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 438 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 439 |
+
return tensor
|
| 440 |
+
|
| 441 |
+
def forward(self, hidden_states, context=None, mask=None):
|
| 442 |
+
if self.use_memory_efficient_attention_xformers:
|
| 443 |
+
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
| 444 |
+
if self.use_memory_efficient_attention_mem_eff:
|
| 445 |
+
return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
|
| 446 |
+
if self.use_sdpa:
|
| 447 |
+
return self.forward_sdpa(hidden_states, context, mask)
|
| 448 |
+
|
| 449 |
+
query = self.to_q(hidden_states)
|
| 450 |
+
context = context if context is not None else hidden_states
|
| 451 |
+
key = self.to_k(context)
|
| 452 |
+
value = self.to_v(context)
|
| 453 |
+
|
| 454 |
+
query = self.reshape_heads_to_batch_dim(query)
|
| 455 |
+
key = self.reshape_heads_to_batch_dim(key)
|
| 456 |
+
value = self.reshape_heads_to_batch_dim(value)
|
| 457 |
+
|
| 458 |
+
hidden_states = self._attention(query, key, value)
|
| 459 |
+
|
| 460 |
+
# linear proj
|
| 461 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 462 |
+
# hidden_states = self.to_out[1](hidden_states) # no dropout
|
| 463 |
+
return hidden_states
|
| 464 |
+
|
| 465 |
+
def _attention(self, query, key, value):
|
| 466 |
+
if self.upcast_attention:
|
| 467 |
+
query = query.float()
|
| 468 |
+
key = key.float()
|
| 469 |
+
|
| 470 |
+
attention_scores = torch.baddbmm(
|
| 471 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
| 472 |
+
query,
|
| 473 |
+
key.transpose(-1, -2),
|
| 474 |
+
beta=0,
|
| 475 |
+
alpha=self.scale,
|
| 476 |
+
)
|
| 477 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 478 |
+
|
| 479 |
+
# cast back to the original dtype
|
| 480 |
+
attention_probs = attention_probs.to(value.dtype)
|
| 481 |
+
|
| 482 |
+
# compute attention output
|
| 483 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 484 |
+
|
| 485 |
+
# reshape hidden_states
|
| 486 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 487 |
+
return hidden_states
|
| 488 |
+
|
| 489 |
+
# TODO support Hypernetworks
|
| 490 |
+
def forward_memory_efficient_xformers(self, x, context=None, mask=None):
|
| 491 |
+
import xformers.ops
|
| 492 |
+
|
| 493 |
+
h = self.heads
|
| 494 |
+
q_in = self.to_q(x)
|
| 495 |
+
context = context if context is not None else x
|
| 496 |
+
context = context.to(x.dtype)
|
| 497 |
+
k_in = self.to_k(context)
|
| 498 |
+
v_in = self.to_v(context)
|
| 499 |
+
|
| 500 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
|
| 501 |
+
del q_in, k_in, v_in
|
| 502 |
+
|
| 503 |
+
q = q.contiguous()
|
| 504 |
+
k = k.contiguous()
|
| 505 |
+
v = v.contiguous()
|
| 506 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
| 507 |
+
del q, k, v
|
| 508 |
+
|
| 509 |
+
out = rearrange(out, "b n h d -> b n (h d)", h=h)
|
| 510 |
+
|
| 511 |
+
out = self.to_out[0](out)
|
| 512 |
+
return out
|
| 513 |
+
|
| 514 |
+
def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
|
| 515 |
+
flash_func = FlashAttentionFunction
|
| 516 |
+
|
| 517 |
+
q_bucket_size = 512
|
| 518 |
+
k_bucket_size = 1024
|
| 519 |
+
|
| 520 |
+
h = self.heads
|
| 521 |
+
q = self.to_q(x)
|
| 522 |
+
context = context if context is not None else x
|
| 523 |
+
context = context.to(x.dtype)
|
| 524 |
+
k = self.to_k(context)
|
| 525 |
+
v = self.to_v(context)
|
| 526 |
+
del context, x
|
| 527 |
+
|
| 528 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 529 |
+
|
| 530 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
| 531 |
+
|
| 532 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 533 |
+
|
| 534 |
+
out = self.to_out[0](out)
|
| 535 |
+
return out
|
| 536 |
+
|
| 537 |
+
def forward_sdpa(self, x, context=None, mask=None):
|
| 538 |
+
h = self.heads
|
| 539 |
+
q_in = self.to_q(x)
|
| 540 |
+
context = context if context is not None else x
|
| 541 |
+
context = context.to(x.dtype)
|
| 542 |
+
k_in = self.to_k(context)
|
| 543 |
+
v_in = self.to_v(context)
|
| 544 |
+
|
| 545 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
|
| 546 |
+
del q_in, k_in, v_in
|
| 547 |
+
|
| 548 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
| 549 |
+
|
| 550 |
+
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
| 551 |
+
|
| 552 |
+
out = self.to_out[0](out)
|
| 553 |
+
return out
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
# feedforward
|
| 557 |
+
class GEGLU(nn.Module):
|
| 558 |
+
r"""
|
| 559 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
| 560 |
+
|
| 561 |
+
Parameters:
|
| 562 |
+
dim_in (`int`): The number of channels in the input.
|
| 563 |
+
dim_out (`int`): The number of channels in the output.
|
| 564 |
+
"""
|
| 565 |
+
|
| 566 |
+
def __init__(self, dim_in: int, dim_out: int):
|
| 567 |
+
super().__init__()
|
| 568 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 569 |
+
|
| 570 |
+
def gelu(self, gate):
|
| 571 |
+
if gate.device.type != "mps":
|
| 572 |
+
return F.gelu(gate)
|
| 573 |
+
# mps: gelu is not implemented for float16
|
| 574 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
| 575 |
+
|
| 576 |
+
def forward(self, hidden_states):
|
| 577 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
| 578 |
+
return hidden_states * self.gelu(gate)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
class FeedForward(nn.Module):
|
| 582 |
+
def __init__(
|
| 583 |
+
self,
|
| 584 |
+
dim: int,
|
| 585 |
+
):
|
| 586 |
+
super().__init__()
|
| 587 |
+
inner_dim = int(dim * 4) # mult is always 4
|
| 588 |
+
|
| 589 |
+
self.net = nn.ModuleList([])
|
| 590 |
+
# project in
|
| 591 |
+
self.net.append(GEGLU(dim, inner_dim))
|
| 592 |
+
# project dropout
|
| 593 |
+
self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
|
| 594 |
+
# project out
|
| 595 |
+
self.net.append(nn.Linear(inner_dim, dim))
|
| 596 |
+
|
| 597 |
+
def forward(self, hidden_states):
|
| 598 |
+
for module in self.net:
|
| 599 |
+
hidden_states = module(hidden_states)
|
| 600 |
+
return hidden_states
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class BasicTransformerBlock(nn.Module):
|
| 604 |
+
def __init__(
|
| 605 |
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
|
| 606 |
+
):
|
| 607 |
+
super().__init__()
|
| 608 |
+
|
| 609 |
+
self.gradient_checkpointing = False
|
| 610 |
+
|
| 611 |
+
# 1. Self-Attn
|
| 612 |
+
self.attn1 = CrossAttention(
|
| 613 |
+
query_dim=dim,
|
| 614 |
+
cross_attention_dim=None,
|
| 615 |
+
heads=num_attention_heads,
|
| 616 |
+
dim_head=attention_head_dim,
|
| 617 |
+
upcast_attention=upcast_attention,
|
| 618 |
+
)
|
| 619 |
+
self.ff = FeedForward(dim)
|
| 620 |
+
|
| 621 |
+
# 2. Cross-Attn
|
| 622 |
+
self.attn2 = CrossAttention(
|
| 623 |
+
query_dim=dim,
|
| 624 |
+
cross_attention_dim=cross_attention_dim,
|
| 625 |
+
heads=num_attention_heads,
|
| 626 |
+
dim_head=attention_head_dim,
|
| 627 |
+
upcast_attention=upcast_attention,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 631 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 632 |
+
|
| 633 |
+
# 3. Feed-forward
|
| 634 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 635 |
+
|
| 636 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
|
| 637 |
+
self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 638 |
+
self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 639 |
+
|
| 640 |
+
def set_use_sdpa(self, sdpa: bool):
|
| 641 |
+
self.attn1.set_use_sdpa(sdpa)
|
| 642 |
+
self.attn2.set_use_sdpa(sdpa)
|
| 643 |
+
|
| 644 |
+
def forward_body(self, hidden_states, context=None, timestep=None):
|
| 645 |
+
# 1. Self-Attention
|
| 646 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 647 |
+
|
| 648 |
+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
| 649 |
+
|
| 650 |
+
# 2. Cross-Attention
|
| 651 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 652 |
+
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
| 653 |
+
|
| 654 |
+
# 3. Feed-forward
|
| 655 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
| 656 |
+
|
| 657 |
+
return hidden_states
|
| 658 |
+
|
| 659 |
+
def forward(self, hidden_states, context=None, timestep=None):
|
| 660 |
+
if self.training and self.gradient_checkpointing:
|
| 661 |
+
# logger.info("BasicTransformerBlock: checkpointing")
|
| 662 |
+
|
| 663 |
+
def create_custom_forward(func):
|
| 664 |
+
def custom_forward(*inputs):
|
| 665 |
+
return func(*inputs)
|
| 666 |
+
|
| 667 |
+
return custom_forward
|
| 668 |
+
|
| 669 |
+
output = torch.utils.checkpoint.checkpoint(
|
| 670 |
+
create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
|
| 671 |
+
)
|
| 672 |
+
else:
|
| 673 |
+
output = self.forward_body(hidden_states, context, timestep)
|
| 674 |
+
|
| 675 |
+
return output
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
class Transformer2DModel(nn.Module):
|
| 679 |
+
def __init__(
|
| 680 |
+
self,
|
| 681 |
+
num_attention_heads: int = 16,
|
| 682 |
+
attention_head_dim: int = 88,
|
| 683 |
+
in_channels: Optional[int] = None,
|
| 684 |
+
cross_attention_dim: Optional[int] = None,
|
| 685 |
+
use_linear_projection: bool = False,
|
| 686 |
+
upcast_attention: bool = False,
|
| 687 |
+
num_transformer_layers: int = 1,
|
| 688 |
+
):
|
| 689 |
+
super().__init__()
|
| 690 |
+
self.in_channels = in_channels
|
| 691 |
+
self.num_attention_heads = num_attention_heads
|
| 692 |
+
self.attention_head_dim = attention_head_dim
|
| 693 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 694 |
+
self.use_linear_projection = use_linear_projection
|
| 695 |
+
|
| 696 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 697 |
+
# self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
|
| 698 |
+
|
| 699 |
+
if use_linear_projection:
|
| 700 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 701 |
+
else:
|
| 702 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
| 703 |
+
|
| 704 |
+
blocks = []
|
| 705 |
+
for _ in range(num_transformer_layers):
|
| 706 |
+
blocks.append(
|
| 707 |
+
BasicTransformerBlock(
|
| 708 |
+
inner_dim,
|
| 709 |
+
num_attention_heads,
|
| 710 |
+
attention_head_dim,
|
| 711 |
+
cross_attention_dim=cross_attention_dim,
|
| 712 |
+
upcast_attention=upcast_attention,
|
| 713 |
+
)
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
self.transformer_blocks = nn.ModuleList(blocks)
|
| 717 |
+
|
| 718 |
+
if use_linear_projection:
|
| 719 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
| 720 |
+
else:
|
| 721 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
| 722 |
+
|
| 723 |
+
self.gradient_checkpointing = False
|
| 724 |
+
|
| 725 |
+
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
| 726 |
+
for transformer in self.transformer_blocks:
|
| 727 |
+
transformer.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 728 |
+
|
| 729 |
+
def set_use_sdpa(self, sdpa):
|
| 730 |
+
for transformer in self.transformer_blocks:
|
| 731 |
+
transformer.set_use_sdpa(sdpa)
|
| 732 |
+
|
| 733 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
|
| 734 |
+
# 1. Input
|
| 735 |
+
batch, _, height, weight = hidden_states.shape
|
| 736 |
+
residual = hidden_states
|
| 737 |
+
|
| 738 |
+
hidden_states = self.norm(hidden_states)
|
| 739 |
+
if not self.use_linear_projection:
|
| 740 |
+
hidden_states = self.proj_in(hidden_states)
|
| 741 |
+
inner_dim = hidden_states.shape[1]
|
| 742 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 743 |
+
else:
|
| 744 |
+
inner_dim = hidden_states.shape[1]
|
| 745 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
| 746 |
+
hidden_states = self.proj_in(hidden_states)
|
| 747 |
+
|
| 748 |
+
# 2. Blocks
|
| 749 |
+
for block in self.transformer_blocks:
|
| 750 |
+
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
| 751 |
+
|
| 752 |
+
# 3. Output
|
| 753 |
+
if not self.use_linear_projection:
|
| 754 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 755 |
+
hidden_states = self.proj_out(hidden_states)
|
| 756 |
+
else:
|
| 757 |
+
hidden_states = self.proj_out(hidden_states)
|
| 758 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
| 759 |
+
|
| 760 |
+
output = hidden_states + residual
|
| 761 |
+
|
| 762 |
+
return output
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class Upsample2D(nn.Module):
|
| 766 |
+
def __init__(self, channels, out_channels):
|
| 767 |
+
super().__init__()
|
| 768 |
+
self.channels = channels
|
| 769 |
+
self.out_channels = out_channels
|
| 770 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
| 771 |
+
|
| 772 |
+
self.gradient_checkpointing = False
|
| 773 |
+
|
| 774 |
+
def forward_body(self, hidden_states, output_size=None):
|
| 775 |
+
assert hidden_states.shape[1] == self.channels
|
| 776 |
+
|
| 777 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 778 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
| 779 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
| 780 |
+
dtype = hidden_states.dtype
|
| 781 |
+
if dtype == torch.bfloat16:
|
| 782 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 783 |
+
|
| 784 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
| 785 |
+
if hidden_states.shape[0] >= 64:
|
| 786 |
+
hidden_states = hidden_states.contiguous()
|
| 787 |
+
|
| 788 |
+
# if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
|
| 789 |
+
if output_size is None:
|
| 790 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 791 |
+
else:
|
| 792 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
| 793 |
+
|
| 794 |
+
# If the input is bfloat16, we cast back to bfloat16
|
| 795 |
+
if dtype == torch.bfloat16:
|
| 796 |
+
hidden_states = hidden_states.to(dtype)
|
| 797 |
+
|
| 798 |
+
hidden_states = self.conv(hidden_states)
|
| 799 |
+
|
| 800 |
+
return hidden_states
|
| 801 |
+
|
| 802 |
+
def forward(self, hidden_states, output_size=None):
|
| 803 |
+
if self.training and self.gradient_checkpointing:
|
| 804 |
+
# logger.info("Upsample2D: gradient_checkpointing")
|
| 805 |
+
|
| 806 |
+
def create_custom_forward(func):
|
| 807 |
+
def custom_forward(*inputs):
|
| 808 |
+
return func(*inputs)
|
| 809 |
+
|
| 810 |
+
return custom_forward
|
| 811 |
+
|
| 812 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 813 |
+
create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
|
| 814 |
+
)
|
| 815 |
+
else:
|
| 816 |
+
hidden_states = self.forward_body(hidden_states, output_size)
|
| 817 |
+
|
| 818 |
+
return hidden_states
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class SdxlUNet2DConditionModel(nn.Module):
|
| 822 |
+
_supports_gradient_checkpointing = True
|
| 823 |
+
|
| 824 |
+
def __init__(
|
| 825 |
+
self,
|
| 826 |
+
**kwargs,
|
| 827 |
+
):
|
| 828 |
+
super().__init__()
|
| 829 |
+
|
| 830 |
+
self.in_channels = IN_CHANNELS
|
| 831 |
+
self.out_channels = OUT_CHANNELS
|
| 832 |
+
self.model_channels = MODEL_CHANNELS
|
| 833 |
+
self.time_embed_dim = TIME_EMBED_DIM
|
| 834 |
+
self.adm_in_channels = ADM_IN_CHANNELS
|
| 835 |
+
|
| 836 |
+
self.gradient_checkpointing = False
|
| 837 |
+
# self.sample_size = sample_size
|
| 838 |
+
|
| 839 |
+
# time embedding
|
| 840 |
+
self.time_embed = nn.Sequential(
|
| 841 |
+
nn.Linear(self.model_channels, self.time_embed_dim),
|
| 842 |
+
nn.SiLU(),
|
| 843 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
# label embedding
|
| 847 |
+
self.label_emb = nn.Sequential(
|
| 848 |
+
nn.Sequential(
|
| 849 |
+
nn.Linear(self.adm_in_channels, self.time_embed_dim),
|
| 850 |
+
nn.SiLU(),
|
| 851 |
+
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
| 852 |
+
)
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# input
|
| 856 |
+
self.input_blocks = nn.ModuleList(
|
| 857 |
+
[
|
| 858 |
+
nn.Sequential(
|
| 859 |
+
nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
|
| 860 |
+
)
|
| 861 |
+
]
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
# level 0
|
| 865 |
+
for i in range(2):
|
| 866 |
+
layers = [
|
| 867 |
+
ResnetBlock2D(
|
| 868 |
+
in_channels=1 * self.model_channels,
|
| 869 |
+
out_channels=1 * self.model_channels,
|
| 870 |
+
),
|
| 871 |
+
]
|
| 872 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
| 873 |
+
|
| 874 |
+
self.input_blocks.append(
|
| 875 |
+
nn.Sequential(
|
| 876 |
+
Downsample2D(
|
| 877 |
+
channels=1 * self.model_channels,
|
| 878 |
+
out_channels=1 * self.model_channels,
|
| 879 |
+
),
|
| 880 |
+
)
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
# level 1
|
| 884 |
+
for i in range(2):
|
| 885 |
+
layers = [
|
| 886 |
+
ResnetBlock2D(
|
| 887 |
+
in_channels=(1 if i == 0 else 2) * self.model_channels,
|
| 888 |
+
out_channels=2 * self.model_channels,
|
| 889 |
+
),
|
| 890 |
+
Transformer2DModel(
|
| 891 |
+
num_attention_heads=2 * self.model_channels // 64,
|
| 892 |
+
attention_head_dim=64,
|
| 893 |
+
in_channels=2 * self.model_channels,
|
| 894 |
+
num_transformer_layers=2,
|
| 895 |
+
use_linear_projection=True,
|
| 896 |
+
cross_attention_dim=2048,
|
| 897 |
+
),
|
| 898 |
+
]
|
| 899 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
| 900 |
+
|
| 901 |
+
self.input_blocks.append(
|
| 902 |
+
nn.Sequential(
|
| 903 |
+
Downsample2D(
|
| 904 |
+
channels=2 * self.model_channels,
|
| 905 |
+
out_channels=2 * self.model_channels,
|
| 906 |
+
),
|
| 907 |
+
)
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
# level 2
|
| 911 |
+
for i in range(2):
|
| 912 |
+
layers = [
|
| 913 |
+
ResnetBlock2D(
|
| 914 |
+
in_channels=(2 if i == 0 else 4) * self.model_channels,
|
| 915 |
+
out_channels=4 * self.model_channels,
|
| 916 |
+
),
|
| 917 |
+
Transformer2DModel(
|
| 918 |
+
num_attention_heads=4 * self.model_channels // 64,
|
| 919 |
+
attention_head_dim=64,
|
| 920 |
+
in_channels=4 * self.model_channels,
|
| 921 |
+
num_transformer_layers=10,
|
| 922 |
+
use_linear_projection=True,
|
| 923 |
+
cross_attention_dim=2048,
|
| 924 |
+
),
|
| 925 |
+
]
|
| 926 |
+
self.input_blocks.append(nn.ModuleList(layers))
|
| 927 |
+
|
| 928 |
+
# mid
|
| 929 |
+
self.middle_block = nn.ModuleList(
|
| 930 |
+
[
|
| 931 |
+
ResnetBlock2D(
|
| 932 |
+
in_channels=4 * self.model_channels,
|
| 933 |
+
out_channels=4 * self.model_channels,
|
| 934 |
+
),
|
| 935 |
+
Transformer2DModel(
|
| 936 |
+
num_attention_heads=4 * self.model_channels // 64,
|
| 937 |
+
attention_head_dim=64,
|
| 938 |
+
in_channels=4 * self.model_channels,
|
| 939 |
+
num_transformer_layers=10,
|
| 940 |
+
use_linear_projection=True,
|
| 941 |
+
cross_attention_dim=2048,
|
| 942 |
+
),
|
| 943 |
+
ResnetBlock2D(
|
| 944 |
+
in_channels=4 * self.model_channels,
|
| 945 |
+
out_channels=4 * self.model_channels,
|
| 946 |
+
),
|
| 947 |
+
]
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
# output
|
| 951 |
+
self.output_blocks = nn.ModuleList([])
|
| 952 |
+
|
| 953 |
+
# level 2
|
| 954 |
+
for i in range(3):
|
| 955 |
+
layers = [
|
| 956 |
+
ResnetBlock2D(
|
| 957 |
+
in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
|
| 958 |
+
out_channels=4 * self.model_channels,
|
| 959 |
+
),
|
| 960 |
+
Transformer2DModel(
|
| 961 |
+
num_attention_heads=4 * self.model_channels // 64,
|
| 962 |
+
attention_head_dim=64,
|
| 963 |
+
in_channels=4 * self.model_channels,
|
| 964 |
+
num_transformer_layers=10,
|
| 965 |
+
use_linear_projection=True,
|
| 966 |
+
cross_attention_dim=2048,
|
| 967 |
+
),
|
| 968 |
+
]
|
| 969 |
+
if i == 2:
|
| 970 |
+
layers.append(
|
| 971 |
+
Upsample2D(
|
| 972 |
+
channels=4 * self.model_channels,
|
| 973 |
+
out_channels=4 * self.model_channels,
|
| 974 |
+
)
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
| 978 |
+
|
| 979 |
+
# level 1
|
| 980 |
+
for i in range(3):
|
| 981 |
+
layers = [
|
| 982 |
+
ResnetBlock2D(
|
| 983 |
+
in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
|
| 984 |
+
out_channels=2 * self.model_channels,
|
| 985 |
+
),
|
| 986 |
+
Transformer2DModel(
|
| 987 |
+
num_attention_heads=2 * self.model_channels // 64,
|
| 988 |
+
attention_head_dim=64,
|
| 989 |
+
in_channels=2 * self.model_channels,
|
| 990 |
+
num_transformer_layers=2,
|
| 991 |
+
use_linear_projection=True,
|
| 992 |
+
cross_attention_dim=2048,
|
| 993 |
+
),
|
| 994 |
+
]
|
| 995 |
+
if i == 2:
|
| 996 |
+
layers.append(
|
| 997 |
+
Upsample2D(
|
| 998 |
+
channels=2 * self.model_channels,
|
| 999 |
+
out_channels=2 * self.model_channels,
|
| 1000 |
+
)
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
| 1004 |
+
|
| 1005 |
+
# level 0
|
| 1006 |
+
for i in range(3):
|
| 1007 |
+
layers = [
|
| 1008 |
+
ResnetBlock2D(
|
| 1009 |
+
in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
|
| 1010 |
+
out_channels=1 * self.model_channels,
|
| 1011 |
+
),
|
| 1012 |
+
]
|
| 1013 |
+
|
| 1014 |
+
self.output_blocks.append(nn.ModuleList(layers))
|
| 1015 |
+
|
| 1016 |
+
# output
|
| 1017 |
+
self.out = nn.ModuleList(
|
| 1018 |
+
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
# region diffusers compatibility
|
| 1022 |
+
def prepare_config(self):
|
| 1023 |
+
self.config = SimpleNamespace()
|
| 1024 |
+
|
| 1025 |
+
@property
|
| 1026 |
+
def dtype(self) -> torch.dtype:
|
| 1027 |
+
# `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
| 1028 |
+
return get_parameter_dtype(self)
|
| 1029 |
+
|
| 1030 |
+
@property
|
| 1031 |
+
def device(self) -> torch.device:
|
| 1032 |
+
# `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
|
| 1033 |
+
return get_parameter_device(self)
|
| 1034 |
+
|
| 1035 |
+
def set_attention_slice(self, slice_size):
|
| 1036 |
+
raise NotImplementedError("Attention slicing is not supported for this model.")
|
| 1037 |
+
|
| 1038 |
+
def is_gradient_checkpointing(self) -> bool:
|
| 1039 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
| 1040 |
+
|
| 1041 |
+
def enable_gradient_checkpointing(self):
|
| 1042 |
+
self.gradient_checkpointing = True
|
| 1043 |
+
self.set_gradient_checkpointing(value=True)
|
| 1044 |
+
|
| 1045 |
+
def disable_gradient_checkpointing(self):
|
| 1046 |
+
self.gradient_checkpointing = False
|
| 1047 |
+
self.set_gradient_checkpointing(value=False)
|
| 1048 |
+
|
| 1049 |
+
def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
|
| 1050 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
| 1051 |
+
for block in blocks:
|
| 1052 |
+
for module in block:
|
| 1053 |
+
if hasattr(module, "set_use_memory_efficient_attention"):
|
| 1054 |
+
# logger.info(module.__class__.__name__)
|
| 1055 |
+
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
| 1056 |
+
|
| 1057 |
+
def set_use_sdpa(self, sdpa: bool) -> None:
|
| 1058 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
| 1059 |
+
for block in blocks:
|
| 1060 |
+
for module in block:
|
| 1061 |
+
if hasattr(module, "set_use_sdpa"):
|
| 1062 |
+
module.set_use_sdpa(sdpa)
|
| 1063 |
+
|
| 1064 |
+
def set_gradient_checkpointing(self, value=False):
|
| 1065 |
+
blocks = self.input_blocks + [self.middle_block] + self.output_blocks
|
| 1066 |
+
for block in blocks:
|
| 1067 |
+
for module in block.modules():
|
| 1068 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 1069 |
+
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
| 1070 |
+
module.gradient_checkpointing = value
|
| 1071 |
+
|
| 1072 |
+
# endregion
|
| 1073 |
+
|
| 1074 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
| 1075 |
+
# broadcast timesteps to batch dimension
|
| 1076 |
+
timesteps = timesteps.expand(x.shape[0])
|
| 1077 |
+
|
| 1078 |
+
hs = []
|
| 1079 |
+
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
| 1080 |
+
t_emb = t_emb.to(x.dtype)
|
| 1081 |
+
emb = self.time_embed(t_emb)
|
| 1082 |
+
|
| 1083 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
| 1084 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
| 1085 |
+
# assert x.dtype == self.dtype
|
| 1086 |
+
emb = emb + self.label_emb(y)
|
| 1087 |
+
|
| 1088 |
+
def call_module(module, h, emb, context):
|
| 1089 |
+
x = h
|
| 1090 |
+
for layer in module:
|
| 1091 |
+
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
| 1092 |
+
if isinstance(layer, ResnetBlock2D):
|
| 1093 |
+
x = layer(x, emb)
|
| 1094 |
+
elif isinstance(layer, Transformer2DModel):
|
| 1095 |
+
x = layer(x, context)
|
| 1096 |
+
else:
|
| 1097 |
+
x = layer(x)
|
| 1098 |
+
return x
|
| 1099 |
+
|
| 1100 |
+
# h = x.type(self.dtype)
|
| 1101 |
+
h = x
|
| 1102 |
+
|
| 1103 |
+
for module in self.input_blocks:
|
| 1104 |
+
h = call_module(module, h, emb, context)
|
| 1105 |
+
hs.append(h)
|
| 1106 |
+
|
| 1107 |
+
h = call_module(self.middle_block, h, emb, context)
|
| 1108 |
+
|
| 1109 |
+
for module in self.output_blocks:
|
| 1110 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 1111 |
+
h = call_module(module, h, emb, context)
|
| 1112 |
+
|
| 1113 |
+
h = h.type(x.dtype)
|
| 1114 |
+
h = call_module(self.out, h, emb, context)
|
| 1115 |
+
|
| 1116 |
+
return h
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
class InferSdxlUNet2DConditionModel:
|
| 1120 |
+
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
| 1121 |
+
self.delegate = original_unet
|
| 1122 |
+
|
| 1123 |
+
# override original model's forward method: because forward is not called by `__call__`
|
| 1124 |
+
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
| 1125 |
+
self.delegate.forward = self.forward
|
| 1126 |
+
|
| 1127 |
+
# Deep Shrink
|
| 1128 |
+
self.ds_depth_1 = None
|
| 1129 |
+
self.ds_depth_2 = None
|
| 1130 |
+
self.ds_timesteps_1 = None
|
| 1131 |
+
self.ds_timesteps_2 = None
|
| 1132 |
+
self.ds_ratio = None
|
| 1133 |
+
|
| 1134 |
+
# call original model's methods
|
| 1135 |
+
def __getattr__(self, name):
|
| 1136 |
+
return getattr(self.delegate, name)
|
| 1137 |
+
|
| 1138 |
+
def __call__(self, *args, **kwargs):
|
| 1139 |
+
return self.delegate(*args, **kwargs)
|
| 1140 |
+
|
| 1141 |
+
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
| 1142 |
+
if ds_depth_1 is None:
|
| 1143 |
+
logger.info("Deep Shrink is disabled.")
|
| 1144 |
+
self.ds_depth_1 = None
|
| 1145 |
+
self.ds_timesteps_1 = None
|
| 1146 |
+
self.ds_depth_2 = None
|
| 1147 |
+
self.ds_timesteps_2 = None
|
| 1148 |
+
self.ds_ratio = None
|
| 1149 |
+
else:
|
| 1150 |
+
logger.info(
|
| 1151 |
+
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
| 1152 |
+
)
|
| 1153 |
+
self.ds_depth_1 = ds_depth_1
|
| 1154 |
+
self.ds_timesteps_1 = ds_timesteps_1
|
| 1155 |
+
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
| 1156 |
+
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
| 1157 |
+
self.ds_ratio = ds_ratio
|
| 1158 |
+
|
| 1159 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
| 1160 |
+
r"""
|
| 1161 |
+
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
| 1162 |
+
"""
|
| 1163 |
+
_self = self.delegate
|
| 1164 |
+
|
| 1165 |
+
# broadcast timesteps to batch dimension
|
| 1166 |
+
timesteps = timesteps.expand(x.shape[0])
|
| 1167 |
+
|
| 1168 |
+
hs = []
|
| 1169 |
+
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
| 1170 |
+
t_emb = t_emb.to(x.dtype)
|
| 1171 |
+
emb = _self.time_embed(t_emb)
|
| 1172 |
+
|
| 1173 |
+
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
| 1174 |
+
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
| 1175 |
+
# assert x.dtype == _self.dtype
|
| 1176 |
+
emb = emb + _self.label_emb(y)
|
| 1177 |
+
|
| 1178 |
+
def call_module(module, h, emb, context):
|
| 1179 |
+
x = h
|
| 1180 |
+
for layer in module:
|
| 1181 |
+
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
| 1182 |
+
if isinstance(layer, ResnetBlock2D):
|
| 1183 |
+
x = layer(x, emb)
|
| 1184 |
+
elif isinstance(layer, Transformer2DModel):
|
| 1185 |
+
x = layer(x, context)
|
| 1186 |
+
else:
|
| 1187 |
+
x = layer(x)
|
| 1188 |
+
return x
|
| 1189 |
+
|
| 1190 |
+
# h = x.type(self.dtype)
|
| 1191 |
+
h = x
|
| 1192 |
+
|
| 1193 |
+
for depth, module in enumerate(_self.input_blocks):
|
| 1194 |
+
# Deep Shrink
|
| 1195 |
+
if self.ds_depth_1 is not None:
|
| 1196 |
+
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
| 1197 |
+
self.ds_depth_2 is not None
|
| 1198 |
+
and depth == self.ds_depth_2
|
| 1199 |
+
and timesteps[0] < self.ds_timesteps_1
|
| 1200 |
+
and timesteps[0] >= self.ds_timesteps_2
|
| 1201 |
+
):
|
| 1202 |
+
# print("downsample", h.shape, self.ds_ratio)
|
| 1203 |
+
org_dtype = h.dtype
|
| 1204 |
+
if org_dtype == torch.bfloat16:
|
| 1205 |
+
h = h.to(torch.float32)
|
| 1206 |
+
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
| 1207 |
+
|
| 1208 |
+
h = call_module(module, h, emb, context)
|
| 1209 |
+
hs.append(h)
|
| 1210 |
+
|
| 1211 |
+
h = call_module(_self.middle_block, h, emb, context)
|
| 1212 |
+
|
| 1213 |
+
for module in _self.output_blocks:
|
| 1214 |
+
# Deep Shrink
|
| 1215 |
+
if self.ds_depth_1 is not None:
|
| 1216 |
+
if hs[-1].shape[-2:] != h.shape[-2:]:
|
| 1217 |
+
# print("upsample", h.shape, hs[-1].shape)
|
| 1218 |
+
h = resize_like(h, hs[-1])
|
| 1219 |
+
|
| 1220 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 1221 |
+
h = call_module(module, h, emb, context)
|
| 1222 |
+
|
| 1223 |
+
# Deep Shrink: in case of depth 0
|
| 1224 |
+
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
| 1225 |
+
# print("upsample", h.shape, x.shape)
|
| 1226 |
+
h = resize_like(h, x)
|
| 1227 |
+
|
| 1228 |
+
h = h.type(x.dtype)
|
| 1229 |
+
h = call_module(_self.out, h, emb, context)
|
| 1230 |
+
|
| 1231 |
+
return h
|
| 1232 |
+
|
| 1233 |
+
|
| 1234 |
+
if __name__ == "__main__":
|
| 1235 |
+
import time
|
| 1236 |
+
|
| 1237 |
+
logger.info("create unet")
|
| 1238 |
+
unet = SdxlUNet2DConditionModel()
|
| 1239 |
+
|
| 1240 |
+
unet.to("cuda")
|
| 1241 |
+
unet.set_use_memory_efficient_attention(True, False)
|
| 1242 |
+
unet.set_gradient_checkpointing(True)
|
| 1243 |
+
unet.train()
|
| 1244 |
+
|
| 1245 |
+
# 使用メモリ量確認用の疑似学習ループ
|
| 1246 |
+
logger.info("preparing optimizer")
|
| 1247 |
+
|
| 1248 |
+
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
| 1249 |
+
|
| 1250 |
+
# import bitsandbytes
|
| 1251 |
+
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
|
| 1252 |
+
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
| 1253 |
+
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
|
| 1254 |
+
|
| 1255 |
+
import transformers
|
| 1256 |
+
|
| 1257 |
+
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
|
| 1258 |
+
|
| 1259 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
| 1260 |
+
|
| 1261 |
+
logger.info("start training")
|
| 1262 |
+
steps = 10
|
| 1263 |
+
batch_size = 1
|
| 1264 |
+
|
| 1265 |
+
for step in range(steps):
|
| 1266 |
+
logger.info(f"step {step}")
|
| 1267 |
+
if step == 1:
|
| 1268 |
+
time_start = time.perf_counter()
|
| 1269 |
+
|
| 1270 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
| 1271 |
+
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
| 1272 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
| 1273 |
+
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
|
| 1274 |
+
|
| 1275 |
+
with torch.cuda.amp.autocast(enabled=True):
|
| 1276 |
+
output = unet(x, t, ctx, y)
|
| 1277 |
+
target = torch.randn_like(output)
|
| 1278 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
| 1279 |
+
|
| 1280 |
+
scaler.scale(loss).backward()
|
| 1281 |
+
scaler.step(optimizer)
|
| 1282 |
+
scaler.update()
|
| 1283 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1284 |
+
|
| 1285 |
+
time_end = time.perf_counter()
|
| 1286 |
+
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
library/sdxl_train_util.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 8 |
+
|
| 9 |
+
init_ipex()
|
| 10 |
+
|
| 11 |
+
from accelerate import init_empty_weights
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import CLIPTokenizer
|
| 14 |
+
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
| 15 |
+
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
| 16 |
+
from .utils import setup_logging
|
| 17 |
+
|
| 18 |
+
setup_logging()
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
| 24 |
+
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
| 25 |
+
|
| 26 |
+
# DEFAULT_NOISE_OFFSET = 0.0357
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
| 30 |
+
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
| 31 |
+
for pi in range(accelerator.state.num_processes):
|
| 32 |
+
if pi == accelerator.state.local_process_index:
|
| 33 |
+
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
| 34 |
+
|
| 35 |
+
(
|
| 36 |
+
load_stable_diffusion_format,
|
| 37 |
+
text_encoder1,
|
| 38 |
+
text_encoder2,
|
| 39 |
+
vae,
|
| 40 |
+
unet,
|
| 41 |
+
logit_scale,
|
| 42 |
+
ckpt_info,
|
| 43 |
+
) = _load_target_model(
|
| 44 |
+
args.pretrained_model_name_or_path,
|
| 45 |
+
args.vae,
|
| 46 |
+
model_version,
|
| 47 |
+
weight_dtype,
|
| 48 |
+
accelerator.device if args.lowram else "cpu",
|
| 49 |
+
model_dtype,
|
| 50 |
+
args.disable_mmap_load_safetensors,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# work on low-ram device
|
| 54 |
+
if args.lowram:
|
| 55 |
+
text_encoder1.to(accelerator.device)
|
| 56 |
+
text_encoder2.to(accelerator.device)
|
| 57 |
+
unet.to(accelerator.device)
|
| 58 |
+
vae.to(accelerator.device)
|
| 59 |
+
|
| 60 |
+
clean_memory_on_device(accelerator.device)
|
| 61 |
+
accelerator.wait_for_everyone()
|
| 62 |
+
|
| 63 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _load_target_model(
|
| 67 |
+
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
|
| 68 |
+
):
|
| 69 |
+
# model_dtype only work with full fp16/bf16
|
| 70 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
| 71 |
+
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
| 72 |
+
|
| 73 |
+
if load_stable_diffusion_format:
|
| 74 |
+
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
|
| 75 |
+
(
|
| 76 |
+
text_encoder1,
|
| 77 |
+
text_encoder2,
|
| 78 |
+
vae,
|
| 79 |
+
unet,
|
| 80 |
+
logit_scale,
|
| 81 |
+
ckpt_info,
|
| 82 |
+
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
|
| 83 |
+
else:
|
| 84 |
+
# Diffusers model is loaded to CPU
|
| 85 |
+
from diffusers import StableDiffusionXLPipeline
|
| 86 |
+
|
| 87 |
+
variant = "fp16" if weight_dtype == torch.float16 else None
|
| 88 |
+
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
| 89 |
+
try:
|
| 90 |
+
try:
|
| 91 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 92 |
+
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
| 93 |
+
)
|
| 94 |
+
except EnvironmentError as ex:
|
| 95 |
+
if variant is not None:
|
| 96 |
+
logger.info("try to load fp32 model")
|
| 97 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
| 98 |
+
else:
|
| 99 |
+
raise ex
|
| 100 |
+
except EnvironmentError as ex:
|
| 101 |
+
logger.error(
|
| 102 |
+
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
| 103 |
+
)
|
| 104 |
+
raise ex
|
| 105 |
+
|
| 106 |
+
text_encoder1 = pipe.text_encoder
|
| 107 |
+
text_encoder2 = pipe.text_encoder_2
|
| 108 |
+
|
| 109 |
+
# convert to fp32 for cache text_encoders outputs
|
| 110 |
+
if text_encoder1.dtype != torch.float32:
|
| 111 |
+
text_encoder1 = text_encoder1.to(dtype=torch.float32)
|
| 112 |
+
if text_encoder2.dtype != torch.float32:
|
| 113 |
+
text_encoder2 = text_encoder2.to(dtype=torch.float32)
|
| 114 |
+
|
| 115 |
+
vae = pipe.vae
|
| 116 |
+
unet = pipe.unet
|
| 117 |
+
del pipe
|
| 118 |
+
|
| 119 |
+
# Diffusers U-Net to original U-Net
|
| 120 |
+
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
| 121 |
+
with init_empty_weights():
|
| 122 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
| 123 |
+
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
| 124 |
+
logger.info("U-Net converted to original U-Net")
|
| 125 |
+
|
| 126 |
+
logit_scale = None
|
| 127 |
+
ckpt_info = None
|
| 128 |
+
|
| 129 |
+
# VAEを読み込む
|
| 130 |
+
if vae_path is not None:
|
| 131 |
+
vae = model_util.load_vae(vae_path, weight_dtype)
|
| 132 |
+
logger.info("additional VAE loaded")
|
| 133 |
+
|
| 134 |
+
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def load_tokenizers(args: argparse.Namespace):
|
| 138 |
+
logger.info("prepare tokenizers")
|
| 139 |
+
|
| 140 |
+
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
| 141 |
+
tokeniers = []
|
| 142 |
+
for i, original_path in enumerate(original_paths):
|
| 143 |
+
tokenizer: CLIPTokenizer = None
|
| 144 |
+
if args.tokenizer_cache_dir:
|
| 145 |
+
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
| 146 |
+
if os.path.exists(local_tokenizer_path):
|
| 147 |
+
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
| 148 |
+
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
| 149 |
+
|
| 150 |
+
if tokenizer is None:
|
| 151 |
+
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
| 152 |
+
|
| 153 |
+
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
| 154 |
+
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
| 155 |
+
tokenizer.save_pretrained(local_tokenizer_path)
|
| 156 |
+
|
| 157 |
+
if i == 1:
|
| 158 |
+
tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
|
| 159 |
+
|
| 160 |
+
tokeniers.append(tokenizer)
|
| 161 |
+
|
| 162 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
| 163 |
+
logger.info(f"update token length: {args.max_token_length}")
|
| 164 |
+
|
| 165 |
+
return tokeniers
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def match_mixed_precision(args, weight_dtype):
|
| 169 |
+
if args.full_fp16:
|
| 170 |
+
assert (
|
| 171 |
+
weight_dtype == torch.float16
|
| 172 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 173 |
+
return weight_dtype
|
| 174 |
+
elif args.full_bf16:
|
| 175 |
+
assert (
|
| 176 |
+
weight_dtype == torch.bfloat16
|
| 177 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
| 178 |
+
return weight_dtype
|
| 179 |
+
else:
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 184 |
+
"""
|
| 185 |
+
Create sinusoidal timestep embeddings.
|
| 186 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 187 |
+
These may be fractional.
|
| 188 |
+
:param dim: the dimension of the output.
|
| 189 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 190 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 191 |
+
"""
|
| 192 |
+
half = dim // 2
|
| 193 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 194 |
+
device=timesteps.device
|
| 195 |
+
)
|
| 196 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 197 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 198 |
+
if dim % 2:
|
| 199 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 200 |
+
return embedding
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_timestep_embedding(x, outdim):
|
| 204 |
+
assert len(x.shape) == 2
|
| 205 |
+
b, dims = x.shape[0], x.shape[1]
|
| 206 |
+
x = torch.flatten(x)
|
| 207 |
+
emb = timestep_embedding(x, outdim)
|
| 208 |
+
emb = torch.reshape(emb, (b, dims * outdim))
|
| 209 |
+
return emb
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
| 213 |
+
emb1 = get_timestep_embedding(orig_size, 256)
|
| 214 |
+
emb2 = get_timestep_embedding(crop_size, 256)
|
| 215 |
+
emb3 = get_timestep_embedding(target_size, 256)
|
| 216 |
+
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
| 217 |
+
return vector
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def save_sd_model_on_train_end(
|
| 221 |
+
args: argparse.Namespace,
|
| 222 |
+
src_path: str,
|
| 223 |
+
save_stable_diffusion_format: bool,
|
| 224 |
+
use_safetensors: bool,
|
| 225 |
+
save_dtype: torch.dtype,
|
| 226 |
+
epoch: int,
|
| 227 |
+
global_step: int,
|
| 228 |
+
text_encoder1,
|
| 229 |
+
text_encoder2,
|
| 230 |
+
unet,
|
| 231 |
+
vae,
|
| 232 |
+
logit_scale,
|
| 233 |
+
ckpt_info,
|
| 234 |
+
):
|
| 235 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
| 236 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
| 237 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
| 238 |
+
ckpt_file,
|
| 239 |
+
text_encoder1,
|
| 240 |
+
text_encoder2,
|
| 241 |
+
unet,
|
| 242 |
+
epoch_no,
|
| 243 |
+
global_step,
|
| 244 |
+
ckpt_info,
|
| 245 |
+
vae,
|
| 246 |
+
logit_scale,
|
| 247 |
+
sai_metadata,
|
| 248 |
+
save_dtype,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
def diffusers_saver(out_dir):
|
| 252 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
| 253 |
+
out_dir,
|
| 254 |
+
text_encoder1,
|
| 255 |
+
text_encoder2,
|
| 256 |
+
unet,
|
| 257 |
+
src_path,
|
| 258 |
+
vae,
|
| 259 |
+
use_safetensors=use_safetensors,
|
| 260 |
+
save_dtype=save_dtype,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
train_util.save_sd_model_on_train_end_common(
|
| 264 |
+
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
|
| 269 |
+
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
|
| 270 |
+
def save_sd_model_on_epoch_end_or_stepwise(
|
| 271 |
+
args: argparse.Namespace,
|
| 272 |
+
on_epoch_end: bool,
|
| 273 |
+
accelerator,
|
| 274 |
+
src_path,
|
| 275 |
+
save_stable_diffusion_format: bool,
|
| 276 |
+
use_safetensors: bool,
|
| 277 |
+
save_dtype: torch.dtype,
|
| 278 |
+
epoch: int,
|
| 279 |
+
num_train_epochs: int,
|
| 280 |
+
global_step: int,
|
| 281 |
+
text_encoder1,
|
| 282 |
+
text_encoder2,
|
| 283 |
+
unet,
|
| 284 |
+
vae,
|
| 285 |
+
logit_scale,
|
| 286 |
+
ckpt_info,
|
| 287 |
+
):
|
| 288 |
+
def sd_saver(ckpt_file, epoch_no, global_step):
|
| 289 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
|
| 290 |
+
sdxl_model_util.save_stable_diffusion_checkpoint(
|
| 291 |
+
ckpt_file,
|
| 292 |
+
text_encoder1,
|
| 293 |
+
text_encoder2,
|
| 294 |
+
unet,
|
| 295 |
+
epoch_no,
|
| 296 |
+
global_step,
|
| 297 |
+
ckpt_info,
|
| 298 |
+
vae,
|
| 299 |
+
logit_scale,
|
| 300 |
+
sai_metadata,
|
| 301 |
+
save_dtype,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
def diffusers_saver(out_dir):
|
| 305 |
+
sdxl_model_util.save_diffusers_checkpoint(
|
| 306 |
+
out_dir,
|
| 307 |
+
text_encoder1,
|
| 308 |
+
text_encoder2,
|
| 309 |
+
unet,
|
| 310 |
+
src_path,
|
| 311 |
+
vae,
|
| 312 |
+
use_safetensors=use_safetensors,
|
| 313 |
+
save_dtype=save_dtype,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
| 317 |
+
args,
|
| 318 |
+
on_epoch_end,
|
| 319 |
+
accelerator,
|
| 320 |
+
save_stable_diffusion_format,
|
| 321 |
+
use_safetensors,
|
| 322 |
+
epoch,
|
| 323 |
+
num_train_epochs,
|
| 324 |
+
global_step,
|
| 325 |
+
sd_saver,
|
| 326 |
+
diffusers_saver,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--cache_text_encoder_outputs_to_disk",
|
| 336 |
+
action="store_true",
|
| 337 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
| 338 |
+
)
|
| 339 |
+
parser.add_argument(
|
| 340 |
+
"--disable_mmap_load_safetensors",
|
| 341 |
+
action="store_true",
|
| 342 |
+
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
| 347 |
+
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
| 348 |
+
if args.v_parameterization:
|
| 349 |
+
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
| 350 |
+
|
| 351 |
+
if args.clip_skip is not None:
|
| 352 |
+
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
| 353 |
+
|
| 354 |
+
# if args.multires_noise_iterations:
|
| 355 |
+
# logger.info(
|
| 356 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
| 357 |
+
# )
|
| 358 |
+
# else:
|
| 359 |
+
# if args.noise_offset is None:
|
| 360 |
+
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
| 361 |
+
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
| 362 |
+
# logger.info(
|
| 363 |
+
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
| 364 |
+
# )
|
| 365 |
+
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
| 366 |
+
|
| 367 |
+
assert (
|
| 368 |
+
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
| 369 |
+
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
| 370 |
+
|
| 371 |
+
if supportTextEncoderCaching:
|
| 372 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
| 373 |
+
args.cache_text_encoder_outputs = True
|
| 374 |
+
logger.warning(
|
| 375 |
+
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
| 376 |
+
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def sample_images(*args, **kwargs):
|
| 381 |
+
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
library/slicing_vae.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from Diffusers to reduce VRAM usage
|
| 2 |
+
|
| 3 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 26 |
+
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
| 27 |
+
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 28 |
+
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
| 29 |
+
from .utils import setup_logging
|
| 30 |
+
setup_logging()
|
| 31 |
+
import logging
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
def slice_h(x, num_slices):
|
| 35 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
| 36 |
+
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
|
| 37 |
+
# NCHWでもNHWCでもどちらでも動く
|
| 38 |
+
size = (x.shape[2] + num_slices - 1) // num_slices
|
| 39 |
+
sliced = []
|
| 40 |
+
for i in range(num_slices):
|
| 41 |
+
if i == 0:
|
| 42 |
+
sliced.append(x[:, :, : size + 1, :])
|
| 43 |
+
else:
|
| 44 |
+
end = size * (i + 1) + 1
|
| 45 |
+
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
|
| 46 |
+
end = x.shape[2]
|
| 47 |
+
sliced.append(x[:, :, size * i - 1 : end, :])
|
| 48 |
+
if end >= x.shape[2]:
|
| 49 |
+
break
|
| 50 |
+
return sliced
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def cat_h(sliced):
|
| 54 |
+
# padding分を除いて結合する
|
| 55 |
+
cat = []
|
| 56 |
+
for i, x in enumerate(sliced):
|
| 57 |
+
if i == 0:
|
| 58 |
+
cat.append(x[:, :, :-1, :])
|
| 59 |
+
elif i == len(sliced) - 1:
|
| 60 |
+
cat.append(x[:, :, 1:, :])
|
| 61 |
+
else:
|
| 62 |
+
cat.append(x[:, :, 1:-1, :])
|
| 63 |
+
del x
|
| 64 |
+
x = torch.cat(cat, dim=2)
|
| 65 |
+
return x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
| 69 |
+
assert _self.upsample is None and _self.downsample is None
|
| 70 |
+
assert _self.norm1.num_groups == _self.norm2.num_groups
|
| 71 |
+
assert temb is None
|
| 72 |
+
|
| 73 |
+
# make sure norms are on cpu
|
| 74 |
+
org_device = input_tensor.device
|
| 75 |
+
cpu_device = torch.device("cpu")
|
| 76 |
+
_self.norm1.to(cpu_device)
|
| 77 |
+
_self.norm2.to(cpu_device)
|
| 78 |
+
|
| 79 |
+
# GroupNormがCPUでfp16で動かない対策
|
| 80 |
+
org_dtype = input_tensor.dtype
|
| 81 |
+
if org_dtype == torch.float16:
|
| 82 |
+
_self.norm1.to(torch.float32)
|
| 83 |
+
_self.norm2.to(torch.float32)
|
| 84 |
+
|
| 85 |
+
# すべてのテンソルをCPUに移動する
|
| 86 |
+
input_tensor = input_tensor.to(cpu_device)
|
| 87 |
+
hidden_states = input_tensor
|
| 88 |
+
|
| 89 |
+
# どうもこれは結果が異なるようだ……
|
| 90 |
+
# def sliced_norm1(norm, x):
|
| 91 |
+
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
|
| 92 |
+
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
| 93 |
+
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
| 94 |
+
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
| 95 |
+
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
| 96 |
+
# normed_tensor = []
|
| 97 |
+
# for i in range(num_div):
|
| 98 |
+
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
| 99 |
+
# normed_tensor.append(n)
|
| 100 |
+
# del n
|
| 101 |
+
# x = torch.cat(normed_tensor, dim=1)
|
| 102 |
+
# return num_div, x
|
| 103 |
+
|
| 104 |
+
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
| 105 |
+
if org_dtype == torch.float16:
|
| 106 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 107 |
+
hidden_states = _self.norm1(hidden_states) # run on cpu
|
| 108 |
+
if org_dtype == torch.float16:
|
| 109 |
+
hidden_states = hidden_states.to(torch.float16)
|
| 110 |
+
|
| 111 |
+
sliced = slice_h(hidden_states, num_slices)
|
| 112 |
+
del hidden_states
|
| 113 |
+
|
| 114 |
+
for i in range(len(sliced)):
|
| 115 |
+
x = sliced[i]
|
| 116 |
+
sliced[i] = None
|
| 117 |
+
|
| 118 |
+
# 計算する部分だけGPUに移動する、以下同様
|
| 119 |
+
x = x.to(org_device)
|
| 120 |
+
x = _self.nonlinearity(x)
|
| 121 |
+
x = _self.conv1(x)
|
| 122 |
+
x = x.to(cpu_device)
|
| 123 |
+
sliced[i] = x
|
| 124 |
+
del x
|
| 125 |
+
|
| 126 |
+
hidden_states = cat_h(sliced)
|
| 127 |
+
del sliced
|
| 128 |
+
|
| 129 |
+
if org_dtype == torch.float16:
|
| 130 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 131 |
+
hidden_states = _self.norm2(hidden_states) # run on cpu
|
| 132 |
+
if org_dtype == torch.float16:
|
| 133 |
+
hidden_states = hidden_states.to(torch.float16)
|
| 134 |
+
|
| 135 |
+
sliced = slice_h(hidden_states, num_slices)
|
| 136 |
+
del hidden_states
|
| 137 |
+
|
| 138 |
+
for i in range(len(sliced)):
|
| 139 |
+
x = sliced[i]
|
| 140 |
+
sliced[i] = None
|
| 141 |
+
|
| 142 |
+
x = x.to(org_device)
|
| 143 |
+
x = _self.nonlinearity(x)
|
| 144 |
+
x = _self.dropout(x)
|
| 145 |
+
x = _self.conv2(x)
|
| 146 |
+
x = x.to(cpu_device)
|
| 147 |
+
sliced[i] = x
|
| 148 |
+
del x
|
| 149 |
+
|
| 150 |
+
hidden_states = cat_h(sliced)
|
| 151 |
+
del sliced
|
| 152 |
+
|
| 153 |
+
# make shortcut
|
| 154 |
+
if _self.conv_shortcut is not None:
|
| 155 |
+
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
|
| 156 |
+
del input_tensor
|
| 157 |
+
|
| 158 |
+
for i in range(len(sliced)):
|
| 159 |
+
x = sliced[i]
|
| 160 |
+
sliced[i] = None
|
| 161 |
+
|
| 162 |
+
x = x.to(org_device)
|
| 163 |
+
x = _self.conv_shortcut(x)
|
| 164 |
+
x = x.to(cpu_device)
|
| 165 |
+
sliced[i] = x
|
| 166 |
+
del x
|
| 167 |
+
|
| 168 |
+
input_tensor = torch.cat(sliced, dim=2)
|
| 169 |
+
del sliced
|
| 170 |
+
|
| 171 |
+
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
|
| 172 |
+
|
| 173 |
+
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
|
| 174 |
+
return output_tensor
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class SlicingEncoder(nn.Module):
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
in_channels=3,
|
| 181 |
+
out_channels=3,
|
| 182 |
+
down_block_types=("DownEncoderBlock2D",),
|
| 183 |
+
block_out_channels=(64,),
|
| 184 |
+
layers_per_block=2,
|
| 185 |
+
norm_num_groups=32,
|
| 186 |
+
act_fn="silu",
|
| 187 |
+
double_z=True,
|
| 188 |
+
num_slices=2,
|
| 189 |
+
):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.layers_per_block = layers_per_block
|
| 192 |
+
|
| 193 |
+
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
| 194 |
+
|
| 195 |
+
self.mid_block = None
|
| 196 |
+
self.down_blocks = nn.ModuleList([])
|
| 197 |
+
|
| 198 |
+
# down
|
| 199 |
+
output_channel = block_out_channels[0]
|
| 200 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 201 |
+
input_channel = output_channel
|
| 202 |
+
output_channel = block_out_channels[i]
|
| 203 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 204 |
+
|
| 205 |
+
down_block = get_down_block(
|
| 206 |
+
down_block_type,
|
| 207 |
+
num_layers=self.layers_per_block,
|
| 208 |
+
in_channels=input_channel,
|
| 209 |
+
out_channels=output_channel,
|
| 210 |
+
add_downsample=not is_final_block,
|
| 211 |
+
resnet_eps=1e-6,
|
| 212 |
+
downsample_padding=0,
|
| 213 |
+
resnet_act_fn=act_fn,
|
| 214 |
+
resnet_groups=norm_num_groups,
|
| 215 |
+
attention_head_dim=output_channel,
|
| 216 |
+
temb_channels=None,
|
| 217 |
+
)
|
| 218 |
+
self.down_blocks.append(down_block)
|
| 219 |
+
|
| 220 |
+
# mid
|
| 221 |
+
self.mid_block = UNetMidBlock2D(
|
| 222 |
+
in_channels=block_out_channels[-1],
|
| 223 |
+
resnet_eps=1e-6,
|
| 224 |
+
resnet_act_fn=act_fn,
|
| 225 |
+
output_scale_factor=1,
|
| 226 |
+
resnet_time_scale_shift="default",
|
| 227 |
+
attention_head_dim=block_out_channels[-1],
|
| 228 |
+
resnet_groups=norm_num_groups,
|
| 229 |
+
temb_channels=None,
|
| 230 |
+
)
|
| 231 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
| 232 |
+
|
| 233 |
+
# out
|
| 234 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 235 |
+
self.conv_act = nn.SiLU()
|
| 236 |
+
|
| 237 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
| 238 |
+
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
| 239 |
+
|
| 240 |
+
# replace forward of ResBlocks
|
| 241 |
+
def wrapper(func, module, num_slices):
|
| 242 |
+
def forward(*args, **kwargs):
|
| 243 |
+
return func(module, num_slices, *args, **kwargs)
|
| 244 |
+
|
| 245 |
+
return forward
|
| 246 |
+
|
| 247 |
+
self.num_slices = num_slices
|
| 248 |
+
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
| 249 |
+
# logger.info(f"initial divisor: {div}")
|
| 250 |
+
if div >= 2:
|
| 251 |
+
div = int(div)
|
| 252 |
+
for resnet in self.mid_block.resnets:
|
| 253 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
| 254 |
+
# midblock doesn't have downsample
|
| 255 |
+
|
| 256 |
+
for i, down_block in enumerate(self.down_blocks[::-1]):
|
| 257 |
+
if div >= 2:
|
| 258 |
+
div = int(div)
|
| 259 |
+
# logger.info(f"down block: {i} divisor: {div}")
|
| 260 |
+
for resnet in down_block.resnets:
|
| 261 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
| 262 |
+
if down_block.downsamplers is not None:
|
| 263 |
+
# logger.info("has downsample")
|
| 264 |
+
for downsample in down_block.downsamplers:
|
| 265 |
+
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
| 266 |
+
div *= 2
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
sample = x
|
| 270 |
+
del x
|
| 271 |
+
|
| 272 |
+
org_device = sample.device
|
| 273 |
+
cpu_device = torch.device("cpu")
|
| 274 |
+
|
| 275 |
+
# sample = self.conv_in(sample)
|
| 276 |
+
sample = sample.to(cpu_device)
|
| 277 |
+
sliced = slice_h(sample, self.num_slices)
|
| 278 |
+
del sample
|
| 279 |
+
|
| 280 |
+
for i in range(len(sliced)):
|
| 281 |
+
x = sliced[i]
|
| 282 |
+
sliced[i] = None
|
| 283 |
+
|
| 284 |
+
x = x.to(org_device)
|
| 285 |
+
x = self.conv_in(x)
|
| 286 |
+
x = x.to(cpu_device)
|
| 287 |
+
sliced[i] = x
|
| 288 |
+
del x
|
| 289 |
+
|
| 290 |
+
sample = cat_h(sliced)
|
| 291 |
+
del sliced
|
| 292 |
+
|
| 293 |
+
sample = sample.to(org_device)
|
| 294 |
+
|
| 295 |
+
# down
|
| 296 |
+
for down_block in self.down_blocks:
|
| 297 |
+
sample = down_block(sample)
|
| 298 |
+
|
| 299 |
+
# middle
|
| 300 |
+
sample = self.mid_block(sample)
|
| 301 |
+
|
| 302 |
+
# post-process
|
| 303 |
+
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
|
| 304 |
+
sample = self.conv_norm_out(sample)
|
| 305 |
+
sample = self.conv_act(sample)
|
| 306 |
+
sample = self.conv_out(sample)
|
| 307 |
+
|
| 308 |
+
return sample
|
| 309 |
+
|
| 310 |
+
def downsample_forward(self, _self, num_slices, hidden_states):
|
| 311 |
+
assert hidden_states.shape[1] == _self.channels
|
| 312 |
+
assert _self.use_conv and _self.padding == 0
|
| 313 |
+
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
|
| 314 |
+
|
| 315 |
+
org_device = hidden_states.device
|
| 316 |
+
cpu_device = torch.device("cpu")
|
| 317 |
+
|
| 318 |
+
hidden_states = hidden_states.to(cpu_device)
|
| 319 |
+
pad = (0, 1, 0, 1)
|
| 320 |
+
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
| 321 |
+
|
| 322 |
+
# slice with even number because of stride 2
|
| 323 |
+
# strideが2なので偶数でスライスする
|
| 324 |
+
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
| 325 |
+
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
|
| 326 |
+
size = size + 1 if size % 2 == 1 else size
|
| 327 |
+
|
| 328 |
+
sliced = []
|
| 329 |
+
for i in range(num_slices):
|
| 330 |
+
if i == 0:
|
| 331 |
+
sliced.append(hidden_states[:, :, : size + 1, :])
|
| 332 |
+
else:
|
| 333 |
+
end = size * (i + 1) + 1
|
| 334 |
+
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
|
| 335 |
+
end = hidden_states.shape[2]
|
| 336 |
+
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
|
| 337 |
+
if end >= hidden_states.shape[2]:
|
| 338 |
+
break
|
| 339 |
+
del hidden_states
|
| 340 |
+
|
| 341 |
+
for i in range(len(sliced)):
|
| 342 |
+
x = sliced[i]
|
| 343 |
+
sliced[i] = None
|
| 344 |
+
|
| 345 |
+
x = x.to(org_device)
|
| 346 |
+
x = _self.conv(x)
|
| 347 |
+
x = x.to(cpu_device)
|
| 348 |
+
|
| 349 |
+
# ここだけ雰囲気が違うのはCopilotのせい
|
| 350 |
+
if i == 0:
|
| 351 |
+
hidden_states = x
|
| 352 |
+
else:
|
| 353 |
+
hidden_states = torch.cat([hidden_states, x], dim=2)
|
| 354 |
+
|
| 355 |
+
hidden_states = hidden_states.to(org_device)
|
| 356 |
+
# logger.info(f"downsample forward done {hidden_states.shape}")
|
| 357 |
+
return hidden_states
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class SlicingDecoder(nn.Module):
|
| 361 |
+
def __init__(
|
| 362 |
+
self,
|
| 363 |
+
in_channels=3,
|
| 364 |
+
out_channels=3,
|
| 365 |
+
up_block_types=("UpDecoderBlock2D",),
|
| 366 |
+
block_out_channels=(64,),
|
| 367 |
+
layers_per_block=2,
|
| 368 |
+
norm_num_groups=32,
|
| 369 |
+
act_fn="silu",
|
| 370 |
+
num_slices=2,
|
| 371 |
+
):
|
| 372 |
+
super().__init__()
|
| 373 |
+
self.layers_per_block = layers_per_block
|
| 374 |
+
|
| 375 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
| 376 |
+
|
| 377 |
+
self.mid_block = None
|
| 378 |
+
self.up_blocks = nn.ModuleList([])
|
| 379 |
+
|
| 380 |
+
# mid
|
| 381 |
+
self.mid_block = UNetMidBlock2D(
|
| 382 |
+
in_channels=block_out_channels[-1],
|
| 383 |
+
resnet_eps=1e-6,
|
| 384 |
+
resnet_act_fn=act_fn,
|
| 385 |
+
output_scale_factor=1,
|
| 386 |
+
resnet_time_scale_shift="default",
|
| 387 |
+
attention_head_dim=block_out_channels[-1],
|
| 388 |
+
resnet_groups=norm_num_groups,
|
| 389 |
+
temb_channels=None,
|
| 390 |
+
)
|
| 391 |
+
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
| 392 |
+
|
| 393 |
+
# up
|
| 394 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 395 |
+
output_channel = reversed_block_out_channels[0]
|
| 396 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 397 |
+
prev_output_channel = output_channel
|
| 398 |
+
output_channel = reversed_block_out_channels[i]
|
| 399 |
+
|
| 400 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 401 |
+
|
| 402 |
+
up_block = get_up_block(
|
| 403 |
+
up_block_type,
|
| 404 |
+
num_layers=self.layers_per_block + 1,
|
| 405 |
+
in_channels=prev_output_channel,
|
| 406 |
+
out_channels=output_channel,
|
| 407 |
+
prev_output_channel=None,
|
| 408 |
+
add_upsample=not is_final_block,
|
| 409 |
+
resnet_eps=1e-6,
|
| 410 |
+
resnet_act_fn=act_fn,
|
| 411 |
+
resnet_groups=norm_num_groups,
|
| 412 |
+
attention_head_dim=output_channel,
|
| 413 |
+
temb_channels=None,
|
| 414 |
+
)
|
| 415 |
+
self.up_blocks.append(up_block)
|
| 416 |
+
prev_output_channel = output_channel
|
| 417 |
+
|
| 418 |
+
# out
|
| 419 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 420 |
+
self.conv_act = nn.SiLU()
|
| 421 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
| 422 |
+
|
| 423 |
+
# replace forward of ResBlocks
|
| 424 |
+
def wrapper(func, module, num_slices):
|
| 425 |
+
def forward(*args, **kwargs):
|
| 426 |
+
return func(module, num_slices, *args, **kwargs)
|
| 427 |
+
|
| 428 |
+
return forward
|
| 429 |
+
|
| 430 |
+
self.num_slices = num_slices
|
| 431 |
+
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
| 432 |
+
logger.info(f"initial divisor: {div}")
|
| 433 |
+
if div >= 2:
|
| 434 |
+
div = int(div)
|
| 435 |
+
for resnet in self.mid_block.resnets:
|
| 436 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
| 437 |
+
# midblock doesn't have upsample
|
| 438 |
+
|
| 439 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 440 |
+
if div >= 2:
|
| 441 |
+
div = int(div)
|
| 442 |
+
# logger.info(f"up block: {i} divisor: {div}")
|
| 443 |
+
for resnet in up_block.resnets:
|
| 444 |
+
resnet.forward = wrapper(resblock_forward, resnet, div)
|
| 445 |
+
if up_block.upsamplers is not None:
|
| 446 |
+
# logger.info("has upsample")
|
| 447 |
+
for upsample in up_block.upsamplers:
|
| 448 |
+
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
| 449 |
+
div *= 2
|
| 450 |
+
|
| 451 |
+
def forward(self, z):
|
| 452 |
+
sample = z
|
| 453 |
+
del z
|
| 454 |
+
sample = self.conv_in(sample)
|
| 455 |
+
|
| 456 |
+
# middle
|
| 457 |
+
sample = self.mid_block(sample)
|
| 458 |
+
|
| 459 |
+
# up
|
| 460 |
+
for i, up_block in enumerate(self.up_blocks):
|
| 461 |
+
sample = up_block(sample)
|
| 462 |
+
|
| 463 |
+
# post-process
|
| 464 |
+
sample = self.conv_norm_out(sample)
|
| 465 |
+
sample = self.conv_act(sample)
|
| 466 |
+
|
| 467 |
+
# conv_out with slicing because of VRAM usage
|
| 468 |
+
# conv_outはとてもVRAM使うのでスライスして対応
|
| 469 |
+
org_device = sample.device
|
| 470 |
+
cpu_device = torch.device("cpu")
|
| 471 |
+
sample = sample.to(cpu_device)
|
| 472 |
+
|
| 473 |
+
sliced = slice_h(sample, self.num_slices)
|
| 474 |
+
del sample
|
| 475 |
+
for i in range(len(sliced)):
|
| 476 |
+
x = sliced[i]
|
| 477 |
+
sliced[i] = None
|
| 478 |
+
|
| 479 |
+
x = x.to(org_device)
|
| 480 |
+
x = self.conv_out(x)
|
| 481 |
+
x = x.to(cpu_device)
|
| 482 |
+
sliced[i] = x
|
| 483 |
+
sample = cat_h(sliced)
|
| 484 |
+
del sliced
|
| 485 |
+
|
| 486 |
+
sample = sample.to(org_device)
|
| 487 |
+
return sample
|
| 488 |
+
|
| 489 |
+
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
|
| 490 |
+
assert hidden_states.shape[1] == _self.channels
|
| 491 |
+
assert _self.use_conv_transpose == False and _self.use_conv
|
| 492 |
+
|
| 493 |
+
org_dtype = hidden_states.dtype
|
| 494 |
+
org_device = hidden_states.device
|
| 495 |
+
cpu_device = torch.device("cpu")
|
| 496 |
+
|
| 497 |
+
hidden_states = hidden_states.to(cpu_device)
|
| 498 |
+
sliced = slice_h(hidden_states, num_slices)
|
| 499 |
+
del hidden_states
|
| 500 |
+
|
| 501 |
+
for i in range(len(sliced)):
|
| 502 |
+
x = sliced[i]
|
| 503 |
+
sliced[i] = None
|
| 504 |
+
|
| 505 |
+
x = x.to(org_device)
|
| 506 |
+
|
| 507 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
| 508 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
| 509 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
| 510 |
+
# PyTorch 2で直らないかね……
|
| 511 |
+
if org_dtype == torch.bfloat16:
|
| 512 |
+
x = x.to(torch.float32)
|
| 513 |
+
|
| 514 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 515 |
+
|
| 516 |
+
if org_dtype == torch.bfloat16:
|
| 517 |
+
x = x.to(org_dtype)
|
| 518 |
+
|
| 519 |
+
x = _self.conv(x)
|
| 520 |
+
|
| 521 |
+
# upsampleされてるのでpadは2になる
|
| 522 |
+
if i == 0:
|
| 523 |
+
x = x[:, :, :-2, :]
|
| 524 |
+
elif i == num_slices - 1:
|
| 525 |
+
x = x[:, :, 2:, :]
|
| 526 |
+
else:
|
| 527 |
+
x = x[:, :, 2:-2, :]
|
| 528 |
+
|
| 529 |
+
x = x.to(cpu_device)
|
| 530 |
+
sliced[i] = x
|
| 531 |
+
del x
|
| 532 |
+
|
| 533 |
+
hidden_states = torch.cat(sliced, dim=2)
|
| 534 |
+
# logger.info(f"us hidden_states {hidden_states.shape}")
|
| 535 |
+
del sliced
|
| 536 |
+
|
| 537 |
+
hidden_states = hidden_states.to(org_device)
|
| 538 |
+
return hidden_states
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
|
| 542 |
+
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
| 543 |
+
and Max Welling.
|
| 544 |
+
|
| 545 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
| 546 |
+
implements for all the model (such as downloading or saving, etc.)
|
| 547 |
+
|
| 548 |
+
Parameters:
|
| 549 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
| 550 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
| 551 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
| 552 |
+
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
| 553 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
| 554 |
+
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
| 555 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
| 556 |
+
obj:`(64,)`): Tuple of block output channels.
|
| 557 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 558 |
+
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
| 559 |
+
sample_size (`int`, *optional*, defaults to `32`): TODO
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
@register_to_config
|
| 563 |
+
def __init__(
|
| 564 |
+
self,
|
| 565 |
+
in_channels: int = 3,
|
| 566 |
+
out_channels: int = 3,
|
| 567 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
| 568 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
| 569 |
+
block_out_channels: Tuple[int] = (64,),
|
| 570 |
+
layers_per_block: int = 1,
|
| 571 |
+
act_fn: str = "silu",
|
| 572 |
+
latent_channels: int = 4,
|
| 573 |
+
norm_num_groups: int = 32,
|
| 574 |
+
sample_size: int = 32,
|
| 575 |
+
num_slices: int = 16,
|
| 576 |
+
):
|
| 577 |
+
super().__init__()
|
| 578 |
+
|
| 579 |
+
# pass init params to Encoder
|
| 580 |
+
self.encoder = SlicingEncoder(
|
| 581 |
+
in_channels=in_channels,
|
| 582 |
+
out_channels=latent_channels,
|
| 583 |
+
down_block_types=down_block_types,
|
| 584 |
+
block_out_channels=block_out_channels,
|
| 585 |
+
layers_per_block=layers_per_block,
|
| 586 |
+
act_fn=act_fn,
|
| 587 |
+
norm_num_groups=norm_num_groups,
|
| 588 |
+
double_z=True,
|
| 589 |
+
num_slices=num_slices,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# pass init params to Decoder
|
| 593 |
+
self.decoder = SlicingDecoder(
|
| 594 |
+
in_channels=latent_channels,
|
| 595 |
+
out_channels=out_channels,
|
| 596 |
+
up_block_types=up_block_types,
|
| 597 |
+
block_out_channels=block_out_channels,
|
| 598 |
+
layers_per_block=layers_per_block,
|
| 599 |
+
norm_num_groups=norm_num_groups,
|
| 600 |
+
act_fn=act_fn,
|
| 601 |
+
num_slices=num_slices,
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
| 605 |
+
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
| 606 |
+
self.use_slicing = False
|
| 607 |
+
|
| 608 |
+
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
| 609 |
+
h = self.encoder(x)
|
| 610 |
+
moments = self.quant_conv(h)
|
| 611 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 612 |
+
|
| 613 |
+
if not return_dict:
|
| 614 |
+
return (posterior,)
|
| 615 |
+
|
| 616 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 617 |
+
|
| 618 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 619 |
+
z = self.post_quant_conv(z)
|
| 620 |
+
dec = self.decoder(z)
|
| 621 |
+
|
| 622 |
+
if not return_dict:
|
| 623 |
+
return (dec,)
|
| 624 |
+
|
| 625 |
+
return DecoderOutput(sample=dec)
|
| 626 |
+
|
| 627 |
+
# これはバッチ方向のスライシング 紛らわしい
|
| 628 |
+
def enable_slicing(self):
|
| 629 |
+
r"""
|
| 630 |
+
Enable sliced VAE decoding.
|
| 631 |
+
|
| 632 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
| 633 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
| 634 |
+
"""
|
| 635 |
+
self.use_slicing = True
|
| 636 |
+
|
| 637 |
+
def disable_slicing(self):
|
| 638 |
+
r"""
|
| 639 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
| 640 |
+
decoding in one step.
|
| 641 |
+
"""
|
| 642 |
+
self.use_slicing = False
|
| 643 |
+
|
| 644 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 645 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 646 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 647 |
+
decoded = torch.cat(decoded_slices)
|
| 648 |
+
else:
|
| 649 |
+
decoded = self._decode(z).sample
|
| 650 |
+
|
| 651 |
+
if not return_dict:
|
| 652 |
+
return (decoded,)
|
| 653 |
+
|
| 654 |
+
return DecoderOutput(sample=decoded)
|
| 655 |
+
|
| 656 |
+
def forward(
|
| 657 |
+
self,
|
| 658 |
+
sample: torch.FloatTensor,
|
| 659 |
+
sample_posterior: bool = False,
|
| 660 |
+
return_dict: bool = True,
|
| 661 |
+
generator: Optional[torch.Generator] = None,
|
| 662 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 663 |
+
r"""
|
| 664 |
+
Args:
|
| 665 |
+
sample (`torch.FloatTensor`): Input sample.
|
| 666 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 667 |
+
Whether to sample from the posterior.
|
| 668 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 669 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 670 |
+
"""
|
| 671 |
+
x = sample
|
| 672 |
+
posterior = self.encode(x).latent_dist
|
| 673 |
+
if sample_posterior:
|
| 674 |
+
z = posterior.sample(generator=generator)
|
| 675 |
+
else:
|
| 676 |
+
z = posterior.mode()
|
| 677 |
+
dec = self.decode(z).sample
|
| 678 |
+
|
| 679 |
+
if not return_dict:
|
| 680 |
+
return (dec,)
|
| 681 |
+
|
| 682 |
+
return DecoderOutput(sample=dec)
|
library/train_util.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
library/utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import sys
|
| 3 |
+
import threading
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from typing import *
|
| 7 |
+
from diffusers import EulerAncestralDiscreteScheduler
|
| 8 |
+
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
| 9 |
+
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
| 10 |
+
import cv2
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fire_in_thread(f, *args, **kwargs):
|
| 16 |
+
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def add_logging_arguments(parser):
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--console_log_level",
|
| 22 |
+
type=str,
|
| 23 |
+
default=None,
|
| 24 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
| 25 |
+
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--console_log_file",
|
| 29 |
+
type=str,
|
| 30 |
+
default=None,
|
| 31 |
+
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def setup_logging(args=None, log_level=None, reset=False):
|
| 37 |
+
if logging.root.handlers:
|
| 38 |
+
if reset:
|
| 39 |
+
# remove all handlers
|
| 40 |
+
for handler in logging.root.handlers[:]:
|
| 41 |
+
logging.root.removeHandler(handler)
|
| 42 |
+
else:
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
| 46 |
+
if log_level is None and args is not None:
|
| 47 |
+
log_level = args.console_log_level
|
| 48 |
+
if log_level is None:
|
| 49 |
+
log_level = "INFO"
|
| 50 |
+
log_level = getattr(logging, log_level)
|
| 51 |
+
|
| 52 |
+
msg_init = None
|
| 53 |
+
if args is not None and args.console_log_file:
|
| 54 |
+
handler = logging.FileHandler(args.console_log_file, mode="w")
|
| 55 |
+
else:
|
| 56 |
+
handler = None
|
| 57 |
+
if not args or not args.console_log_simple:
|
| 58 |
+
try:
|
| 59 |
+
from rich.logging import RichHandler
|
| 60 |
+
from rich.console import Console
|
| 61 |
+
from rich.logging import RichHandler
|
| 62 |
+
|
| 63 |
+
handler = RichHandler(console=Console(stderr=True))
|
| 64 |
+
except ImportError:
|
| 65 |
+
# print("rich is not installed, using basic logging")
|
| 66 |
+
msg_init = "rich is not installed, using basic logging"
|
| 67 |
+
|
| 68 |
+
if handler is None:
|
| 69 |
+
handler = logging.StreamHandler(sys.stdout) # same as print
|
| 70 |
+
handler.propagate = False
|
| 71 |
+
|
| 72 |
+
formatter = logging.Formatter(
|
| 73 |
+
fmt="%(message)s",
|
| 74 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 75 |
+
)
|
| 76 |
+
handler.setFormatter(formatter)
|
| 77 |
+
logging.root.setLevel(log_level)
|
| 78 |
+
logging.root.addHandler(handler)
|
| 79 |
+
|
| 80 |
+
if msg_init is not None:
|
| 81 |
+
logger = logging.getLogger(__name__)
|
| 82 |
+
logger.info(msg_init)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
| 86 |
+
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
| 87 |
+
|
| 88 |
+
if has_alpha:
|
| 89 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
| 90 |
+
else:
|
| 91 |
+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
| 92 |
+
|
| 93 |
+
resized_pil = pil_image.resize(size, interpolation)
|
| 94 |
+
|
| 95 |
+
# Convert back to cv2 format
|
| 96 |
+
if has_alpha:
|
| 97 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
|
| 98 |
+
else:
|
| 99 |
+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
|
| 100 |
+
|
| 101 |
+
return resized_cv2
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# TODO make inf_utils.py
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# region Gradual Latent hires fix
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class GradualLatent:
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
ratio,
|
| 114 |
+
start_timesteps,
|
| 115 |
+
every_n_steps,
|
| 116 |
+
ratio_step,
|
| 117 |
+
s_noise=1.0,
|
| 118 |
+
gaussian_blur_ksize=None,
|
| 119 |
+
gaussian_blur_sigma=0.5,
|
| 120 |
+
gaussian_blur_strength=0.5,
|
| 121 |
+
unsharp_target_x=True,
|
| 122 |
+
):
|
| 123 |
+
self.ratio = ratio
|
| 124 |
+
self.start_timesteps = start_timesteps
|
| 125 |
+
self.every_n_steps = every_n_steps
|
| 126 |
+
self.ratio_step = ratio_step
|
| 127 |
+
self.s_noise = s_noise
|
| 128 |
+
self.gaussian_blur_ksize = gaussian_blur_ksize
|
| 129 |
+
self.gaussian_blur_sigma = gaussian_blur_sigma
|
| 130 |
+
self.gaussian_blur_strength = gaussian_blur_strength
|
| 131 |
+
self.unsharp_target_x = unsharp_target_x
|
| 132 |
+
|
| 133 |
+
def __str__(self) -> str:
|
| 134 |
+
return (
|
| 135 |
+
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
| 136 |
+
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
| 137 |
+
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
| 138 |
+
+ f"unsharp_target_x={self.unsharp_target_x})"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def apply_unshark_mask(self, x: torch.Tensor):
|
| 142 |
+
if self.gaussian_blur_ksize is None:
|
| 143 |
+
return x
|
| 144 |
+
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
| 145 |
+
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
| 146 |
+
mask = (x - blurred) * self.gaussian_blur_strength
|
| 147 |
+
sharpened = x + mask
|
| 148 |
+
return sharpened
|
| 149 |
+
|
| 150 |
+
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
| 151 |
+
org_dtype = x.dtype
|
| 152 |
+
if org_dtype == torch.bfloat16:
|
| 153 |
+
x = x.float()
|
| 154 |
+
|
| 155 |
+
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
| 156 |
+
|
| 157 |
+
# apply unsharp mask / アンシャープマスクを適用する
|
| 158 |
+
if unsharp and self.gaussian_blur_ksize:
|
| 159 |
+
x = self.apply_unshark_mask(x)
|
| 160 |
+
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
| 165 |
+
def __init__(self, *args, **kwargs):
|
| 166 |
+
super().__init__(*args, **kwargs)
|
| 167 |
+
self.resized_size = None
|
| 168 |
+
self.gradual_latent = None
|
| 169 |
+
|
| 170 |
+
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
| 171 |
+
self.resized_size = size
|
| 172 |
+
self.gradual_latent = gradual_latent
|
| 173 |
+
|
| 174 |
+
def step(
|
| 175 |
+
self,
|
| 176 |
+
model_output: torch.FloatTensor,
|
| 177 |
+
timestep: Union[float, torch.FloatTensor],
|
| 178 |
+
sample: torch.FloatTensor,
|
| 179 |
+
generator: Optional[torch.Generator] = None,
|
| 180 |
+
return_dict: bool = True,
|
| 181 |
+
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
| 182 |
+
"""
|
| 183 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
| 184 |
+
process from the learned model outputs (most often the predicted noise).
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
model_output (`torch.FloatTensor`):
|
| 188 |
+
The direct output from learned diffusion model.
|
| 189 |
+
timestep (`float`):
|
| 190 |
+
The current discrete timestep in the diffusion chain.
|
| 191 |
+
sample (`torch.FloatTensor`):
|
| 192 |
+
A current instance of a sample created by the diffusion process.
|
| 193 |
+
generator (`torch.Generator`, *optional*):
|
| 194 |
+
A random number generator.
|
| 195 |
+
return_dict (`bool`):
|
| 196 |
+
Whether or not to return a
|
| 197 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
| 201 |
+
If return_dict is `True`,
|
| 202 |
+
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
| 203 |
+
otherwise a tuple is returned where the first element is the sample tensor.
|
| 204 |
+
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
(
|
| 210 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
| 211 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
| 212 |
+
" one of the `scheduler.timesteps` as a timestep."
|
| 213 |
+
),
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if not self.is_scale_input_called:
|
| 217 |
+
# logger.warning(
|
| 218 |
+
print(
|
| 219 |
+
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
| 220 |
+
"See `StableDiffusionPipeline` for a usage example."
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if self.step_index is None:
|
| 224 |
+
self._init_step_index(timestep)
|
| 225 |
+
|
| 226 |
+
sigma = self.sigmas[self.step_index]
|
| 227 |
+
|
| 228 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 229 |
+
if self.config.prediction_type == "epsilon":
|
| 230 |
+
pred_original_sample = sample - sigma * model_output
|
| 231 |
+
elif self.config.prediction_type == "v_prediction":
|
| 232 |
+
# * c_out + input * c_skip
|
| 233 |
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
| 234 |
+
elif self.config.prediction_type == "sample":
|
| 235 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
| 238 |
+
|
| 239 |
+
sigma_from = self.sigmas[self.step_index]
|
| 240 |
+
sigma_to = self.sigmas[self.step_index + 1]
|
| 241 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
| 242 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
| 243 |
+
|
| 244 |
+
# 2. Convert to an ODE derivative
|
| 245 |
+
derivative = (sample - pred_original_sample) / sigma
|
| 246 |
+
|
| 247 |
+
dt = sigma_down - sigma
|
| 248 |
+
|
| 249 |
+
device = model_output.device
|
| 250 |
+
if self.resized_size is None:
|
| 251 |
+
prev_sample = sample + derivative * dt
|
| 252 |
+
|
| 253 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
| 254 |
+
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
| 255 |
+
)
|
| 256 |
+
s_noise = 1.0
|
| 257 |
+
else:
|
| 258 |
+
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
| 259 |
+
s_noise = self.gradual_latent.s_noise
|
| 260 |
+
|
| 261 |
+
if self.gradual_latent.unsharp_target_x:
|
| 262 |
+
prev_sample = sample + derivative * dt
|
| 263 |
+
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
| 264 |
+
else:
|
| 265 |
+
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
| 266 |
+
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
| 267 |
+
prev_sample = sample + derivative * dt
|
| 268 |
+
|
| 269 |
+
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
| 270 |
+
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
| 271 |
+
dtype=model_output.dtype,
|
| 272 |
+
device=device,
|
| 273 |
+
generator=generator,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
prev_sample = prev_sample + noise * sigma_up * s_noise
|
| 277 |
+
|
| 278 |
+
# upon completion increase step index by one
|
| 279 |
+
self._step_index += 1
|
| 280 |
+
|
| 281 |
+
if not return_dict:
|
| 282 |
+
return (prev_sample,)
|
| 283 |
+
|
| 284 |
+
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# endregion
|