MegaTronX commited on
Commit
ef00dae
·
verified ·
1 Parent(s): 084df85

Upload 25 files

Browse files
library/__init__.py ADDED
File without changes
library/adafactor_fused.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers import Adafactor
4
+
5
+ @torch.no_grad()
6
+ def adafactor_step_param(self, p, group):
7
+ if p.grad is None:
8
+ return
9
+ grad = p.grad
10
+ if grad.dtype in {torch.float16, torch.bfloat16}:
11
+ grad = grad.float()
12
+ if grad.is_sparse:
13
+ raise RuntimeError("Adafactor does not support sparse gradients.")
14
+
15
+ state = self.state[p]
16
+ grad_shape = grad.shape
17
+
18
+ factored, use_first_moment = Adafactor._get_options(group, grad_shape)
19
+ # State Initialization
20
+ if len(state) == 0:
21
+ state["step"] = 0
22
+
23
+ if use_first_moment:
24
+ # Exponential moving average of gradient values
25
+ state["exp_avg"] = torch.zeros_like(grad)
26
+ if factored:
27
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
28
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
29
+ else:
30
+ state["exp_avg_sq"] = torch.zeros_like(grad)
31
+
32
+ state["RMS"] = 0
33
+ else:
34
+ if use_first_moment:
35
+ state["exp_avg"] = state["exp_avg"].to(grad)
36
+ if factored:
37
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
38
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
39
+ else:
40
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
41
+
42
+ p_data_fp32 = p
43
+ if p.dtype in {torch.float16, torch.bfloat16}:
44
+ p_data_fp32 = p_data_fp32.float()
45
+
46
+ state["step"] += 1
47
+ state["RMS"] = Adafactor._rms(p_data_fp32)
48
+ lr = Adafactor._get_lr(group, state)
49
+
50
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
51
+ update = (grad ** 2) + group["eps"][0]
52
+ if factored:
53
+ exp_avg_sq_row = state["exp_avg_sq_row"]
54
+ exp_avg_sq_col = state["exp_avg_sq_col"]
55
+
56
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
57
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
58
+
59
+ # Approximation of exponential moving average of square of gradient
60
+ update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
61
+ update.mul_(grad)
62
+ else:
63
+ exp_avg_sq = state["exp_avg_sq"]
64
+
65
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
66
+ update = exp_avg_sq.rsqrt().mul_(grad)
67
+
68
+ update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
69
+ update.mul_(lr)
70
+
71
+ if use_first_moment:
72
+ exp_avg = state["exp_avg"]
73
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
74
+ update = exp_avg
75
+
76
+ if group["weight_decay"] != 0:
77
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
78
+
79
+ p_data_fp32.add_(-update)
80
+
81
+ if p.dtype in {torch.float16, torch.bfloat16}:
82
+ p.copy_(p_data_fp32)
83
+
84
+
85
+ @torch.no_grad()
86
+ def adafactor_step(self, closure=None):
87
+ """
88
+ Performs a single optimization step
89
+
90
+ Arguments:
91
+ closure (callable, optional): A closure that reevaluates the model
92
+ and returns the loss.
93
+ """
94
+ loss = None
95
+ if closure is not None:
96
+ loss = closure()
97
+
98
+ for group in self.param_groups:
99
+ for p in group["params"]:
100
+ adafactor_step_param(self, p, group)
101
+
102
+ return loss
103
+
104
+ def patch_adafactor_fused(optimizer: Adafactor):
105
+ optimizer.step_param = adafactor_step_param.__get__(optimizer)
106
+ optimizer.step = adafactor_step.__get__(optimizer)
library/attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
library/config_util.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import (
14
+ List,
15
+ Optional,
16
+ Sequence,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+ import toml
22
+ import voluptuous
23
+ from voluptuous import (
24
+ Any,
25
+ ExactSequence,
26
+ MultipleInvalid,
27
+ Object,
28
+ Required,
29
+ Schema,
30
+ )
31
+ from transformers import CLIPTokenizer
32
+
33
+ from . import train_util
34
+ from .train_util import (
35
+ DreamBoothSubset,
36
+ FineTuningSubset,
37
+ ControlNetSubset,
38
+ DreamBoothDataset,
39
+ FineTuningDataset,
40
+ ControlNetDataset,
41
+ DatasetGroup,
42
+ )
43
+ from .utils import setup_logging
44
+
45
+ setup_logging()
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def add_config_arguments(parser: argparse.ArgumentParser):
52
+ parser.add_argument(
53
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
54
+ )
55
+
56
+
57
+ # TODO: inherit Params class in Subset, Dataset
58
+
59
+
60
+ @dataclass
61
+ class BaseSubsetParams:
62
+ image_dir: Optional[str] = None
63
+ num_repeats: int = 1
64
+ shuffle_caption: bool = False
65
+ caption_separator: str = (",",)
66
+ keep_tokens: int = 0
67
+ keep_tokens_separator: str = (None,)
68
+ secondary_separator: Optional[str] = None
69
+ enable_wildcard: bool = False
70
+ color_aug: bool = False
71
+ flip_aug: bool = False
72
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
73
+ random_crop: bool = False
74
+ caption_prefix: Optional[str] = None
75
+ caption_suffix: Optional[str] = None
76
+ caption_dropout_rate: float = 0.0
77
+ caption_dropout_every_n_epochs: int = 0
78
+ caption_tag_dropout_rate: float = 0.0
79
+ token_warmup_min: int = 1
80
+ token_warmup_step: float = 0
81
+
82
+
83
+ @dataclass
84
+ class DreamBoothSubsetParams(BaseSubsetParams):
85
+ is_reg: bool = False
86
+ class_tokens: Optional[str] = None
87
+ caption_extension: str = ".caption"
88
+ cache_info: bool = False
89
+ alpha_mask: bool = False
90
+
91
+
92
+ @dataclass
93
+ class FineTuningSubsetParams(BaseSubsetParams):
94
+ metadata_file: Optional[str] = None
95
+ alpha_mask: bool = False
96
+
97
+
98
+ @dataclass
99
+ class ControlNetSubsetParams(BaseSubsetParams):
100
+ conditioning_data_dir: str = None
101
+ caption_extension: str = ".caption"
102
+ cache_info: bool = False
103
+
104
+
105
+ @dataclass
106
+ class BaseDatasetParams:
107
+ tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
108
+ max_token_length: int = None
109
+ resolution: Optional[Tuple[int, int]] = None
110
+ network_multiplier: float = 1.0
111
+ debug_dataset: bool = False
112
+
113
+
114
+ @dataclass
115
+ class DreamBoothDatasetParams(BaseDatasetParams):
116
+ batch_size: int = 1
117
+ enable_bucket: bool = False
118
+ min_bucket_reso: int = 256
119
+ max_bucket_reso: int = 1024
120
+ bucket_reso_steps: int = 64
121
+ bucket_no_upscale: bool = False
122
+ prior_loss_weight: float = 1.0
123
+
124
+
125
+ @dataclass
126
+ class FineTuningDatasetParams(BaseDatasetParams):
127
+ batch_size: int = 1
128
+ enable_bucket: bool = False
129
+ min_bucket_reso: int = 256
130
+ max_bucket_reso: int = 1024
131
+ bucket_reso_steps: int = 64
132
+ bucket_no_upscale: bool = False
133
+
134
+
135
+ @dataclass
136
+ class ControlNetDatasetParams(BaseDatasetParams):
137
+ batch_size: int = 1
138
+ enable_bucket: bool = False
139
+ min_bucket_reso: int = 256
140
+ max_bucket_reso: int = 1024
141
+ bucket_reso_steps: int = 64
142
+ bucket_no_upscale: bool = False
143
+
144
+
145
+ @dataclass
146
+ class SubsetBlueprint:
147
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
148
+
149
+
150
+ @dataclass
151
+ class DatasetBlueprint:
152
+ is_dreambooth: bool
153
+ is_controlnet: bool
154
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
155
+ subsets: Sequence[SubsetBlueprint]
156
+
157
+
158
+ @dataclass
159
+ class DatasetGroupBlueprint:
160
+ datasets: Sequence[DatasetBlueprint]
161
+
162
+
163
+ @dataclass
164
+ class Blueprint:
165
+ dataset_group: DatasetGroupBlueprint
166
+
167
+
168
+ class ConfigSanitizer:
169
+ # @curry
170
+ @staticmethod
171
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
172
+ Schema(ExactSequence([klass, klass]))(value)
173
+ return tuple(value)
174
+
175
+ # @curry
176
+ @staticmethod
177
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
178
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
179
+ try:
180
+ Schema(klass)(value)
181
+ return (value, value)
182
+ except:
183
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
184
+
185
+ # subset schema
186
+ SUBSET_ASCENDABLE_SCHEMA = {
187
+ "color_aug": bool,
188
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
189
+ "flip_aug": bool,
190
+ "num_repeats": int,
191
+ "random_crop": bool,
192
+ "shuffle_caption": bool,
193
+ "keep_tokens": int,
194
+ "keep_tokens_separator": str,
195
+ "secondary_separator": str,
196
+ "caption_separator": str,
197
+ "enable_wildcard": bool,
198
+ "token_warmup_min": int,
199
+ "token_warmup_step": Any(float, int),
200
+ "caption_prefix": str,
201
+ "caption_suffix": str,
202
+ }
203
+ # DO means DropOut
204
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_dropout_every_n_epochs": int,
206
+ "caption_dropout_rate": Any(float, int),
207
+ "caption_tag_dropout_rate": Any(float, int),
208
+ }
209
+ # DB means DreamBooth
210
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
211
+ "caption_extension": str,
212
+ "class_tokens": str,
213
+ "cache_info": bool,
214
+ }
215
+ DB_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("image_dir"): str,
217
+ "is_reg": bool,
218
+ "alpha_mask": bool,
219
+ }
220
+ # FT means FineTuning
221
+ FT_SUBSET_DISTINCT_SCHEMA = {
222
+ Required("metadata_file"): str,
223
+ "image_dir": str,
224
+ "alpha_mask": bool,
225
+ }
226
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
227
+ "caption_extension": str,
228
+ "cache_info": bool,
229
+ }
230
+ CN_SUBSET_DISTINCT_SCHEMA = {
231
+ Required("image_dir"): str,
232
+ Required("conditioning_data_dir"): str,
233
+ }
234
+
235
+ # datasets schema
236
+ DATASET_ASCENDABLE_SCHEMA = {
237
+ "batch_size": int,
238
+ "bucket_no_upscale": bool,
239
+ "bucket_reso_steps": int,
240
+ "enable_bucket": bool,
241
+ "max_bucket_reso": int,
242
+ "min_bucket_reso": int,
243
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
244
+ "network_multiplier": float,
245
+ }
246
+
247
+ # options handled by argparse but not handled by user config
248
+ ARGPARSE_SPECIFIC_SCHEMA = {
249
+ "debug_dataset": bool,
250
+ "max_token_length": Any(None, int),
251
+ "prior_loss_weight": Any(float, int),
252
+ }
253
+ # for handling default None value of argparse
254
+ ARGPARSE_NULLABLE_OPTNAMES = [
255
+ "face_crop_aug_range",
256
+ "resolution",
257
+ ]
258
+ # prepare map because option name may differ among argparse and user config
259
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
260
+ "train_batch_size": "batch_size",
261
+ "dataset_repeats": "num_repeats",
262
+ }
263
+
264
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
265
+ assert support_dreambooth or support_finetuning or support_controlnet, (
266
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
267
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
268
+ )
269
+
270
+ self.db_subset_schema = self.__merge_dict(
271
+ self.SUBSET_ASCENDABLE_SCHEMA,
272
+ self.DB_SUBSET_DISTINCT_SCHEMA,
273
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.ft_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.FT_SUBSET_DISTINCT_SCHEMA,
280
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
281
+ )
282
+
283
+ self.cn_subset_schema = self.__merge_dict(
284
+ self.SUBSET_ASCENDABLE_SCHEMA,
285
+ self.CN_SUBSET_DISTINCT_SCHEMA,
286
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
288
+ )
289
+
290
+ self.db_dataset_schema = self.__merge_dict(
291
+ self.DATASET_ASCENDABLE_SCHEMA,
292
+ self.SUBSET_ASCENDABLE_SCHEMA,
293
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
294
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
295
+ {"subsets": [self.db_subset_schema]},
296
+ )
297
+
298
+ self.ft_dataset_schema = self.__merge_dict(
299
+ self.DATASET_ASCENDABLE_SCHEMA,
300
+ self.SUBSET_ASCENDABLE_SCHEMA,
301
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
302
+ {"subsets": [self.ft_subset_schema]},
303
+ )
304
+
305
+ self.cn_dataset_schema = self.__merge_dict(
306
+ self.DATASET_ASCENDABLE_SCHEMA,
307
+ self.SUBSET_ASCENDABLE_SCHEMA,
308
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
309
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
310
+ {"subsets": [self.cn_subset_schema]},
311
+ )
312
+
313
+ if support_dreambooth and support_finetuning:
314
+
315
+ def validate_flex_dataset(dataset_config: dict):
316
+ subsets_config = dataset_config.get("subsets", [])
317
+
318
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
319
+ return Schema(self.cn_dataset_schema)(dataset_config)
320
+ # check dataset meets FT style
321
+ # NOTE: all FT subsets should have "metadata_file"
322
+ elif all(["metadata_file" in subset for subset in subsets_config]):
323
+ return Schema(self.ft_dataset_schema)(dataset_config)
324
+ # check dataset meets DB style
325
+ # NOTE: all DB subsets should have no "metadata_file"
326
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
327
+ return Schema(self.db_dataset_schema)(dataset_config)
328
+ else:
329
+ raise voluptuous.Invalid(
330
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
331
+ )
332
+
333
+ self.dataset_schema = validate_flex_dataset
334
+ elif support_dreambooth:
335
+ if support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+ else:
338
+ self.dataset_schema = self.db_dataset_schema
339
+ elif support_finetuning:
340
+ self.dataset_schema = self.ft_dataset_schema
341
+ elif support_controlnet:
342
+ self.dataset_schema = self.cn_dataset_schema
343
+
344
+ self.general_schema = self.__merge_dict(
345
+ self.DATASET_ASCENDABLE_SCHEMA,
346
+ self.SUBSET_ASCENDABLE_SCHEMA,
347
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
348
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
349
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
350
+ )
351
+
352
+ self.user_config_validator = Schema(
353
+ {
354
+ "general": self.general_schema,
355
+ "datasets": [self.dataset_schema],
356
+ }
357
+ )
358
+
359
+ self.argparse_schema = self.__merge_dict(
360
+ self.general_schema,
361
+ self.ARGPARSE_SPECIFIC_SCHEMA,
362
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
363
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
364
+ )
365
+
366
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
367
+
368
+ def sanitize_user_config(self, user_config: dict) -> dict:
369
+ try:
370
+ return self.user_config_validator(user_config)
371
+ except MultipleInvalid:
372
+ # TODO: エラー発生時のメッセージをわかりやすくする
373
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
374
+ raise
375
+
376
+ # NOTE: In nature, argument parser result is not needed to be sanitize
377
+ # However this will help us to detect program bug
378
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
379
+ try:
380
+ return self.argparse_config_validator(argparse_namespace)
381
+ except MultipleInvalid:
382
+ # XXX: this should be a bug
383
+ logger.error(
384
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
385
+ )
386
+ raise
387
+
388
+ # NOTE: value would be overwritten by latter dict if there is already the same key
389
+ @staticmethod
390
+ def __merge_dict(*dict_list: dict) -> dict:
391
+ merged = {}
392
+ for schema in dict_list:
393
+ # merged |= schema
394
+ for k, v in schema.items():
395
+ merged[k] = v
396
+ return merged
397
+
398
+
399
+ class BlueprintGenerator:
400
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
401
+
402
+ def __init__(self, sanitizer: ConfigSanitizer):
403
+ self.sanitizer = sanitizer
404
+
405
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
406
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
407
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
408
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
409
+
410
+ # convert argparse namespace to dict like config
411
+ # NOTE: it is ok to have extra entries in dict
412
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
413
+ argparse_config = {
414
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
415
+ }
416
+
417
+ general_config = sanitized_user_config.get("general", {})
418
+
419
+ dataset_blueprints = []
420
+ for dataset_config in sanitized_user_config.get("datasets", []):
421
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
422
+ subsets = dataset_config.get("subsets", [])
423
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
424
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
425
+ if is_controlnet:
426
+ subset_params_klass = ControlNetSubsetParams
427
+ dataset_params_klass = ControlNetDatasetParams
428
+ elif is_dreambooth:
429
+ subset_params_klass = DreamBoothSubsetParams
430
+ dataset_params_klass = DreamBoothDatasetParams
431
+ else:
432
+ subset_params_klass = FineTuningSubsetParams
433
+ dataset_params_klass = FineTuningDatasetParams
434
+
435
+ subset_blueprints = []
436
+ for subset_config in subsets:
437
+ params = self.generate_params_by_fallbacks(
438
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
439
+ )
440
+ subset_blueprints.append(SubsetBlueprint(params))
441
+
442
+ params = self.generate_params_by_fallbacks(
443
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
444
+ )
445
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
446
+
447
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
448
+
449
+ return Blueprint(dataset_group_blueprint)
450
+
451
+ @staticmethod
452
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
453
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
454
+ search_value = BlueprintGenerator.search_value
455
+ default_params = asdict(param_klass())
456
+ param_names = default_params.keys()
457
+
458
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
459
+
460
+ return param_klass(**params)
461
+
462
+ @staticmethod
463
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
464
+ for cand in fallbacks:
465
+ value = cand.get(key)
466
+ if value is not None:
467
+ return value
468
+
469
+ return default_value
470
+
471
+
472
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
473
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
474
+
475
+ for dataset_blueprint in dataset_group_blueprint.datasets:
476
+ if dataset_blueprint.is_controlnet:
477
+ subset_klass = ControlNetSubset
478
+ dataset_klass = ControlNetDataset
479
+ elif dataset_blueprint.is_dreambooth:
480
+ subset_klass = DreamBoothSubset
481
+ dataset_klass = DreamBoothDataset
482
+ else:
483
+ subset_klass = FineTuningSubset
484
+ dataset_klass = FineTuningDataset
485
+
486
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
487
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
488
+ datasets.append(dataset)
489
+
490
+ # print info
491
+ info = ""
492
+ for i, dataset in enumerate(datasets):
493
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
494
+ is_controlnet = isinstance(dataset, ControlNetDataset)
495
+ info += dedent(
496
+ f"""\
497
+ [Dataset {i}]
498
+ batch_size: {dataset.batch_size}
499
+ resolution: {(dataset.width, dataset.height)}
500
+ enable_bucket: {dataset.enable_bucket}
501
+ network_multiplier: {dataset.network_multiplier}
502
+ """
503
+ )
504
+
505
+ if dataset.enable_bucket:
506
+ info += indent(
507
+ dedent(
508
+ f"""\
509
+ min_bucket_reso: {dataset.min_bucket_reso}
510
+ max_bucket_reso: {dataset.max_bucket_reso}
511
+ bucket_reso_steps: {dataset.bucket_reso_steps}
512
+ bucket_no_upscale: {dataset.bucket_no_upscale}
513
+ \n"""
514
+ ),
515
+ " ",
516
+ )
517
+ else:
518
+ info += "\n"
519
+
520
+ for j, subset in enumerate(dataset.subsets):
521
+ info += indent(
522
+ dedent(
523
+ f"""\
524
+ [Subset {j} of Dataset {i}]
525
+ image_dir: "{subset.image_dir}"
526
+ image_count: {subset.img_count}
527
+ num_repeats: {subset.num_repeats}
528
+ shuffle_caption: {subset.shuffle_caption}
529
+ keep_tokens: {subset.keep_tokens}
530
+ keep_tokens_separator: {subset.keep_tokens_separator}
531
+ caption_separator: {subset.caption_separator}
532
+ secondary_separator: {subset.secondary_separator}
533
+ enable_wildcard: {subset.enable_wildcard}
534
+ caption_dropout_rate: {subset.caption_dropout_rate}
535
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
536
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
537
+ caption_prefix: {subset.caption_prefix}
538
+ caption_suffix: {subset.caption_suffix}
539
+ color_aug: {subset.color_aug}
540
+ flip_aug: {subset.flip_aug}
541
+ face_crop_aug_range: {subset.face_crop_aug_range}
542
+ random_crop: {subset.random_crop}
543
+ token_warmup_min: {subset.token_warmup_min},
544
+ token_warmup_step: {subset.token_warmup_step},
545
+ alpha_mask: {subset.alpha_mask},
546
+ """
547
+ ),
548
+ " ",
549
+ )
550
+
551
+ if is_dreambooth:
552
+ info += indent(
553
+ dedent(
554
+ f"""\
555
+ is_reg: {subset.is_reg}
556
+ class_tokens: {subset.class_tokens}
557
+ caption_extension: {subset.caption_extension}
558
+ \n"""
559
+ ),
560
+ " ",
561
+ )
562
+ elif not is_controlnet:
563
+ info += indent(
564
+ dedent(
565
+ f"""\
566
+ metadata_file: {subset.metadata_file}
567
+ \n"""
568
+ ),
569
+ " ",
570
+ )
571
+
572
+ logger.info(f"{info}")
573
+
574
+ # make buckets first because it determines the length of dataset
575
+ # and set the same seed for all datasets
576
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
577
+ for i, dataset in enumerate(datasets):
578
+ logger.info(f"[Dataset {i}]")
579
+ dataset.make_buckets()
580
+ dataset.set_seed(seed)
581
+
582
+ return DatasetGroup(datasets)
583
+
584
+
585
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
586
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
587
+ tokens = name.split("_")
588
+ try:
589
+ n_repeats = int(tokens[0])
590
+ except ValueError as e:
591
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
592
+ return 0, ""
593
+ caption_by_folder = "_".join(tokens[1:])
594
+ return n_repeats, caption_by_folder
595
+
596
+ def generate(base_dir: Optional[str], is_reg: bool):
597
+ if base_dir is None:
598
+ return []
599
+
600
+ base_dir: Path = Path(base_dir)
601
+ if not base_dir.is_dir():
602
+ return []
603
+
604
+ subsets_config = []
605
+ for subdir in base_dir.iterdir():
606
+ if not subdir.is_dir():
607
+ continue
608
+
609
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
610
+ if num_repeats < 1:
611
+ continue
612
+
613
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
614
+ subsets_config.append(subset_config)
615
+
616
+ return subsets_config
617
+
618
+ subsets_config = []
619
+ subsets_config += generate(train_data_dir, False)
620
+ subsets_config += generate(reg_data_dir, True)
621
+
622
+ return subsets_config
623
+
624
+
625
+ def generate_controlnet_subsets_config_by_subdirs(
626
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
627
+ ):
628
+ def generate(base_dir: Optional[str]):
629
+ if base_dir is None:
630
+ return []
631
+
632
+ base_dir: Path = Path(base_dir)
633
+ if not base_dir.is_dir():
634
+ return []
635
+
636
+ subsets_config = []
637
+ subset_config = {
638
+ "image_dir": train_data_dir,
639
+ "conditioning_data_dir": conditioning_data_dir,
640
+ "caption_extension": caption_extension,
641
+ "num_repeats": 1,
642
+ }
643
+ subsets_config.append(subset_config)
644
+
645
+ return subsets_config
646
+
647
+ subsets_config = []
648
+ subsets_config += generate(train_data_dir)
649
+
650
+ return subsets_config
651
+
652
+
653
+ def load_user_config(file: str) -> dict:
654
+ file: Path = Path(file)
655
+ if not file.is_file():
656
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
657
+
658
+ if file.name.lower().endswith(".json"):
659
+ try:
660
+ with open(file, "r") as f:
661
+ config = json.load(f)
662
+ except Exception:
663
+ logger.error(
664
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
665
+ )
666
+ raise
667
+ elif file.name.lower().endswith(".toml"):
668
+ try:
669
+ config = toml.load(file)
670
+ except Exception:
671
+ logger.error(
672
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
673
+ )
674
+ raise
675
+ else:
676
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
677
+
678
+ return config
679
+
680
+
681
+ # for config test
682
+ if __name__ == "__main__":
683
+ parser = argparse.ArgumentParser()
684
+ parser.add_argument("--support_dreambooth", action="store_true")
685
+ parser.add_argument("--support_finetuning", action="store_true")
686
+ parser.add_argument("--support_controlnet", action="store_true")
687
+ parser.add_argument("--support_dropout", action="store_true")
688
+ parser.add_argument("dataset_config")
689
+ config_args, remain = parser.parse_known_args()
690
+
691
+ parser = argparse.ArgumentParser()
692
+ train_util.add_dataset_arguments(
693
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
694
+ )
695
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
696
+ argparse_namespace = parser.parse_args(remain)
697
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
698
+
699
+ logger.info("[argparse_namespace]")
700
+ logger.info(f"{vars(argparse_namespace)}")
701
+
702
+ user_config = load_user_config(config_args.dataset_config)
703
+
704
+ logger.info("")
705
+ logger.info("[user_config]")
706
+ logger.info(f"{user_config}")
707
+
708
+ sanitizer = ConfigSanitizer(
709
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
710
+ )
711
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
712
+
713
+ logger.info("")
714
+ logger.info("[sanitized_user_config]")
715
+ logger.info(f"{sanitized_user_config}")
716
+
717
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
718
+
719
+ logger.info("")
720
+ logger.info("[blueprint]")
721
+ logger.info(f"{blueprint}")
library/custom_train_functions.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
15
+ if hasattr(noise_scheduler, "all_snr"):
16
+ return
17
+
18
+ alphas_cumprod = noise_scheduler.alphas_cumprod
19
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
20
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
21
+ alpha = sqrt_alphas_cumprod
22
+ sigma = sqrt_one_minus_alphas_cumprod
23
+ all_snr = (alpha / sigma) ** 2
24
+
25
+ noise_scheduler.all_snr = all_snr.to(device)
26
+
27
+
28
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
29
+ # fix beta: zero terminal SNR
30
+ logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
31
+
32
+ def enforce_zero_terminal_snr(betas):
33
+ # Convert betas to alphas_bar_sqrt
34
+ alphas = 1 - betas
35
+ alphas_bar = alphas.cumprod(0)
36
+ alphas_bar_sqrt = alphas_bar.sqrt()
37
+
38
+ # Store old values.
39
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
40
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
41
+ # Shift so last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+ # Scale so first timestep is back to old value.
44
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
45
+
46
+ # Convert alphas_bar_sqrt to betas
47
+ alphas_bar = alphas_bar_sqrt**2
48
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
49
+ alphas = torch.cat([alphas_bar[0:1], alphas])
50
+ betas = 1 - alphas
51
+ return betas
52
+
53
+ betas = noise_scheduler.betas
54
+ betas = enforce_zero_terminal_snr(betas)
55
+ alphas = 1.0 - betas
56
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
57
+
58
+ # logger.info(f"original: {noise_scheduler.betas}")
59
+ # logger.info(f"fixed: {betas}")
60
+
61
+ noise_scheduler.betas = betas
62
+ noise_scheduler.alphas = alphas
63
+ noise_scheduler.alphas_cumprod = alphas_cumprod
64
+
65
+
66
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
67
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
68
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
69
+ if v_prediction:
70
+ snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
71
+ else:
72
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
73
+ loss = loss * snr_weight
74
+ return loss
75
+
76
+
77
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
78
+ scale = get_snr_scale(timesteps, noise_scheduler)
79
+ loss = loss * scale
80
+ return loss
81
+
82
+
83
+ def get_snr_scale(timesteps, noise_scheduler):
84
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
85
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
86
+ scale = snr_t / (snr_t + 1)
87
+ # # show debug info
88
+ # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
89
+ return scale
90
+
91
+
92
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
93
+ scale = get_snr_scale(timesteps, noise_scheduler)
94
+ # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
95
+ loss = loss + loss / scale * v_pred_like_loss
96
+ return loss
97
+
98
+
99
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
100
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102
+ if v_prediction:
103
+ weight = 1 / (snr_t + 1)
104
+ else:
105
+ weight = 1 / torch.sqrt(snr_t)
106
+ loss = weight * loss
107
+ return loss
108
+
109
+
110
+ # TODO train_utilと分散しているのでどちらかに寄せる
111
+
112
+
113
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
114
+ parser.add_argument(
115
+ "--min_snr_gamma",
116
+ type=float,
117
+ default=None,
118
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
119
+ )
120
+ parser.add_argument(
121
+ "--scale_v_pred_loss_like_noise_pred",
122
+ action="store_true",
123
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
124
+ )
125
+ parser.add_argument(
126
+ "--v_pred_like_loss",
127
+ type=float,
128
+ default=None,
129
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
130
+ )
131
+ parser.add_argument(
132
+ "--debiased_estimation_loss",
133
+ action="store_true",
134
+ help="debiased estimation loss / debiased estimation loss",
135
+ )
136
+ if support_weighted_captions:
137
+ parser.add_argument(
138
+ "--weighted_captions",
139
+ action="store_true",
140
+ default=False,
141
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
142
+ )
143
+
144
+
145
+ re_attention = re.compile(
146
+ r"""
147
+ \\\(|
148
+ \\\)|
149
+ \\\[|
150
+ \\]|
151
+ \\\\|
152
+ \\|
153
+ \(|
154
+ \[|
155
+ :([+-]?[.\d]+)\)|
156
+ \)|
157
+ ]|
158
+ [^\\()\[\]:]+|
159
+ :
160
+ """,
161
+ re.X,
162
+ )
163
+
164
+
165
+ def parse_prompt_attention(text):
166
+ """
167
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
168
+ Accepted tokens are:
169
+ (abc) - increases attention to abc by a multiplier of 1.1
170
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
171
+ [abc] - decreases attention to abc by a multiplier of 1.1
172
+ \( - literal character '('
173
+ \[ - literal character '['
174
+ \) - literal character ')'
175
+ \] - literal character ']'
176
+ \\ - literal character '\'
177
+ anything else - just text
178
+ >>> parse_prompt_attention('normal text')
179
+ [['normal text', 1.0]]
180
+ >>> parse_prompt_attention('an (important) word')
181
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
182
+ >>> parse_prompt_attention('(unbalanced')
183
+ [['unbalanced', 1.1]]
184
+ >>> parse_prompt_attention('\(literal\]')
185
+ [['(literal]', 1.0]]
186
+ >>> parse_prompt_attention('(unnecessary)(parens)')
187
+ [['unnecessaryparens', 1.1]]
188
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
189
+ [['a ', 1.0],
190
+ ['house', 1.5730000000000004],
191
+ [' ', 1.1],
192
+ ['on', 1.0],
193
+ [' a ', 1.1],
194
+ ['hill', 0.55],
195
+ [', sun, ', 1.1],
196
+ ['sky', 1.4641000000000006],
197
+ ['.', 1.1]]
198
+ """
199
+
200
+ res = []
201
+ round_brackets = []
202
+ square_brackets = []
203
+
204
+ round_bracket_multiplier = 1.1
205
+ square_bracket_multiplier = 1 / 1.1
206
+
207
+ def multiply_range(start_position, multiplier):
208
+ for p in range(start_position, len(res)):
209
+ res[p][1] *= multiplier
210
+
211
+ for m in re_attention.finditer(text):
212
+ text = m.group(0)
213
+ weight = m.group(1)
214
+
215
+ if text.startswith("\\"):
216
+ res.append([text[1:], 1.0])
217
+ elif text == "(":
218
+ round_brackets.append(len(res))
219
+ elif text == "[":
220
+ square_brackets.append(len(res))
221
+ elif weight is not None and len(round_brackets) > 0:
222
+ multiply_range(round_brackets.pop(), float(weight))
223
+ elif text == ")" and len(round_brackets) > 0:
224
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
225
+ elif text == "]" and len(square_brackets) > 0:
226
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
227
+ else:
228
+ res.append([text, 1.0])
229
+
230
+ for pos in round_brackets:
231
+ multiply_range(pos, round_bracket_multiplier)
232
+
233
+ for pos in square_brackets:
234
+ multiply_range(pos, square_bracket_multiplier)
235
+
236
+ if len(res) == 0:
237
+ res = [["", 1.0]]
238
+
239
+ # merge runs of identical weights
240
+ i = 0
241
+ while i + 1 < len(res):
242
+ if res[i][1] == res[i + 1][1]:
243
+ res[i][0] += res[i + 1][0]
244
+ res.pop(i + 1)
245
+ else:
246
+ i += 1
247
+
248
+ return res
249
+
250
+
251
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
252
+ r"""
253
+ Tokenize a list of prompts and return its tokens with weights of each token.
254
+
255
+ No padding, starting or ending token is included.
256
+ """
257
+ tokens = []
258
+ weights = []
259
+ truncated = False
260
+ for text in prompt:
261
+ texts_and_weights = parse_prompt_attention(text)
262
+ text_token = []
263
+ text_weight = []
264
+ for word, weight in texts_and_weights:
265
+ # tokenize and discard the starting and the ending token
266
+ token = tokenizer(word).input_ids[1:-1]
267
+ text_token += token
268
+ # copy the weight by length of token
269
+ text_weight += [weight] * len(token)
270
+ # stop if the text is too long (longer than truncation limit)
271
+ if len(text_token) > max_length:
272
+ truncated = True
273
+ break
274
+ # truncate
275
+ if len(text_token) > max_length:
276
+ truncated = True
277
+ text_token = text_token[:max_length]
278
+ text_weight = text_weight[:max_length]
279
+ tokens.append(text_token)
280
+ weights.append(text_weight)
281
+ if truncated:
282
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
283
+ return tokens, weights
284
+
285
+
286
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
287
+ r"""
288
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
289
+ """
290
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
291
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
292
+ for i in range(len(tokens)):
293
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
294
+ if no_boseos_middle:
295
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
296
+ else:
297
+ w = []
298
+ if len(weights[i]) == 0:
299
+ w = [1.0] * weights_length
300
+ else:
301
+ for j in range(max_embeddings_multiples):
302
+ w.append(1.0) # weight for starting token in this chunk
303
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
304
+ w.append(1.0) # weight for ending token in this chunk
305
+ w += [1.0] * (weights_length - len(w))
306
+ weights[i] = w[:]
307
+
308
+ return tokens, weights
309
+
310
+
311
+ def get_unweighted_text_embeddings(
312
+ tokenizer,
313
+ text_encoder,
314
+ text_input: torch.Tensor,
315
+ chunk_length: int,
316
+ clip_skip: int,
317
+ eos: int,
318
+ pad: int,
319
+ no_boseos_middle: Optional[bool] = True,
320
+ ):
321
+ """
322
+ When the length of tokens is a multiple of the capacity of the text encoder,
323
+ it should be split into chunks and sent to the text encoder individually.
324
+ """
325
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
326
+ if max_embeddings_multiples > 1:
327
+ text_embeddings = []
328
+ for i in range(max_embeddings_multiples):
329
+ # extract the i-th chunk
330
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
331
+
332
+ # cover the head and the tail by the starting and the ending tokens
333
+ text_input_chunk[:, 0] = text_input[0, 0]
334
+ if pad == eos: # v1
335
+ text_input_chunk[:, -1] = text_input[0, -1]
336
+ else: # v2
337
+ for j in range(len(text_input_chunk)):
338
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
339
+ text_input_chunk[j, -1] = eos
340
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
341
+ text_input_chunk[j, 1] = eos
342
+
343
+ if clip_skip is None or clip_skip == 1:
344
+ text_embedding = text_encoder(text_input_chunk)[0]
345
+ else:
346
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
347
+ text_embedding = enc_out["hidden_states"][-clip_skip]
348
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
349
+
350
+ if no_boseos_middle:
351
+ if i == 0:
352
+ # discard the ending token
353
+ text_embedding = text_embedding[:, :-1]
354
+ elif i == max_embeddings_multiples - 1:
355
+ # discard the starting token
356
+ text_embedding = text_embedding[:, 1:]
357
+ else:
358
+ # discard both starting and ending tokens
359
+ text_embedding = text_embedding[:, 1:-1]
360
+
361
+ text_embeddings.append(text_embedding)
362
+ text_embeddings = torch.concat(text_embeddings, axis=1)
363
+ else:
364
+ if clip_skip is None or clip_skip == 1:
365
+ text_embeddings = text_encoder(text_input)[0]
366
+ else:
367
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
368
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
369
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
370
+ return text_embeddings
371
+
372
+
373
+ def get_weighted_text_embeddings(
374
+ tokenizer,
375
+ text_encoder,
376
+ prompt: Union[str, List[str]],
377
+ device,
378
+ max_embeddings_multiples: Optional[int] = 3,
379
+ no_boseos_middle: Optional[bool] = False,
380
+ clip_skip=None,
381
+ ):
382
+ r"""
383
+ Prompts can be assigned with local weights using brackets. For example,
384
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
385
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
386
+
387
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
388
+
389
+ Args:
390
+ prompt (`str` or `List[str]`):
391
+ The prompt or prompts to guide the image generation.
392
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
393
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
394
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
395
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
396
+ ending token in each of the chunk in the middle.
397
+ skip_parsing (`bool`, *optional*, defaults to `False`):
398
+ Skip the parsing of brackets.
399
+ skip_weighting (`bool`, *optional*, defaults to `False`):
400
+ Skip the weighting. When the parsing is skipped, it is forced True.
401
+ """
402
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
403
+ if isinstance(prompt, str):
404
+ prompt = [prompt]
405
+
406
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
407
+
408
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
409
+ max_length = max([len(token) for token in prompt_tokens])
410
+
411
+ max_embeddings_multiples = min(
412
+ max_embeddings_multiples,
413
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
414
+ )
415
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
416
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
417
+
418
+ # pad the length of tokens and weights
419
+ bos = tokenizer.bos_token_id
420
+ eos = tokenizer.eos_token_id
421
+ pad = tokenizer.pad_token_id
422
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
423
+ prompt_tokens,
424
+ prompt_weights,
425
+ max_length,
426
+ bos,
427
+ eos,
428
+ no_boseos_middle=no_boseos_middle,
429
+ chunk_length=tokenizer.model_max_length,
430
+ )
431
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
432
+
433
+ # get the embeddings
434
+ text_embeddings = get_unweighted_text_embeddings(
435
+ tokenizer,
436
+ text_encoder,
437
+ prompt_tokens,
438
+ tokenizer.model_max_length,
439
+ clip_skip,
440
+ eos,
441
+ pad,
442
+ no_boseos_middle=no_boseos_middle,
443
+ )
444
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
445
+
446
+ # assign weights to the prompts and normalize in the sense of mean
447
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
448
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
449
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
450
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
451
+
452
+ return text_embeddings
453
+
454
+
455
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
456
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
457
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
458
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
459
+ for i in range(iterations):
460
+ r = random.random() * 2 + 2 # Rather than always going 2x,
461
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
462
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
463
+ if wn == 1 or hn == 1:
464
+ break # Lowest resolution is 1x1
465
+ return noise / noise.std() # Scaled back to roughly unit variance
466
+
467
+
468
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
469
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
470
+ if noise_offset is None:
471
+ return noise
472
+ if adaptive_noise_scale is not None:
473
+ # latent shape: (batch_size, channels, height, width)
474
+ # abs mean value for each channel
475
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
476
+
477
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
478
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
479
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
480
+
481
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
482
+ return noise
483
+
484
+
485
+ def apply_masked_loss(loss, batch):
486
+ if "conditioning_images" in batch:
487
+ # conditioning image is -1 to 1. we need to convert it to 0 to 1
488
+ mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
489
+ mask_image = mask_image / 2 + 0.5
490
+ # print(f"conditioning_image: {mask_image.shape}")
491
+ elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
492
+ # alpha mask is 0 to 1
493
+ mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
494
+ # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
495
+ else:
496
+ return loss
497
+
498
+ # resize to the same size as the loss
499
+ mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
500
+ loss = loss * mask_image
501
+ return loss
502
+
503
+
504
+ """
505
+ ##########################################
506
+ # Perlin Noise
507
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
508
+ delta = (res[0] / shape[0], res[1] / shape[1])
509
+ d = (shape[0] // res[0], shape[1] // res[1])
510
+
511
+ grid = (
512
+ torch.stack(
513
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
514
+ dim=-1,
515
+ )
516
+ % 1
517
+ )
518
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
519
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
520
+
521
+ tile_grads = (
522
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
523
+ .repeat_interleave(d[0], 0)
524
+ .repeat_interleave(d[1], 1)
525
+ )
526
+ dot = lambda grad, shift: (
527
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
528
+ * grad[: shape[0], : shape[1]]
529
+ ).sum(dim=-1)
530
+
531
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
532
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
533
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
534
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
535
+ t = fade(grid[: shape[0], : shape[1]])
536
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
537
+
538
+
539
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
540
+ noise = torch.zeros(shape, device=device)
541
+ frequency = 1
542
+ amplitude = 1
543
+ for _ in range(octaves):
544
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
545
+ frequency *= 2
546
+ amplitude *= persistence
547
+ return noise
548
+
549
+
550
+ def perlin_noise(noise, device, octaves):
551
+ _, c, w, h = noise.shape
552
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
553
+ noise_perlin = []
554
+ for _ in range(c):
555
+ noise_perlin.append(perlin())
556
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
557
+ noise += noise_perlin # broadcast for each batch
558
+ return noise / noise.std() # Scaled back to roughly unit variance
559
+ """
library/deepspeed_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from accelerate import DeepSpeedPlugin, Accelerator
5
+
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15
+ # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16
+ parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17
+ parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18
+ parser.add_argument(
19
+ "--offload_optimizer_device",
20
+ type=str,
21
+ default=None,
22
+ choices=[None, "cpu", "nvme"],
23
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24
+ )
25
+ parser.add_argument(
26
+ "--offload_optimizer_nvme_path",
27
+ type=str,
28
+ default=None,
29
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30
+ )
31
+ parser.add_argument(
32
+ "--offload_param_device",
33
+ type=str,
34
+ default=None,
35
+ choices=[None, "cpu", "nvme"],
36
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37
+ )
38
+ parser.add_argument(
39
+ "--offload_param_nvme_path",
40
+ type=str,
41
+ default=None,
42
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43
+ )
44
+ parser.add_argument(
45
+ "--zero3_init_flag",
46
+ action="store_true",
47
+ help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48
+ "Only applicable with ZeRO Stage-3.",
49
+ )
50
+ parser.add_argument(
51
+ "--zero3_save_16bit_model",
52
+ action="store_true",
53
+ help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54
+ )
55
+ parser.add_argument(
56
+ "--fp16_master_weights_and_gradients",
57
+ action="store_true",
58
+ help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59
+ )
60
+
61
+
62
+ def prepare_deepspeed_args(args: argparse.Namespace):
63
+ if not args.deepspeed:
64
+ return
65
+
66
+ # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67
+ args.max_data_loader_n_workers = 1
68
+
69
+
70
+ def prepare_deepspeed_plugin(args: argparse.Namespace):
71
+ if not args.deepspeed:
72
+ return None
73
+
74
+ try:
75
+ import deepspeed
76
+ except ImportError as e:
77
+ logger.error(
78
+ "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79
+ )
80
+ exit(1)
81
+
82
+ deepspeed_plugin = DeepSpeedPlugin(
83
+ zero_stage=args.zero_stage,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_clipping=args.max_grad_norm,
86
+ offload_optimizer_device=args.offload_optimizer_device,
87
+ offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88
+ offload_param_device=args.offload_param_device,
89
+ offload_param_nvme_path=args.offload_param_nvme_path,
90
+ zero3_init_flag=args.zero3_init_flag,
91
+ zero3_save_16bit_model=args.zero3_save_16bit_model,
92
+ )
93
+ deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94
+ deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95
+ args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96
+ )
97
+ deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98
+ if args.mixed_precision.lower() == "fp16":
99
+ deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100
+ if args.full_fp16 or args.fp16_master_weights_and_gradients:
101
+ if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102
+ deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103
+ logger.info("[DeepSpeed] full fp16 enable.")
104
+ else:
105
+ logger.info(
106
+ "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107
+ )
108
+
109
+ if args.offload_optimizer_device is not None:
110
+ logger.info("[DeepSpeed] start to manually build cpu_adam.")
111
+ deepspeed.ops.op_builder.CPUAdamBuilder().load()
112
+ logger.info("[DeepSpeed] building cpu_adam done.")
113
+
114
+ return deepspeed_plugin
115
+
116
+
117
+ # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118
+ def prepare_deepspeed_model(args: argparse.Namespace, **models):
119
+ # remove None from models
120
+ models = {k: v for k, v in models.items() if v is not None}
121
+
122
+ class DeepSpeedWrapper(torch.nn.Module):
123
+ def __init__(self, **kw_models) -> None:
124
+ super().__init__()
125
+ self.models = torch.nn.ModuleDict()
126
+
127
+ for key, model in kw_models.items():
128
+ if isinstance(model, list):
129
+ model = torch.nn.ModuleList(model)
130
+ assert isinstance(
131
+ model, torch.nn.Module
132
+ ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133
+ self.models.update(torch.nn.ModuleDict({key: model}))
134
+
135
+ def get_models(self):
136
+ return self.models
137
+
138
+ ds_model = DeepSpeedWrapper(**models)
139
+ return ds_model
library/device_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+
4
+ import torch
5
+
6
+ try:
7
+ HAS_CUDA = torch.cuda.is_available()
8
+ except Exception:
9
+ HAS_CUDA = False
10
+
11
+ try:
12
+ HAS_MPS = torch.backends.mps.is_available()
13
+ except Exception:
14
+ HAS_MPS = False
15
+
16
+ try:
17
+ import intel_extension_for_pytorch as ipex # noqa
18
+
19
+ HAS_XPU = torch.xpu.is_available()
20
+ except Exception:
21
+ HAS_XPU = False
22
+
23
+
24
+ def clean_memory():
25
+ gc.collect()
26
+ if HAS_CUDA:
27
+ torch.cuda.empty_cache()
28
+ if HAS_XPU:
29
+ torch.xpu.empty_cache()
30
+ if HAS_MPS:
31
+ torch.mps.empty_cache()
32
+
33
+
34
+ def clean_memory_on_device(device: torch.device):
35
+ r"""
36
+ Clean memory on the specified device, will be called from training scripts.
37
+ """
38
+ gc.collect()
39
+
40
+ # device may "cuda" or "cuda:0", so we need to check the type of device
41
+ if device.type == "cuda":
42
+ torch.cuda.empty_cache()
43
+ if device.type == "xpu":
44
+ torch.xpu.empty_cache()
45
+ if device.type == "mps":
46
+ torch.mps.empty_cache()
47
+
48
+
49
+ @functools.lru_cache(maxsize=None)
50
+ def get_preferred_device() -> torch.device:
51
+ r"""
52
+ Do not call this function from training scripts. Use accelerator.device instead.
53
+ """
54
+ if HAS_CUDA:
55
+ device = torch.device("cuda")
56
+ elif HAS_XPU:
57
+ device = torch.device("xpu")
58
+ elif HAS_MPS:
59
+ device = torch.device("mps")
60
+ else:
61
+ device = torch.device("cpu")
62
+ print(f"get_preferred_device() -> {device}")
63
+ return device
64
+
65
+
66
+ def init_ipex():
67
+ """
68
+ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
69
+
70
+ This function should run right after importing torch and before doing anything else.
71
+
72
+ If IPEX is not available, this function does nothing.
73
+ """
74
+ try:
75
+ if HAS_XPU:
76
+ from library.ipex import ipex_init
77
+
78
+ is_initialized, error_message = ipex_init()
79
+ if not is_initialized:
80
+ print("failed to initialize ipex:", error_message)
81
+ else:
82
+ return
83
+ except Exception as e:
84
+ print("failed to initialize ipex:", e)
library/huggingface_util.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, BinaryIO
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+ from library.utils import fire_in_thread
7
+ from library.utils import setup_logging
8
+ setup_logging()
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13
+ api = HfApi(
14
+ token=token,
15
+ )
16
+ try:
17
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18
+ return True
19
+ except:
20
+ return False
21
+
22
+
23
+ def upload(
24
+ args: argparse.Namespace,
25
+ src: Union[str, Path, bytes, BinaryIO],
26
+ dest_suffix: str = "",
27
+ force_sync_upload: bool = False,
28
+ ):
29
+ repo_id = args.huggingface_repo_id
30
+ repo_type = args.huggingface_repo_type
31
+ token = args.huggingface_token
32
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34
+ api = HfApi(token=token)
35
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36
+ try:
37
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38
+ except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39
+ logger.error("===========================================")
40
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41
+ logger.error("===========================================")
42
+
43
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44
+
45
+ def uploader():
46
+ try:
47
+ if is_folder:
48
+ api.upload_folder(
49
+ repo_id=repo_id,
50
+ repo_type=repo_type,
51
+ folder_path=src,
52
+ path_in_repo=path_in_repo,
53
+ )
54
+ else:
55
+ api.upload_file(
56
+ repo_id=repo_id,
57
+ repo_type=repo_type,
58
+ path_or_fileobj=src,
59
+ path_in_repo=path_in_repo,
60
+ )
61
+ except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62
+ logger.error("===========================================")
63
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64
+ logger.error("===========================================")
65
+
66
+ if args.async_upload and not force_sync_upload:
67
+ fire_in_thread(uploader)
68
+ else:
69
+ uploader()
70
+
71
+
72
+ def list_dir(
73
+ repo_id: str,
74
+ subfolder: str,
75
+ repo_type: str,
76
+ revision: str = "main",
77
+ token: str = None,
78
+ ):
79
+ api = HfApi(
80
+ token=token,
81
+ )
82
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84
+ return file_list
library/hypernetwork.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers.models.attention_processor import (
4
+ Attention,
5
+ AttnProcessor2_0,
6
+ SlicedAttnProcessor,
7
+ XFormersAttnProcessor
8
+ )
9
+
10
+ try:
11
+ import xformers.ops
12
+ except:
13
+ xformers = None
14
+
15
+
16
+ loaded_networks = []
17
+
18
+
19
+ def apply_single_hypernetwork(
20
+ hypernetwork, hidden_states, encoder_hidden_states
21
+ ):
22
+ context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
23
+ return context_k, context_v
24
+
25
+
26
+ def apply_hypernetworks(context_k, context_v, layer=None):
27
+ if len(loaded_networks) == 0:
28
+ return context_v, context_v
29
+ for hypernetwork in loaded_networks:
30
+ context_k, context_v = hypernetwork.forward(context_k, context_v)
31
+
32
+ context_k = context_k.to(dtype=context_k.dtype)
33
+ context_v = context_v.to(dtype=context_k.dtype)
34
+
35
+ return context_k, context_v
36
+
37
+
38
+
39
+ def xformers_forward(
40
+ self: XFormersAttnProcessor,
41
+ attn: Attention,
42
+ hidden_states: torch.Tensor,
43
+ encoder_hidden_states: torch.Tensor = None,
44
+ attention_mask: torch.Tensor = None,
45
+ ):
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape
48
+ if encoder_hidden_states is None
49
+ else encoder_hidden_states.shape
50
+ )
51
+
52
+ attention_mask = attn.prepare_attention_mask(
53
+ attention_mask, sequence_length, batch_size
54
+ )
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
64
+
65
+ key = attn.to_k(context_k)
66
+ value = attn.to_v(context_v)
67
+
68
+ query = attn.head_to_batch_dim(query).contiguous()
69
+ key = attn.head_to_batch_dim(key).contiguous()
70
+ value = attn.head_to_batch_dim(value).contiguous()
71
+
72
+ hidden_states = xformers.ops.memory_efficient_attention(
73
+ query,
74
+ key,
75
+ value,
76
+ attn_bias=attention_mask,
77
+ op=self.attention_op,
78
+ scale=attn.scale,
79
+ )
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ hidden_states = attn.batch_to_head_dim(hidden_states)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+ return hidden_states
88
+
89
+
90
+ def sliced_attn_forward(
91
+ self: SlicedAttnProcessor,
92
+ attn: Attention,
93
+ hidden_states: torch.Tensor,
94
+ encoder_hidden_states: torch.Tensor = None,
95
+ attention_mask: torch.Tensor = None,
96
+ ):
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+ attention_mask = attn.prepare_attention_mask(
103
+ attention_mask, sequence_length, batch_size
104
+ )
105
+
106
+ query = attn.to_q(hidden_states)
107
+ dim = query.shape[-1]
108
+ query = attn.head_to_batch_dim(query)
109
+
110
+ if encoder_hidden_states is None:
111
+ encoder_hidden_states = hidden_states
112
+ elif attn.norm_cross:
113
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
114
+
115
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
116
+
117
+ key = attn.to_k(context_k)
118
+ value = attn.to_v(context_v)
119
+ key = attn.head_to_batch_dim(key)
120
+ value = attn.head_to_batch_dim(value)
121
+
122
+ batch_size_attention, query_tokens, _ = query.shape
123
+ hidden_states = torch.zeros(
124
+ (batch_size_attention, query_tokens, dim // attn.heads),
125
+ device=query.device,
126
+ dtype=query.dtype,
127
+ )
128
+
129
+ for i in range(batch_size_attention // self.slice_size):
130
+ start_idx = i * self.slice_size
131
+ end_idx = (i + 1) * self.slice_size
132
+
133
+ query_slice = query[start_idx:end_idx]
134
+ key_slice = key[start_idx:end_idx]
135
+ attn_mask_slice = (
136
+ attention_mask[start_idx:end_idx] if attention_mask is not None else None
137
+ )
138
+
139
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
140
+
141
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
142
+
143
+ hidden_states[start_idx:end_idx] = attn_slice
144
+
145
+ hidden_states = attn.batch_to_head_dim(hidden_states)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ return hidden_states
153
+
154
+
155
+ def v2_0_forward(
156
+ self: AttnProcessor2_0,
157
+ attn: Attention,
158
+ hidden_states,
159
+ encoder_hidden_states=None,
160
+ attention_mask=None,
161
+ ):
162
+ batch_size, sequence_length, _ = (
163
+ hidden_states.shape
164
+ if encoder_hidden_states is None
165
+ else encoder_hidden_states.shape
166
+ )
167
+ inner_dim = hidden_states.shape[-1]
168
+
169
+ if attention_mask is not None:
170
+ attention_mask = attn.prepare_attention_mask(
171
+ attention_mask, sequence_length, batch_size
172
+ )
173
+ # scaled_dot_product_attention expects attention_mask shape to be
174
+ # (batch, heads, source_length, target_length)
175
+ attention_mask = attention_mask.view(
176
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
177
+ )
178
+
179
+ query = attn.to_q(hidden_states)
180
+
181
+ if encoder_hidden_states is None:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
187
+
188
+ key = attn.to_k(context_k)
189
+ value = attn.to_v(context_v)
190
+
191
+ head_dim = inner_dim // attn.heads
192
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
197
+ # TODO: add support for attn.scale when we move to Torch 2.1
198
+ hidden_states = F.scaled_dot_product_attention(
199
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
200
+ )
201
+
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(
203
+ batch_size, -1, attn.heads * head_dim
204
+ )
205
+ hidden_states = hidden_states.to(query.dtype)
206
+
207
+ # linear proj
208
+ hidden_states = attn.to_out[0](hidden_states)
209
+ # dropout
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+ return hidden_states
212
+
213
+
214
+ def replace_attentions_for_hypernetwork():
215
+ import diffusers.models.attention_processor
216
+
217
+ diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
218
+ xformers_forward
219
+ )
220
+ diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
221
+ sliced_attn_forward
222
+ )
223
+ diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
library/ipex/__init__.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import contextlib
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ from .hijacks import ipex_hijacks
7
+
8
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
9
+
10
+ def ipex_init(): # pylint: disable=too-many-statements
11
+ try:
12
+ if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
13
+ return True, "Skipping IPEX hijack"
14
+ else:
15
+ # Replace cuda with xpu:
16
+ torch.cuda.current_device = torch.xpu.current_device
17
+ torch.cuda.current_stream = torch.xpu.current_stream
18
+ torch.cuda.device = torch.xpu.device
19
+ torch.cuda.device_count = torch.xpu.device_count
20
+ torch.cuda.device_of = torch.xpu.device_of
21
+ torch.cuda.get_device_name = torch.xpu.get_device_name
22
+ torch.cuda.get_device_properties = torch.xpu.get_device_properties
23
+ torch.cuda.init = torch.xpu.init
24
+ torch.cuda.is_available = torch.xpu.is_available
25
+ torch.cuda.is_initialized = torch.xpu.is_initialized
26
+ torch.cuda.is_current_stream_capturing = lambda: False
27
+ torch.cuda.set_device = torch.xpu.set_device
28
+ torch.cuda.stream = torch.xpu.stream
29
+ torch.cuda.synchronize = torch.xpu.synchronize
30
+ torch.cuda.Event = torch.xpu.Event
31
+ torch.cuda.Stream = torch.xpu.Stream
32
+ torch.cuda.FloatTensor = torch.xpu.FloatTensor
33
+ torch.Tensor.cuda = torch.Tensor.xpu
34
+ torch.Tensor.is_cuda = torch.Tensor.is_xpu
35
+ torch.nn.Module.cuda = torch.nn.Module.xpu
36
+ torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
37
+ torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
38
+ torch.cuda._initialized = torch.xpu.lazy_init._initialized
39
+ torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
40
+ torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
41
+ torch.cuda._tls = torch.xpu.lazy_init._tls
42
+ torch.cuda.threading = torch.xpu.lazy_init.threading
43
+ torch.cuda.traceback = torch.xpu.lazy_init.traceback
44
+ torch.cuda.Optional = torch.xpu.Optional
45
+ torch.cuda.__cached__ = torch.xpu.__cached__
46
+ torch.cuda.__loader__ = torch.xpu.__loader__
47
+ torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
48
+ torch.cuda.Tuple = torch.xpu.Tuple
49
+ torch.cuda.streams = torch.xpu.streams
50
+ torch.cuda._lazy_new = torch.xpu._lazy_new
51
+ torch.cuda.FloatStorage = torch.xpu.FloatStorage
52
+ torch.cuda.Any = torch.xpu.Any
53
+ torch.cuda.__doc__ = torch.xpu.__doc__
54
+ torch.cuda.default_generators = torch.xpu.default_generators
55
+ torch.cuda.HalfTensor = torch.xpu.HalfTensor
56
+ torch.cuda._get_device_index = torch.xpu._get_device_index
57
+ torch.cuda.__path__ = torch.xpu.__path__
58
+ torch.cuda.Device = torch.xpu.Device
59
+ torch.cuda.IntTensor = torch.xpu.IntTensor
60
+ torch.cuda.ByteStorage = torch.xpu.ByteStorage
61
+ torch.cuda.set_stream = torch.xpu.set_stream
62
+ torch.cuda.BoolStorage = torch.xpu.BoolStorage
63
+ torch.cuda.os = torch.xpu.os
64
+ torch.cuda.torch = torch.xpu.torch
65
+ torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
66
+ torch.cuda.Union = torch.xpu.Union
67
+ torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
68
+ torch.cuda.ShortTensor = torch.xpu.ShortTensor
69
+ torch.cuda.LongTensor = torch.xpu.LongTensor
70
+ torch.cuda.IntStorage = torch.xpu.IntStorage
71
+ torch.cuda.LongStorage = torch.xpu.LongStorage
72
+ torch.cuda.__annotations__ = torch.xpu.__annotations__
73
+ torch.cuda.__package__ = torch.xpu.__package__
74
+ torch.cuda.__builtins__ = torch.xpu.__builtins__
75
+ torch.cuda.CharTensor = torch.xpu.CharTensor
76
+ torch.cuda.List = torch.xpu.List
77
+ torch.cuda._lazy_init = torch.xpu._lazy_init
78
+ torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
79
+ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
80
+ torch.cuda.ByteTensor = torch.xpu.ByteTensor
81
+ torch.cuda.StreamContext = torch.xpu.StreamContext
82
+ torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
83
+ torch.cuda.ShortStorage = torch.xpu.ShortStorage
84
+ torch.cuda._lazy_call = torch.xpu._lazy_call
85
+ torch.cuda.HalfStorage = torch.xpu.HalfStorage
86
+ torch.cuda.random = torch.xpu.random
87
+ torch.cuda._device = torch.xpu._device
88
+ torch.cuda.classproperty = torch.xpu.classproperty
89
+ torch.cuda.__name__ = torch.xpu.__name__
90
+ torch.cuda._device_t = torch.xpu._device_t
91
+ torch.cuda.warnings = torch.xpu.warnings
92
+ torch.cuda.__spec__ = torch.xpu.__spec__
93
+ torch.cuda.BoolTensor = torch.xpu.BoolTensor
94
+ torch.cuda.CharStorage = torch.xpu.CharStorage
95
+ torch.cuda.__file__ = torch.xpu.__file__
96
+ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
97
+ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
98
+
99
+ # Memory:
100
+ torch.cuda.memory = torch.xpu.memory
101
+ if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
102
+ torch.xpu.empty_cache = lambda: None
103
+ torch.cuda.empty_cache = torch.xpu.empty_cache
104
+ torch.cuda.memory_stats = torch.xpu.memory_stats
105
+ torch.cuda.memory_summary = torch.xpu.memory_summary
106
+ torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
107
+ torch.cuda.memory_allocated = torch.xpu.memory_allocated
108
+ torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
109
+ torch.cuda.memory_reserved = torch.xpu.memory_reserved
110
+ torch.cuda.memory_cached = torch.xpu.memory_reserved
111
+ torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
112
+ torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
113
+ torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
114
+ torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
115
+ torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
116
+ torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
117
+ torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
118
+
119
+ # RNG:
120
+ torch.cuda.get_rng_state = torch.xpu.get_rng_state
121
+ torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
122
+ torch.cuda.set_rng_state = torch.xpu.set_rng_state
123
+ torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
124
+ torch.cuda.manual_seed = torch.xpu.manual_seed
125
+ torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
126
+ torch.cuda.seed = torch.xpu.seed
127
+ torch.cuda.seed_all = torch.xpu.seed_all
128
+ torch.cuda.initial_seed = torch.xpu.initial_seed
129
+
130
+ # AMP:
131
+ torch.cuda.amp = torch.xpu.amp
132
+ torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
133
+ torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
134
+
135
+ if not hasattr(torch.cuda.amp, "common"):
136
+ torch.cuda.amp.common = contextlib.nullcontext()
137
+ torch.cuda.amp.common.amp_definitely_not_available = lambda: False
138
+
139
+ try:
140
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
141
+ except Exception: # pylint: disable=broad-exception-caught
142
+ try:
143
+ from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
144
+ gradscaler_init()
145
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
146
+ except Exception: # pylint: disable=broad-exception-caught
147
+ torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
148
+
149
+ # C
150
+ torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
151
+ ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
152
+ ipex._C._DeviceProperties.major = 2024
153
+ ipex._C._DeviceProperties.minor = 0
154
+
155
+ # Fix functions with ipex:
156
+ torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
157
+ torch._utils._get_available_device_type = lambda: "xpu"
158
+ torch.has_cuda = True
159
+ torch.cuda.has_half = True
160
+ torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
161
+ torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
162
+ torch.backends.cuda.is_built = lambda *args, **kwargs: True
163
+ torch.version.cuda = "12.1"
164
+ torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
165
+ torch.cuda.get_device_properties.major = 12
166
+ torch.cuda.get_device_properties.minor = 1
167
+ torch.cuda.ipc_collect = lambda *args, **kwargs: None
168
+ torch.cuda.utilization = lambda *args, **kwargs: 0
169
+
170
+ ipex_hijacks()
171
+ if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
172
+ try:
173
+ from .diffusers import ipex_diffusers
174
+ ipex_diffusers()
175
+ except Exception: # pylint: disable=broad-exception-caught
176
+ pass
177
+ torch.cuda.is_xpu_hijacked = True
178
+ except Exception as e:
179
+ return False, e
180
+ return True, None
library/ipex/attention.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ from functools import cache
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
9
+
10
+ sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ # Find something divisible with the input_tokens
14
+ @cache
15
+ def find_slice_size(slice_size, slice_block_size):
16
+ while (slice_size * slice_block_size) > attention_slice_rate:
17
+ slice_size = slice_size // 2
18
+ if slice_size <= 1:
19
+ slice_size = 1
20
+ break
21
+ return slice_size
22
+
23
+ # Find slice sizes for SDPA
24
+ @cache
25
+ def find_sdpa_slice_sizes(query_shape, query_element_size):
26
+ if len(query_shape) == 3:
27
+ batch_size_attention, query_tokens, shape_three = query_shape
28
+ shape_four = 1
29
+ else:
30
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if block_size > sdpa_slice_trigger_rate:
44
+ do_split = True
45
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
46
+ if split_slice_size * slice_block_size > attention_slice_rate:
47
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
48
+ do_split_2 = True
49
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
50
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
51
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
52
+ do_split_3 = True
53
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
54
+
55
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
56
+
57
+ # Find slice sizes for BMM
58
+ @cache
59
+ def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
60
+ batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
61
+ slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
62
+ block_size = batch_size_attention * slice_block_size
63
+
64
+ split_slice_size = batch_size_attention
65
+ split_2_slice_size = input_tokens
66
+ split_3_slice_size = mat2_atten_shape
67
+
68
+ do_split = False
69
+ do_split_2 = False
70
+ do_split_3 = False
71
+
72
+ if block_size > attention_slice_rate:
73
+ do_split = True
74
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
75
+ if split_slice_size * slice_block_size > attention_slice_rate:
76
+ slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
77
+ do_split_2 = True
78
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
79
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
80
+ slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
81
+ do_split_3 = True
82
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
83
+
84
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
85
+
86
+
87
+ original_torch_bmm = torch.bmm
88
+ def torch_bmm_32_bit(input, mat2, *, out=None):
89
+ if input.device.type != "xpu":
90
+ return original_torch_bmm(input, mat2, out=out)
91
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
92
+
93
+ # Slice BMM
94
+ if do_split:
95
+ batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
96
+ hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
97
+ for i in range(batch_size_attention // split_slice_size):
98
+ start_idx = i * split_slice_size
99
+ end_idx = (i + 1) * split_slice_size
100
+ if do_split_2:
101
+ for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
102
+ start_idx_2 = i2 * split_2_slice_size
103
+ end_idx_2 = (i2 + 1) * split_2_slice_size
104
+ if do_split_3:
105
+ for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
106
+ start_idx_3 = i3 * split_3_slice_size
107
+ end_idx_3 = (i3 + 1) * split_3_slice_size
108
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
109
+ input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
110
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
111
+ out=out
112
+ )
113
+ else:
114
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
115
+ input[start_idx:end_idx, start_idx_2:end_idx_2],
116
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2],
117
+ out=out
118
+ )
119
+ else:
120
+ hidden_states[start_idx:end_idx] = original_torch_bmm(
121
+ input[start_idx:end_idx],
122
+ mat2[start_idx:end_idx],
123
+ out=out
124
+ )
125
+ torch.xpu.synchronize(input.device)
126
+ else:
127
+ return original_torch_bmm(input, mat2, out=out)
128
+ return hidden_states
129
+
130
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
131
+ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
132
+ if query.device.type != "xpu":
133
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
134
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
135
+
136
+ # Slice SDPA
137
+ if do_split:
138
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
139
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
140
+ for i in range(batch_size_attention // split_slice_size):
141
+ start_idx = i * split_slice_size
142
+ end_idx = (i + 1) * split_slice_size
143
+ if do_split_2:
144
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
145
+ start_idx_2 = i2 * split_2_slice_size
146
+ end_idx_2 = (i2 + 1) * split_2_slice_size
147
+ if do_split_3:
148
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
149
+ start_idx_3 = i3 * split_3_slice_size
150
+ end_idx_3 = (i3 + 1) * split_3_slice_size
151
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
152
+ query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
153
+ key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
154
+ value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
155
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
156
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
157
+ )
158
+ else:
159
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
160
+ query[start_idx:end_idx, start_idx_2:end_idx_2],
161
+ key[start_idx:end_idx, start_idx_2:end_idx_2],
162
+ value[start_idx:end_idx, start_idx_2:end_idx_2],
163
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
164
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
165
+ )
166
+ else:
167
+ hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
168
+ query[start_idx:end_idx],
169
+ key[start_idx:end_idx],
170
+ value[start_idx:end_idx],
171
+ attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
172
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
173
+ )
174
+ torch.xpu.synchronize(query.device)
175
+ else:
176
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
177
+ return hidden_states
library/ipex/diffusers.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import diffusers #0.24.0 # pylint: disable=import-error
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import USE_PEFT_BACKEND
7
+ from functools import cache
8
+
9
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
10
+
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ @cache
14
+ def find_slice_size(slice_size, slice_block_size):
15
+ while (slice_size * slice_block_size) > attention_slice_rate:
16
+ slice_size = slice_size // 2
17
+ if slice_size <= 1:
18
+ slice_size = 1
19
+ break
20
+ return slice_size
21
+
22
+ @cache
23
+ def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
24
+ if len(query_shape) == 3:
25
+ batch_size_attention, query_tokens, shape_three = query_shape
26
+ shape_four = 1
27
+ else:
28
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
29
+ if slice_size is not None:
30
+ batch_size_attention = slice_size
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if query_device_type != "xpu":
44
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
45
+
46
+ if block_size > attention_slice_rate:
47
+ do_split = True
48
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
49
+ if split_slice_size * slice_block_size > attention_slice_rate:
50
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
51
+ do_split_2 = True
52
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
53
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
54
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
55
+ do_split_3 = True
56
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
57
+
58
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
59
+
60
+ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
61
+ r"""
62
+ Processor for implementing sliced attention.
63
+
64
+ Args:
65
+ slice_size (`int`, *optional*):
66
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
67
+ `attention_head_dim` must be a multiple of the `slice_size`.
68
+ """
69
+
70
+ def __init__(self, slice_size):
71
+ self.slice_size = slice_size
72
+
73
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
74
+ encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
75
+
76
+ residual = hidden_states
77
+
78
+ input_ndim = hidden_states.ndim
79
+
80
+ if input_ndim == 4:
81
+ batch_size, channel, height, width = hidden_states.shape
82
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
83
+
84
+ batch_size, sequence_length, _ = (
85
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
86
+ )
87
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
88
+
89
+ if attn.group_norm is not None:
90
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
91
+
92
+ query = attn.to_q(hidden_states)
93
+ dim = query.shape[-1]
94
+ query = attn.head_to_batch_dim(query)
95
+
96
+ if encoder_hidden_states is None:
97
+ encoder_hidden_states = hidden_states
98
+ elif attn.norm_cross:
99
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
100
+
101
+ key = attn.to_k(encoder_hidden_states)
102
+ value = attn.to_v(encoder_hidden_states)
103
+ key = attn.head_to_batch_dim(key)
104
+ value = attn.head_to_batch_dim(value)
105
+
106
+ batch_size_attention, query_tokens, shape_three = query.shape
107
+ hidden_states = torch.zeros(
108
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
109
+ )
110
+
111
+ ####################################################################
112
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
113
+ _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
114
+
115
+ for i in range(batch_size_attention // split_slice_size):
116
+ start_idx = i * split_slice_size
117
+ end_idx = (i + 1) * split_slice_size
118
+ if do_split_2:
119
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
120
+ start_idx_2 = i2 * split_2_slice_size
121
+ end_idx_2 = (i2 + 1) * split_2_slice_size
122
+ if do_split_3:
123
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
124
+ start_idx_3 = i3 * split_3_slice_size
125
+ end_idx_3 = (i3 + 1) * split_3_slice_size
126
+
127
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
128
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
129
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
130
+
131
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
132
+ del query_slice
133
+ del key_slice
134
+ del attn_mask_slice
135
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
136
+
137
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
138
+ del attn_slice
139
+ else:
140
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
141
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
142
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
143
+
144
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
145
+ del query_slice
146
+ del key_slice
147
+ del attn_mask_slice
148
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
149
+
150
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
151
+ del attn_slice
152
+ torch.xpu.synchronize(query.device)
153
+ else:
154
+ query_slice = query[start_idx:end_idx]
155
+ key_slice = key[start_idx:end_idx]
156
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
157
+
158
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
159
+ del query_slice
160
+ del key_slice
161
+ del attn_mask_slice
162
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
163
+
164
+ hidden_states[start_idx:end_idx] = attn_slice
165
+ del attn_slice
166
+ ####################################################################
167
+
168
+ hidden_states = attn.batch_to_head_dim(hidden_states)
169
+
170
+ # linear proj
171
+ hidden_states = attn.to_out[0](hidden_states)
172
+ # dropout
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+
175
+ if input_ndim == 4:
176
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177
+
178
+ if attn.residual_connection:
179
+ hidden_states = hidden_states + residual
180
+
181
+ hidden_states = hidden_states / attn.rescale_output_factor
182
+
183
+ return hidden_states
184
+
185
+
186
+ class AttnProcessor:
187
+ r"""
188
+ Default processor for performing attention-related computations.
189
+ """
190
+
191
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
192
+ encoder_hidden_states=None, attention_mask=None,
193
+ temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
194
+
195
+ residual = hidden_states
196
+
197
+ args = () if USE_PEFT_BACKEND else (scale,)
198
+
199
+ if attn.spatial_norm is not None:
200
+ hidden_states = attn.spatial_norm(hidden_states, temb)
201
+
202
+ input_ndim = hidden_states.ndim
203
+
204
+ if input_ndim == 4:
205
+ batch_size, channel, height, width = hidden_states.shape
206
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
207
+
208
+ batch_size, sequence_length, _ = (
209
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
210
+ )
211
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
212
+
213
+ if attn.group_norm is not None:
214
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
215
+
216
+ query = attn.to_q(hidden_states, *args)
217
+
218
+ if encoder_hidden_states is None:
219
+ encoder_hidden_states = hidden_states
220
+ elif attn.norm_cross:
221
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
222
+
223
+ key = attn.to_k(encoder_hidden_states, *args)
224
+ value = attn.to_v(encoder_hidden_states, *args)
225
+
226
+ query = attn.head_to_batch_dim(query)
227
+ key = attn.head_to_batch_dim(key)
228
+ value = attn.head_to_batch_dim(value)
229
+
230
+ ####################################################################
231
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
232
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
233
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
234
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
235
+
236
+ if do_split:
237
+ for i in range(batch_size_attention // split_slice_size):
238
+ start_idx = i * split_slice_size
239
+ end_idx = (i + 1) * split_slice_size
240
+ if do_split_2:
241
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
242
+ start_idx_2 = i2 * split_2_slice_size
243
+ end_idx_2 = (i2 + 1) * split_2_slice_size
244
+ if do_split_3:
245
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
246
+ start_idx_3 = i3 * split_3_slice_size
247
+ end_idx_3 = (i3 + 1) * split_3_slice_size
248
+
249
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
250
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
251
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
252
+
253
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
254
+ del query_slice
255
+ del key_slice
256
+ del attn_mask_slice
257
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
258
+
259
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
260
+ del attn_slice
261
+ else:
262
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
263
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
264
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
265
+
266
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
267
+ del query_slice
268
+ del key_slice
269
+ del attn_mask_slice
270
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
271
+
272
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
273
+ del attn_slice
274
+ else:
275
+ query_slice = query[start_idx:end_idx]
276
+ key_slice = key[start_idx:end_idx]
277
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
278
+
279
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
280
+ del query_slice
281
+ del key_slice
282
+ del attn_mask_slice
283
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
284
+
285
+ hidden_states[start_idx:end_idx] = attn_slice
286
+ del attn_slice
287
+ torch.xpu.synchronize(query.device)
288
+ else:
289
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
290
+ hidden_states = torch.bmm(attention_probs, value)
291
+ ####################################################################
292
+ hidden_states = attn.batch_to_head_dim(hidden_states)
293
+
294
+ # linear proj
295
+ hidden_states = attn.to_out[0](hidden_states, *args)
296
+ # dropout
297
+ hidden_states = attn.to_out[1](hidden_states)
298
+
299
+ if input_ndim == 4:
300
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
301
+
302
+ if attn.residual_connection:
303
+ hidden_states = hidden_states + residual
304
+
305
+ hidden_states = hidden_states / attn.rescale_output_factor
306
+
307
+ return hidden_states
308
+
309
+ def ipex_diffusers():
310
+ #ARC GPUs can't allocate more than 4GB to a single block:
311
+ diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
312
+ diffusers.models.attention_processor.AttnProcessor = AttnProcessor
library/ipex/gradscaler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
12
+
13
+ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
14
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
15
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
16
+
17
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
18
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
19
+ # However, we don't know their devices or dtypes in advance.
20
+
21
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
22
+ # Google says mypy struggles with defaultdicts type annotations.
23
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
24
+ # sync grad to master weight
25
+ if hasattr(optimizer, "sync_grad"):
26
+ optimizer.sync_grad()
27
+ with torch.no_grad():
28
+ for group in optimizer.param_groups:
29
+ for param in group["params"]:
30
+ if param.grad is None:
31
+ continue
32
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
33
+ raise ValueError("Attempting to unscale FP16 gradients.")
34
+ if param.grad.is_sparse:
35
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
36
+ # coalesce() deduplicates indices and adds all values that have the same index.
37
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
38
+ # so we should check the coalesced _values().
39
+ if param.grad.dtype is torch.float16:
40
+ param.grad = param.grad.coalesce()
41
+ to_unscale = param.grad._values()
42
+ else:
43
+ to_unscale = param.grad
44
+
45
+ # -: is there a way to split by device and dtype without appending in the inner loop?
46
+ to_unscale = to_unscale.to("cpu")
47
+ per_device_and_dtype_grads[to_unscale.device][
48
+ to_unscale.dtype
49
+ ].append(to_unscale)
50
+
51
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
52
+ for grads in per_dtype_grads.values():
53
+ core._amp_foreach_non_finite_check_and_unscale_(
54
+ grads,
55
+ per_device_found_inf.get("cpu"),
56
+ per_device_inv_scale.get("cpu"),
57
+ )
58
+
59
+ return per_device_found_inf._per_device_tensors
60
+
61
+ def unscale_(self, optimizer):
62
+ """
63
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
64
+ :meth:`unscale_` is optional, serving cases where you need to
65
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
66
+ between the backward pass(es) and :meth:`step`.
67
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
68
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
69
+ ...
70
+ scaler.scale(loss).backward()
71
+ scaler.unscale_(optimizer)
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
73
+ scaler.step(optimizer)
74
+ scaler.update()
75
+ Args:
76
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
77
+ .. warning::
78
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
79
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
80
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
81
+ .. warning::
82
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
83
+ """
84
+ if not self._enabled:
85
+ return
86
+
87
+ self._check_scale_growth_tracker("unscale_")
88
+
89
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
90
+
91
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
92
+ raise RuntimeError(
93
+ "unscale_() has already been called on this optimizer since the last update()."
94
+ )
95
+ elif optimizer_state["stage"] is OptState.STEPPED:
96
+ raise RuntimeError("unscale_() is being called after step().")
97
+
98
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
99
+ assert self._scale is not None
100
+ if device_supports_fp64:
101
+ inv_scale = self._scale.double().reciprocal().float()
102
+ else:
103
+ inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
104
+ found_inf = torch.full(
105
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
106
+ )
107
+
108
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
109
+ optimizer, inv_scale, found_inf, False
110
+ )
111
+ optimizer_state["stage"] = OptState.UNSCALED
112
+
113
+ def update(self, new_scale=None):
114
+ """
115
+ Updates the scale factor.
116
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
117
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
118
+ the scale is multiplied by ``growth_factor`` to increase it.
119
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
120
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
121
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
122
+ affect the scale GradScaler uses internally.)
123
+ Args:
124
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
125
+ .. warning::
126
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
127
+ been invoked for all optimizers used this iteration.
128
+ """
129
+ if not self._enabled:
130
+ return
131
+
132
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
133
+
134
+ if new_scale is not None:
135
+ # Accept a new user-defined scale.
136
+ if isinstance(new_scale, float):
137
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
138
+ else:
139
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
140
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
141
+ assert new_scale.numel() == 1, reason
142
+ assert new_scale.requires_grad is False, reason
143
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
144
+ else:
145
+ # Consume shared inf/nan data collected from optimizers to update the scale.
146
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
147
+ found_infs = [
148
+ found_inf.to(device="cpu", non_blocking=True)
149
+ for state in self._per_optimizer_states.values()
150
+ for found_inf in state["found_inf_per_device"].values()
151
+ ]
152
+
153
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
154
+
155
+ found_inf_combined = found_infs[0]
156
+ if len(found_infs) > 1:
157
+ for i in range(1, len(found_infs)):
158
+ found_inf_combined += found_infs[i]
159
+
160
+ to_device = _scale.device
161
+ _scale = _scale.to("cpu")
162
+ _growth_tracker = _growth_tracker.to("cpu")
163
+
164
+ core._amp_update_scale_(
165
+ _scale,
166
+ _growth_tracker,
167
+ found_inf_combined,
168
+ self._growth_factor,
169
+ self._backoff_factor,
170
+ self._growth_interval,
171
+ )
172
+
173
+ _scale = _scale.to(to_device)
174
+ _growth_tracker = _growth_tracker.to(to_device)
175
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
176
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
177
+
178
+ def gradscaler_init():
179
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
180
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
181
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
182
+ torch.xpu.amp.GradScaler.update = update
183
+ return torch.xpu.amp.GradScaler
library/ipex/hijacks.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import wraps
3
+ from contextlib import nullcontext
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ import numpy as np
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+
10
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
11
+
12
+ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
13
+ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
14
+ if isinstance(device_ids, list) and len(device_ids) > 1:
15
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
16
+ return module.to("xpu")
17
+
18
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
19
+ return nullcontext()
20
+
21
+ @property
22
+ def is_cuda(self):
23
+ return self.device.type == 'xpu' or self.device.type == 'cuda'
24
+
25
+ def check_device(device):
26
+ return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
27
+
28
+ def return_xpu(device):
29
+ return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
30
+
31
+
32
+ # Autocast
33
+ original_autocast_init = torch.amp.autocast_mode.autocast.__init__
34
+ @wraps(torch.amp.autocast_mode.autocast.__init__)
35
+ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
36
+ if device_type == "cuda":
37
+ return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
38
+ else:
39
+ return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
40
+
41
+ # Latent Antialias CPU Offload:
42
+ original_interpolate = torch.nn.functional.interpolate
43
+ @wraps(torch.nn.functional.interpolate)
44
+ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
45
+ if antialias or align_corners is not None or mode == 'bicubic':
46
+ return_device = tensor.device
47
+ return_dtype = tensor.dtype
48
+ return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
49
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
50
+ else:
51
+ return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
52
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
53
+
54
+
55
+ # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
56
+ original_from_numpy = torch.from_numpy
57
+ @wraps(torch.from_numpy)
58
+ def from_numpy(ndarray):
59
+ if ndarray.dtype == float:
60
+ return original_from_numpy(ndarray.astype('float32'))
61
+ else:
62
+ return original_from_numpy(ndarray)
63
+
64
+ original_as_tensor = torch.as_tensor
65
+ @wraps(torch.as_tensor)
66
+ def as_tensor(data, dtype=None, device=None):
67
+ if check_device(device):
68
+ device = return_xpu(device)
69
+ if isinstance(data, np.ndarray) and data.dtype == float and not (
70
+ (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
71
+ return original_as_tensor(data, dtype=torch.float32, device=device)
72
+ else:
73
+ return original_as_tensor(data, dtype=dtype, device=device)
74
+
75
+
76
+ if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
77
+ original_torch_bmm = torch.bmm
78
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
79
+ else:
80
+ # 32 bit attention workarounds for Alchemist:
81
+ try:
82
+ from .attention import torch_bmm_32_bit as original_torch_bmm
83
+ from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
84
+ except Exception: # pylint: disable=broad-exception-caught
85
+ original_torch_bmm = torch.bmm
86
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
87
+
88
+
89
+ # Data Type Errors:
90
+ @wraps(torch.bmm)
91
+ def torch_bmm(input, mat2, *, out=None):
92
+ if input.dtype != mat2.dtype:
93
+ mat2 = mat2.to(input.dtype)
94
+ return original_torch_bmm(input, mat2, out=out)
95
+
96
+ @wraps(torch.nn.functional.scaled_dot_product_attention)
97
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
98
+ if query.dtype != key.dtype:
99
+ key = key.to(dtype=query.dtype)
100
+ if query.dtype != value.dtype:
101
+ value = value.to(dtype=query.dtype)
102
+ if attn_mask is not None and query.dtype != attn_mask.dtype:
103
+ attn_mask = attn_mask.to(dtype=query.dtype)
104
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
105
+
106
+ # A1111 FP16
107
+ original_functional_group_norm = torch.nn.functional.group_norm
108
+ @wraps(torch.nn.functional.group_norm)
109
+ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
110
+ if weight is not None and input.dtype != weight.data.dtype:
111
+ input = input.to(dtype=weight.data.dtype)
112
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
113
+ bias.data = bias.data.to(dtype=weight.data.dtype)
114
+ return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
115
+
116
+ # A1111 BF16
117
+ original_functional_layer_norm = torch.nn.functional.layer_norm
118
+ @wraps(torch.nn.functional.layer_norm)
119
+ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
120
+ if weight is not None and input.dtype != weight.data.dtype:
121
+ input = input.to(dtype=weight.data.dtype)
122
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
123
+ bias.data = bias.data.to(dtype=weight.data.dtype)
124
+ return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
125
+
126
+ # Training
127
+ original_functional_linear = torch.nn.functional.linear
128
+ @wraps(torch.nn.functional.linear)
129
+ def functional_linear(input, weight, bias=None):
130
+ if input.dtype != weight.data.dtype:
131
+ input = input.to(dtype=weight.data.dtype)
132
+ if bias is not None and bias.data.dtype != weight.data.dtype:
133
+ bias.data = bias.data.to(dtype=weight.data.dtype)
134
+ return original_functional_linear(input, weight, bias=bias)
135
+
136
+ original_functional_conv2d = torch.nn.functional.conv2d
137
+ @wraps(torch.nn.functional.conv2d)
138
+ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
139
+ if input.dtype != weight.data.dtype:
140
+ input = input.to(dtype=weight.data.dtype)
141
+ if bias is not None and bias.data.dtype != weight.data.dtype:
142
+ bias.data = bias.data.to(dtype=weight.data.dtype)
143
+ return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
144
+
145
+ # A1111 Embedding BF16
146
+ original_torch_cat = torch.cat
147
+ @wraps(torch.cat)
148
+ def torch_cat(tensor, *args, **kwargs):
149
+ if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
150
+ return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
151
+ else:
152
+ return original_torch_cat(tensor, *args, **kwargs)
153
+
154
+ # SwinIR BF16:
155
+ original_functional_pad = torch.nn.functional.pad
156
+ @wraps(torch.nn.functional.pad)
157
+ def functional_pad(input, pad, mode='constant', value=None):
158
+ if mode == 'reflect' and input.dtype == torch.bfloat16:
159
+ return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
160
+ else:
161
+ return original_functional_pad(input, pad, mode=mode, value=value)
162
+
163
+
164
+ original_torch_tensor = torch.tensor
165
+ @wraps(torch.tensor)
166
+ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
167
+ if check_device(device):
168
+ device = return_xpu(device)
169
+ if not device_supports_fp64:
170
+ if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
171
+ if dtype == torch.float64:
172
+ dtype = torch.float32
173
+ elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
174
+ dtype = torch.float32
175
+ return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
176
+
177
+ original_Tensor_to = torch.Tensor.to
178
+ @wraps(torch.Tensor.to)
179
+ def Tensor_to(self, device=None, *args, **kwargs):
180
+ if check_device(device):
181
+ return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
182
+ else:
183
+ return original_Tensor_to(self, device, *args, **kwargs)
184
+
185
+ original_Tensor_cuda = torch.Tensor.cuda
186
+ @wraps(torch.Tensor.cuda)
187
+ def Tensor_cuda(self, device=None, *args, **kwargs):
188
+ if check_device(device):
189
+ return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
190
+ else:
191
+ return original_Tensor_cuda(self, device, *args, **kwargs)
192
+
193
+ original_Tensor_pin_memory = torch.Tensor.pin_memory
194
+ @wraps(torch.Tensor.pin_memory)
195
+ def Tensor_pin_memory(self, device=None, *args, **kwargs):
196
+ if device is None:
197
+ device = "xpu"
198
+ if check_device(device):
199
+ return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
200
+ else:
201
+ return original_Tensor_pin_memory(self, device, *args, **kwargs)
202
+
203
+ original_UntypedStorage_init = torch.UntypedStorage.__init__
204
+ @wraps(torch.UntypedStorage.__init__)
205
+ def UntypedStorage_init(*args, device=None, **kwargs):
206
+ if check_device(device):
207
+ return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
208
+ else:
209
+ return original_UntypedStorage_init(*args, device=device, **kwargs)
210
+
211
+ original_UntypedStorage_cuda = torch.UntypedStorage.cuda
212
+ @wraps(torch.UntypedStorage.cuda)
213
+ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
214
+ if check_device(device):
215
+ return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
216
+ else:
217
+ return original_UntypedStorage_cuda(self, device, *args, **kwargs)
218
+
219
+ original_torch_empty = torch.empty
220
+ @wraps(torch.empty)
221
+ def torch_empty(*args, device=None, **kwargs):
222
+ if check_device(device):
223
+ return original_torch_empty(*args, device=return_xpu(device), **kwargs)
224
+ else:
225
+ return original_torch_empty(*args, device=device, **kwargs)
226
+
227
+ original_torch_randn = torch.randn
228
+ @wraps(torch.randn)
229
+ def torch_randn(*args, device=None, dtype=None, **kwargs):
230
+ if dtype == bytes:
231
+ dtype = None
232
+ if check_device(device):
233
+ return original_torch_randn(*args, device=return_xpu(device), **kwargs)
234
+ else:
235
+ return original_torch_randn(*args, device=device, **kwargs)
236
+
237
+ original_torch_ones = torch.ones
238
+ @wraps(torch.ones)
239
+ def torch_ones(*args, device=None, **kwargs):
240
+ if check_device(device):
241
+ return original_torch_ones(*args, device=return_xpu(device), **kwargs)
242
+ else:
243
+ return original_torch_ones(*args, device=device, **kwargs)
244
+
245
+ original_torch_zeros = torch.zeros
246
+ @wraps(torch.zeros)
247
+ def torch_zeros(*args, device=None, **kwargs):
248
+ if check_device(device):
249
+ return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
250
+ else:
251
+ return original_torch_zeros(*args, device=device, **kwargs)
252
+
253
+ original_torch_linspace = torch.linspace
254
+ @wraps(torch.linspace)
255
+ def torch_linspace(*args, device=None, **kwargs):
256
+ if check_device(device):
257
+ return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
258
+ else:
259
+ return original_torch_linspace(*args, device=device, **kwargs)
260
+
261
+ original_torch_Generator = torch.Generator
262
+ @wraps(torch.Generator)
263
+ def torch_Generator(device=None):
264
+ if check_device(device):
265
+ return original_torch_Generator(return_xpu(device))
266
+ else:
267
+ return original_torch_Generator(device)
268
+
269
+ original_torch_load = torch.load
270
+ @wraps(torch.load)
271
+ def torch_load(f, map_location=None, *args, **kwargs):
272
+ if map_location is None:
273
+ map_location = "xpu"
274
+ if check_device(map_location):
275
+ return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
276
+ else:
277
+ return original_torch_load(f, *args, map_location=map_location, **kwargs)
278
+
279
+
280
+ # Hijack Functions:
281
+ def ipex_hijacks():
282
+ torch.tensor = torch_tensor
283
+ torch.Tensor.to = Tensor_to
284
+ torch.Tensor.cuda = Tensor_cuda
285
+ torch.Tensor.pin_memory = Tensor_pin_memory
286
+ torch.UntypedStorage.__init__ = UntypedStorage_init
287
+ torch.UntypedStorage.cuda = UntypedStorage_cuda
288
+ torch.empty = torch_empty
289
+ torch.randn = torch_randn
290
+ torch.ones = torch_ones
291
+ torch.zeros = torch_zeros
292
+ torch.linspace = torch_linspace
293
+ torch.Generator = torch_Generator
294
+ torch.load = torch_load
295
+
296
+ torch.backends.cuda.sdp_kernel = return_null_context
297
+ torch.nn.DataParallel = DummyDataParallel
298
+ torch.UntypedStorage.is_cuda = is_cuda
299
+ torch.amp.autocast_mode.autocast.__init__ = autocast_init
300
+
301
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
302
+ torch.nn.functional.group_norm = functional_group_norm
303
+ torch.nn.functional.layer_norm = functional_layer_norm
304
+ torch.nn.functional.linear = functional_linear
305
+ torch.nn.functional.conv2d = functional_conv2d
306
+ torch.nn.functional.interpolate = interpolate
307
+ torch.nn.functional.pad = functional_pad
308
+
309
+ torch.bmm = torch_bmm
310
+ torch.cat = torch_cat
311
+ if not device_supports_fp64:
312
+ torch.from_numpy = from_numpy
313
+ torch.as_tensor = as_tensor
library/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+ try:
21
+ from diffusers.utils import PIL_INTERPOLATION
22
+ except ImportError:
23
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
24
+ PIL_INTERPOLATION = {
25
+ "linear": PIL.Image.Resampling.BILINEAR,
26
+ "bilinear": PIL.Image.Resampling.BILINEAR,
27
+ "bicubic": PIL.Image.Resampling.BICUBIC,
28
+ "lanczos": PIL.Image.Resampling.LANCZOS,
29
+ "nearest": PIL.Image.Resampling.NEAREST,
30
+ }
31
+ else:
32
+ PIL_INTERPOLATION = {
33
+ "linear": PIL.Image.LINEAR,
34
+ "bilinear": PIL.Image.BILINEAR,
35
+ "bicubic": PIL.Image.BICUBIC,
36
+ "lanczos": PIL.Image.LANCZOS,
37
+ "nearest": PIL.Image.NEAREST,
38
+ }
39
+ # ------------------------------------------------------------------------------
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ \( - literal character '('
71
+ \[ - literal character '['
72
+ \) - literal character ')'
73
+ \] - literal character ']'
74
+ \\ - literal character '\'
75
+ anything else - just text
76
+ >>> parse_prompt_attention('normal text')
77
+ [['normal text', 1.0]]
78
+ >>> parse_prompt_attention('an (important) word')
79
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
80
+ >>> parse_prompt_attention('(unbalanced')
81
+ [['unbalanced', 1.1]]
82
+ >>> parse_prompt_attention('\(literal\]')
83
+ [['(literal]', 1.0]]
84
+ >>> parse_prompt_attention('(unnecessary)(parens)')
85
+ [['unnecessaryparens', 1.1]]
86
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
87
+ [['a ', 1.0],
88
+ ['house', 1.5730000000000004],
89
+ [' ', 1.1],
90
+ ['on', 1.0],
91
+ [' a ', 1.1],
92
+ ['hill', 0.55],
93
+ [', sun, ', 1.1],
94
+ ['sky', 1.4641000000000006],
95
+ ['.', 1.1]]
96
+ """
97
+
98
+ res = []
99
+ round_brackets = []
100
+ square_brackets = []
101
+
102
+ round_bracket_multiplier = 1.1
103
+ square_bracket_multiplier = 1 / 1.1
104
+
105
+ def multiply_range(start_position, multiplier):
106
+ for p in range(start_position, len(res)):
107
+ res[p][1] *= multiplier
108
+
109
+ for m in re_attention.finditer(text):
110
+ text = m.group(0)
111
+ weight = m.group(1)
112
+
113
+ if text.startswith("\\"):
114
+ res.append([text[1:], 1.0])
115
+ elif text == "(":
116
+ round_brackets.append(len(res))
117
+ elif text == "[":
118
+ square_brackets.append(len(res))
119
+ elif weight is not None and len(round_brackets) > 0:
120
+ multiply_range(round_brackets.pop(), float(weight))
121
+ elif text == ")" and len(round_brackets) > 0:
122
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
123
+ elif text == "]" and len(square_brackets) > 0:
124
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
125
+ else:
126
+ res.append([text, 1.0])
127
+
128
+ for pos in round_brackets:
129
+ multiply_range(pos, round_bracket_multiplier)
130
+
131
+ for pos in square_brackets:
132
+ multiply_range(pos, square_bracket_multiplier)
133
+
134
+ if len(res) == 0:
135
+ res = [["", 1.0]]
136
+
137
+ # merge runs of identical weights
138
+ i = 0
139
+ while i + 1 < len(res):
140
+ if res[i][1] == res[i + 1][1]:
141
+ res[i][0] += res[i + 1][0]
142
+ res.pop(i + 1)
143
+ else:
144
+ i += 1
145
+
146
+ return res
147
+
148
+
149
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
150
+ r"""
151
+ Tokenize a list of prompts and return its tokens with weights of each token.
152
+
153
+ No padding, starting or ending token is included.
154
+ """
155
+ tokens = []
156
+ weights = []
157
+ truncated = False
158
+ for text in prompt:
159
+ texts_and_weights = parse_prompt_attention(text)
160
+ text_token = []
161
+ text_weight = []
162
+ for word, weight in texts_and_weights:
163
+ # tokenize and discard the starting and the ending token
164
+ token = pipe.tokenizer(word).input_ids[1:-1]
165
+ text_token += token
166
+ # copy the weight by length of token
167
+ text_weight += [weight] * len(token)
168
+ # stop if the text is too long (longer than truncation limit)
169
+ if len(text_token) > max_length:
170
+ truncated = True
171
+ break
172
+ # truncate
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ text_token = text_token[:max_length]
176
+ text_weight = text_weight[:max_length]
177
+ tokens.append(text_token)
178
+ weights.append(text_weight)
179
+ if truncated:
180
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
181
+ return tokens, weights
182
+
183
+
184
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
192
+ if no_boseos_middle:
193
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
194
+ else:
195
+ w = []
196
+ if len(weights[i]) == 0:
197
+ w = [1.0] * weights_length
198
+ else:
199
+ for j in range(max_embeddings_multiples):
200
+ w.append(1.0) # weight for starting token in this chunk
201
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
202
+ w.append(1.0) # weight for ending token in this chunk
203
+ w += [1.0] * (weights_length - len(w))
204
+ weights[i] = w[:]
205
+
206
+ return tokens, weights
207
+
208
+
209
+ def get_unweighted_text_embeddings(
210
+ pipe: StableDiffusionPipeline,
211
+ text_input: torch.Tensor,
212
+ chunk_length: int,
213
+ clip_skip: int,
214
+ eos: int,
215
+ pad: int,
216
+ no_boseos_middle: Optional[bool] = True,
217
+ ):
218
+ """
219
+ When the length of tokens is a multiple of the capacity of the text encoder,
220
+ it should be split into chunks and sent to the text encoder individually.
221
+ """
222
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
223
+ if max_embeddings_multiples > 1:
224
+ text_embeddings = []
225
+ for i in range(max_embeddings_multiples):
226
+ # extract the i-th chunk
227
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
228
+
229
+ # cover the head and the tail by the starting and the ending tokens
230
+ text_input_chunk[:, 0] = text_input[0, 0]
231
+ if pad == eos: # v1
232
+ text_input_chunk[:, -1] = text_input[0, -1]
233
+ else: # v2
234
+ for j in range(len(text_input_chunk)):
235
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
236
+ text_input_chunk[j, -1] = eos
237
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
238
+ text_input_chunk[j, 1] = eos
239
+
240
+ if clip_skip is None or clip_skip == 1:
241
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
242
+ else:
243
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
244
+ text_embedding = enc_out["hidden_states"][-clip_skip]
245
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
246
+
247
+ if no_boseos_middle:
248
+ if i == 0:
249
+ # discard the ending token
250
+ text_embedding = text_embedding[:, :-1]
251
+ elif i == max_embeddings_multiples - 1:
252
+ # discard the starting token
253
+ text_embedding = text_embedding[:, 1:]
254
+ else:
255
+ # discard both starting and ending tokens
256
+ text_embedding = text_embedding[:, 1:-1]
257
+
258
+ text_embeddings.append(text_embedding)
259
+ text_embeddings = torch.concat(text_embeddings, axis=1)
260
+ else:
261
+ if clip_skip is None or clip_skip == 1:
262
+ text_embeddings = pipe.text_encoder(text_input)[0]
263
+ else:
264
+ enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
265
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
266
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
267
+ return text_embeddings
268
+
269
+
270
+ def get_weighted_text_embeddings(
271
+ pipe: StableDiffusionPipeline,
272
+ prompt: Union[str, List[str]],
273
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
274
+ max_embeddings_multiples: Optional[int] = 3,
275
+ no_boseos_middle: Optional[bool] = False,
276
+ skip_parsing: Optional[bool] = False,
277
+ skip_weighting: Optional[bool] = False,
278
+ clip_skip=None,
279
+ ):
280
+ r"""
281
+ Prompts can be assigned with local weights using brackets. For example,
282
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
283
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
284
+
285
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
286
+
287
+ Args:
288
+ pipe (`StableDiffusionPipeline`):
289
+ Pipe to provide access to the tokenizer and the text encoder.
290
+ prompt (`str` or `List[str]`):
291
+ The prompt or prompts to guide the image generation.
292
+ uncond_prompt (`str` or `List[str]`):
293
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
294
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
295
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
296
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
297
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
298
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
299
+ ending token in each of the chunk in the middle.
300
+ skip_parsing (`bool`, *optional*, defaults to `False`):
301
+ Skip the parsing of brackets.
302
+ skip_weighting (`bool`, *optional*, defaults to `False`):
303
+ Skip the weighting. When the parsing is skipped, it is forced True.
304
+ """
305
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
306
+ if isinstance(prompt, str):
307
+ prompt = [prompt]
308
+
309
+ if not skip_parsing:
310
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
311
+ if uncond_prompt is not None:
312
+ if isinstance(uncond_prompt, str):
313
+ uncond_prompt = [uncond_prompt]
314
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
315
+ else:
316
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
317
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
318
+ if uncond_prompt is not None:
319
+ if isinstance(uncond_prompt, str):
320
+ uncond_prompt = [uncond_prompt]
321
+ uncond_tokens = [
322
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
323
+ ]
324
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
325
+
326
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
327
+ max_length = max([len(token) for token in prompt_tokens])
328
+ if uncond_prompt is not None:
329
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
330
+
331
+ max_embeddings_multiples = min(
332
+ max_embeddings_multiples,
333
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
334
+ )
335
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
336
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
337
+
338
+ # pad the length of tokens and weights
339
+ bos = pipe.tokenizer.bos_token_id
340
+ eos = pipe.tokenizer.eos_token_id
341
+ pad = pipe.tokenizer.pad_token_id
342
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
343
+ prompt_tokens,
344
+ prompt_weights,
345
+ max_length,
346
+ bos,
347
+ eos,
348
+ no_boseos_middle=no_boseos_middle,
349
+ chunk_length=pipe.tokenizer.model_max_length,
350
+ )
351
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
352
+ if uncond_prompt is not None:
353
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
354
+ uncond_tokens,
355
+ uncond_weights,
356
+ max_length,
357
+ bos,
358
+ eos,
359
+ no_boseos_middle=no_boseos_middle,
360
+ chunk_length=pipe.tokenizer.model_max_length,
361
+ )
362
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
363
+
364
+ # get the embeddings
365
+ text_embeddings = get_unweighted_text_embeddings(
366
+ pipe,
367
+ prompt_tokens,
368
+ pipe.tokenizer.model_max_length,
369
+ clip_skip,
370
+ eos,
371
+ pad,
372
+ no_boseos_middle=no_boseos_middle,
373
+ )
374
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
375
+ if uncond_prompt is not None:
376
+ uncond_embeddings = get_unweighted_text_embeddings(
377
+ pipe,
378
+ uncond_tokens,
379
+ pipe.tokenizer.model_max_length,
380
+ clip_skip,
381
+ eos,
382
+ pad,
383
+ no_boseos_middle=no_boseos_middle,
384
+ )
385
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
386
+
387
+ # assign weights to the prompts and normalize in the sense of mean
388
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
389
+ if (not skip_parsing) and (not skip_weighting):
390
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
391
+ text_embeddings *= prompt_weights.unsqueeze(-1)
392
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
394
+ if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
396
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
397
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
398
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
399
+
400
+ if uncond_prompt is not None:
401
+ return text_embeddings, uncond_embeddings
402
+ return text_embeddings, None
403
+
404
+
405
+ def preprocess_image(image):
406
+ w, h = image.size
407
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
408
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
409
+ image = np.array(image).astype(np.float32) / 255.0
410
+ image = image[None].transpose(0, 3, 1, 2)
411
+ image = torch.from_numpy(image)
412
+ return 2.0 * image - 1.0
413
+
414
+
415
+ def preprocess_mask(mask, scale_factor=8):
416
+ mask = mask.convert("L")
417
+ w, h = mask.size
418
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
419
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
420
+ mask = np.array(mask).astype(np.float32) / 255.0
421
+ mask = np.tile(mask, (4, 1, 1))
422
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
423
+ mask = 1 - mask # repaint white, keep black
424
+ mask = torch.from_numpy(mask)
425
+ return mask
426
+
427
+
428
+ def prepare_controlnet_image(
429
+ image: PIL.Image.Image,
430
+ width: int,
431
+ height: int,
432
+ batch_size: int,
433
+ num_images_per_prompt: int,
434
+ device: torch.device,
435
+ dtype: torch.dtype,
436
+ do_classifier_free_guidance: bool = False,
437
+ guess_mode: bool = False,
438
+ ):
439
+ if not isinstance(image, torch.Tensor):
440
+ if isinstance(image, PIL.Image.Image):
441
+ image = [image]
442
+
443
+ if isinstance(image[0], PIL.Image.Image):
444
+ images = []
445
+
446
+ for image_ in image:
447
+ image_ = image_.convert("RGB")
448
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
449
+ image_ = np.array(image_)
450
+ image_ = image_[None, :]
451
+ images.append(image_)
452
+
453
+ image = images
454
+
455
+ image = np.concatenate(image, axis=0)
456
+ image = np.array(image).astype(np.float32) / 255.0
457
+ image = image.transpose(0, 3, 1, 2)
458
+ image = torch.from_numpy(image)
459
+ elif isinstance(image[0], torch.Tensor):
460
+ image = torch.cat(image, dim=0)
461
+
462
+ image_batch_size = image.shape[0]
463
+
464
+ if image_batch_size == 1:
465
+ repeat_by = batch_size
466
+ else:
467
+ # image batch size is the same as prompt batch size
468
+ repeat_by = num_images_per_prompt
469
+
470
+ image = image.repeat_interleave(repeat_by, dim=0)
471
+
472
+ image = image.to(device=device, dtype=dtype)
473
+
474
+ if do_classifier_free_guidance and not guess_mode:
475
+ image = torch.cat([image] * 2)
476
+
477
+ return image
478
+
479
+
480
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
481
+ r"""
482
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
483
+ weighting in prompt.
484
+
485
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
486
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
487
+
488
+ Args:
489
+ vae ([`AutoencoderKL`]):
490
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
491
+ text_encoder ([`CLIPTextModel`]):
492
+ Frozen text-encoder. Stable Diffusion uses the text portion of
493
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
494
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
495
+ tokenizer (`CLIPTokenizer`):
496
+ Tokenizer of class
497
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
498
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
499
+ scheduler ([`SchedulerMixin`]):
500
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
501
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
502
+ safety_checker ([`StableDiffusionSafetyChecker`]):
503
+ Classification module that estimates whether generated images could be considered offensive or harmful.
504
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
505
+ feature_extractor ([`CLIPFeatureExtractor`]):
506
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
507
+ """
508
+
509
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
510
+
511
+ def __init__(
512
+ self,
513
+ vae: AutoencoderKL,
514
+ text_encoder: CLIPTextModel,
515
+ tokenizer: CLIPTokenizer,
516
+ unet: UNet2DConditionModel,
517
+ scheduler: SchedulerMixin,
518
+ # clip_skip: int,
519
+ safety_checker: StableDiffusionSafetyChecker,
520
+ feature_extractor: CLIPFeatureExtractor,
521
+ requires_safety_checker: bool = True,
522
+ image_encoder: CLIPVisionModelWithProjection = None,
523
+ clip_skip: int = 1,
524
+ ):
525
+ super().__init__(
526
+ vae=vae,
527
+ text_encoder=text_encoder,
528
+ tokenizer=tokenizer,
529
+ unet=unet,
530
+ scheduler=scheduler,
531
+ safety_checker=safety_checker,
532
+ feature_extractor=feature_extractor,
533
+ requires_safety_checker=requires_safety_checker,
534
+ image_encoder=image_encoder,
535
+ )
536
+ self.custom_clip_skip = clip_skip
537
+ self.__init__additional__()
538
+
539
+ def __init__additional__(self):
540
+ if not hasattr(self, "vae_scale_factor"):
541
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
542
+
543
+ @property
544
+ def _execution_device(self):
545
+ r"""
546
+ Returns the device on which the pipeline's models will be executed. After calling
547
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
548
+ hooks.
549
+ """
550
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
551
+ return self.device
552
+ for module in self.unet.modules():
553
+ if (
554
+ hasattr(module, "_hf_hook")
555
+ and hasattr(module._hf_hook, "execution_device")
556
+ and module._hf_hook.execution_device is not None
557
+ ):
558
+ return torch.device(module._hf_hook.execution_device)
559
+ return self.device
560
+
561
+ def _encode_prompt(
562
+ self,
563
+ prompt,
564
+ device,
565
+ num_images_per_prompt,
566
+ do_classifier_free_guidance,
567
+ negative_prompt,
568
+ max_embeddings_multiples,
569
+ ):
570
+ r"""
571
+ Encodes the prompt into text encoder hidden states.
572
+
573
+ Args:
574
+ prompt (`str` or `list(int)`):
575
+ prompt to be encoded
576
+ device: (`torch.device`):
577
+ torch device
578
+ num_images_per_prompt (`int`):
579
+ number of images that should be generated per prompt
580
+ do_classifier_free_guidance (`bool`):
581
+ whether to use classifier free guidance or not
582
+ negative_prompt (`str` or `List[str]`):
583
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
584
+ if `guidance_scale` is less than `1`).
585
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
586
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
587
+ """
588
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
589
+
590
+ if negative_prompt is None:
591
+ negative_prompt = [""] * batch_size
592
+ elif isinstance(negative_prompt, str):
593
+ negative_prompt = [negative_prompt] * batch_size
594
+ if batch_size != len(negative_prompt):
595
+ raise ValueError(
596
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
597
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
598
+ " the batch size of `prompt`."
599
+ )
600
+
601
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
602
+ pipe=self,
603
+ prompt=prompt,
604
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
605
+ max_embeddings_multiples=max_embeddings_multiples,
606
+ clip_skip=self.custom_clip_skip,
607
+ )
608
+ bs_embed, seq_len, _ = text_embeddings.shape
609
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
610
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
611
+
612
+ if do_classifier_free_guidance:
613
+ bs_embed, seq_len, _ = uncond_embeddings.shape
614
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
615
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
616
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
617
+
618
+ return text_embeddings
619
+
620
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
621
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
622
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
623
+
624
+ if strength < 0 or strength > 1:
625
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
626
+
627
+ if height % 8 != 0 or width % 8 != 0:
628
+ logger.info(f'{height} {width}')
629
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
630
+
631
+ if (callback_steps is None) or (
632
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
633
+ ):
634
+ raise ValueError(
635
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
636
+ )
637
+
638
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
639
+ if is_text2img:
640
+ return self.scheduler.timesteps.to(device), num_inference_steps
641
+ else:
642
+ # get the original timestep using init_timestep
643
+ offset = self.scheduler.config.get("steps_offset", 0)
644
+ init_timestep = int(num_inference_steps * strength) + offset
645
+ init_timestep = min(init_timestep, num_inference_steps)
646
+
647
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
648
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
649
+ return timesteps, num_inference_steps - t_start
650
+
651
+ def run_safety_checker(self, image, device, dtype):
652
+ if self.safety_checker is not None:
653
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
654
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
655
+ else:
656
+ has_nsfw_concept = None
657
+ return image, has_nsfw_concept
658
+
659
+ def decode_latents(self, latents):
660
+ latents = 1 / 0.18215 * latents
661
+ image = self.vae.decode(latents).sample
662
+ image = (image / 2 + 0.5).clamp(0, 1)
663
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
664
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
665
+ return image
666
+
667
+ def prepare_extra_step_kwargs(self, generator, eta):
668
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
669
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
670
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
671
+ # and should be between [0, 1]
672
+
673
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
674
+ extra_step_kwargs = {}
675
+ if accepts_eta:
676
+ extra_step_kwargs["eta"] = eta
677
+
678
+ # check if the scheduler accepts generator
679
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
680
+ if accepts_generator:
681
+ extra_step_kwargs["generator"] = generator
682
+ return extra_step_kwargs
683
+
684
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
685
+ if image is None:
686
+ shape = (
687
+ batch_size,
688
+ self.unet.in_channels,
689
+ height // self.vae_scale_factor,
690
+ width // self.vae_scale_factor,
691
+ )
692
+
693
+ if latents is None:
694
+ if device.type == "mps":
695
+ # randn does not work reproducibly on mps
696
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
697
+ else:
698
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
699
+ else:
700
+ if latents.shape != shape:
701
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
702
+ latents = latents.to(device)
703
+
704
+ # scale the initial noise by the standard deviation required by the scheduler
705
+ latents = latents * self.scheduler.init_noise_sigma
706
+ return latents, None, None
707
+ else:
708
+ init_latent_dist = self.vae.encode(image).latent_dist
709
+ init_latents = init_latent_dist.sample(generator=generator)
710
+ init_latents = 0.18215 * init_latents
711
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
712
+ init_latents_orig = init_latents
713
+ shape = init_latents.shape
714
+
715
+ # add noise to latents using the timesteps
716
+ if device.type == "mps":
717
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
718
+ else:
719
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
720
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
721
+ return latents, init_latents_orig, noise
722
+
723
+ @torch.no_grad()
724
+ def __call__(
725
+ self,
726
+ prompt: Union[str, List[str]],
727
+ negative_prompt: Optional[Union[str, List[str]]] = None,
728
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
729
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
730
+ height: int = 512,
731
+ width: int = 512,
732
+ num_inference_steps: int = 50,
733
+ guidance_scale: float = 7.5,
734
+ strength: float = 0.8,
735
+ num_images_per_prompt: Optional[int] = 1,
736
+ eta: float = 0.0,
737
+ generator: Optional[torch.Generator] = None,
738
+ latents: Optional[torch.FloatTensor] = None,
739
+ max_embeddings_multiples: Optional[int] = 3,
740
+ output_type: Optional[str] = "pil",
741
+ return_dict: bool = True,
742
+ controlnet=None,
743
+ controlnet_image=None,
744
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
745
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
746
+ callback_steps: int = 1,
747
+ ):
748
+ r"""
749
+ Function invoked when calling the pipeline for generation.
750
+
751
+ Args:
752
+ prompt (`str` or `List[str]`):
753
+ The prompt or prompts to guide the image generation.
754
+ negative_prompt (`str` or `List[str]`, *optional*):
755
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
756
+ if `guidance_scale` is less than `1`).
757
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
758
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
759
+ process.
760
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
761
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
762
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
763
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
764
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
765
+ height (`int`, *optional*, defaults to 512):
766
+ The height in pixels of the generated image.
767
+ width (`int`, *optional*, defaults to 512):
768
+ The width in pixels of the generated image.
769
+ num_inference_steps (`int`, *optional*, defaults to 50):
770
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
771
+ expense of slower inference.
772
+ guidance_scale (`float`, *optional*, defaults to 7.5):
773
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
775
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
776
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
777
+ usually at the expense of lower image quality.
778
+ strength (`float`, *optional*, defaults to 0.8):
779
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
780
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
781
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
782
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
783
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
784
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
785
+ The number of images to generate per prompt.
786
+ eta (`float`, *optional*, defaults to 0.0):
787
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
788
+ [`schedulers.DDIMScheduler`], will be ignored for others.
789
+ generator (`torch.Generator`, *optional*):
790
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
791
+ deterministic.
792
+ latents (`torch.FloatTensor`, *optional*):
793
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
794
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
795
+ tensor will ge generated by sampling using the supplied random `generator`.
796
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
797
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
798
+ output_type (`str`, *optional*, defaults to `"pil"`):
799
+ The output format of the generate image. Choose between
800
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
803
+ plain tuple.
804
+ controlnet (`diffusers.ControlNetModel`, *optional*):
805
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
806
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
807
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
808
+ inference.
809
+ callback (`Callable`, *optional*):
810
+ A function that will be called every `callback_steps` steps during inference. The function will be
811
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
812
+ is_cancelled_callback (`Callable`, *optional*):
813
+ A function that will be called every `callback_steps` steps during inference. If the function returns
814
+ `True`, the inference will be cancelled.
815
+ callback_steps (`int`, *optional*, defaults to 1):
816
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
817
+ called at every step.
818
+
819
+ Returns:
820
+ `None` if cancelled by `is_cancelled_callback`,
821
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
822
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
823
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
824
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
825
+ (nsfw) content, according to the `safety_checker`.
826
+ """
827
+ if controlnet is not None and controlnet_image is None:
828
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
829
+
830
+ # 0. Default height and width to unet
831
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
832
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
833
+
834
+ # 1. Check inputs. Raise error if not correct
835
+ self.check_inputs(prompt, height, width, strength, callback_steps)
836
+
837
+ # 2. Define call parameters
838
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
839
+ device = self._execution_device
840
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
841
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
842
+ # corresponds to doing no classifier free guidance.
843
+ do_classifier_free_guidance = guidance_scale > 1.0
844
+
845
+ # 3. Encode input prompt
846
+ text_embeddings = self._encode_prompt(
847
+ prompt,
848
+ device,
849
+ num_images_per_prompt,
850
+ do_classifier_free_guidance,
851
+ negative_prompt,
852
+ max_embeddings_multiples,
853
+ )
854
+ dtype = text_embeddings.dtype
855
+
856
+ # 4. Preprocess image and mask
857
+ if isinstance(image, PIL.Image.Image):
858
+ image = preprocess_image(image)
859
+ if image is not None:
860
+ image = image.to(device=self.device, dtype=dtype)
861
+ if isinstance(mask_image, PIL.Image.Image):
862
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
863
+ if mask_image is not None:
864
+ mask = mask_image.to(device=self.device, dtype=dtype)
865
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
866
+ else:
867
+ mask = None
868
+
869
+ if controlnet_image is not None:
870
+ controlnet_image = prepare_controlnet_image(
871
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
872
+ )
873
+
874
+ # 5. set timesteps
875
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
876
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
877
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
878
+
879
+ # 6. Prepare latent variables
880
+ latents, init_latents_orig, noise = self.prepare_latents(
881
+ image,
882
+ latent_timestep,
883
+ batch_size * num_images_per_prompt,
884
+ height,
885
+ width,
886
+ dtype,
887
+ device,
888
+ generator,
889
+ latents,
890
+ )
891
+
892
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
893
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
894
+
895
+ # 8. Denoising loop
896
+ for i, t in enumerate(self.progress_bar(timesteps)):
897
+ # expand the latents if we are doing classifier free guidance
898
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
899
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
900
+
901
+ unet_additional_args = {}
902
+ if controlnet is not None:
903
+ down_block_res_samples, mid_block_res_sample = controlnet(
904
+ latent_model_input,
905
+ t,
906
+ encoder_hidden_states=text_embeddings,
907
+ controlnet_cond=controlnet_image,
908
+ conditioning_scale=1.0,
909
+ guess_mode=False,
910
+ return_dict=False,
911
+ )
912
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
913
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
914
+
915
+ # predict the noise residual
916
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
917
+
918
+ # perform guidance
919
+ if do_classifier_free_guidance:
920
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
921
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
922
+
923
+ # compute the previous noisy sample x_t -> x_t-1
924
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
925
+
926
+ if mask is not None:
927
+ # masking
928
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
929
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
930
+
931
+ # call the callback, if provided
932
+ if i % callback_steps == 0:
933
+ if callback is not None:
934
+ callback(i, t, latents)
935
+ if is_cancelled_callback is not None and is_cancelled_callback():
936
+ return None
937
+
938
+ return latents
939
+
940
+ def latents_to_image(self, latents):
941
+ # 9. Post-processing
942
+ image = self.decode_latents(latents.to(self.vae.dtype))
943
+ image = self.numpy_to_pil(image)
944
+ return image
945
+
946
+ def text2img(
947
+ self,
948
+ prompt: Union[str, List[str]],
949
+ negative_prompt: Optional[Union[str, List[str]]] = None,
950
+ height: int = 512,
951
+ width: int = 512,
952
+ num_inference_steps: int = 50,
953
+ guidance_scale: float = 7.5,
954
+ num_images_per_prompt: Optional[int] = 1,
955
+ eta: float = 0.0,
956
+ generator: Optional[torch.Generator] = None,
957
+ latents: Optional[torch.FloatTensor] = None,
958
+ max_embeddings_multiples: Optional[int] = 3,
959
+ output_type: Optional[str] = "pil",
960
+ return_dict: bool = True,
961
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
962
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
963
+ callback_steps: int = 1,
964
+ ):
965
+ r"""
966
+ Function for text-to-image generation.
967
+ Args:
968
+ prompt (`str` or `List[str]`):
969
+ The prompt or prompts to guide the image generation.
970
+ negative_prompt (`str` or `List[str]`, *optional*):
971
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
972
+ if `guidance_scale` is less than `1`).
973
+ height (`int`, *optional*, defaults to 512):
974
+ The height in pixels of the generated image.
975
+ width (`int`, *optional*, defaults to 512):
976
+ The width in pixels of the generated image.
977
+ num_inference_steps (`int`, *optional*, defaults to 50):
978
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
979
+ expense of slower inference.
980
+ guidance_scale (`float`, *optional*, defaults to 7.5):
981
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
982
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
983
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
984
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
985
+ usually at the expense of lower image quality.
986
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
987
+ The number of images to generate per prompt.
988
+ eta (`float`, *optional*, defaults to 0.0):
989
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
990
+ [`schedulers.DDIMScheduler`], will be ignored for others.
991
+ generator (`torch.Generator`, *optional*):
992
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
993
+ deterministic.
994
+ latents (`torch.FloatTensor`, *optional*):
995
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
996
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
997
+ tensor will ge generated by sampling using the supplied random `generator`.
998
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
999
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1000
+ output_type (`str`, *optional*, defaults to `"pil"`):
1001
+ The output format of the generate image. Choose between
1002
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1003
+ return_dict (`bool`, *optional*, defaults to `True`):
1004
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1005
+ plain tuple.
1006
+ callback (`Callable`, *optional*):
1007
+ A function that will be called every `callback_steps` steps during inference. The function will be
1008
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1009
+ is_cancelled_callback (`Callable`, *optional*):
1010
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1011
+ `True`, the inference will be cancelled.
1012
+ callback_steps (`int`, *optional*, defaults to 1):
1013
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1014
+ called at every step.
1015
+ Returns:
1016
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1017
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1018
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1019
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1020
+ (nsfw) content, according to the `safety_checker`.
1021
+ """
1022
+ return self.__call__(
1023
+ prompt=prompt,
1024
+ negative_prompt=negative_prompt,
1025
+ height=height,
1026
+ width=width,
1027
+ num_inference_steps=num_inference_steps,
1028
+ guidance_scale=guidance_scale,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ eta=eta,
1031
+ generator=generator,
1032
+ latents=latents,
1033
+ max_embeddings_multiples=max_embeddings_multiples,
1034
+ output_type=output_type,
1035
+ return_dict=return_dict,
1036
+ callback=callback,
1037
+ is_cancelled_callback=is_cancelled_callback,
1038
+ callback_steps=callback_steps,
1039
+ )
1040
+
1041
+ def img2img(
1042
+ self,
1043
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1044
+ prompt: Union[str, List[str]],
1045
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1046
+ strength: float = 0.8,
1047
+ num_inference_steps: Optional[int] = 50,
1048
+ guidance_scale: Optional[float] = 7.5,
1049
+ num_images_per_prompt: Optional[int] = 1,
1050
+ eta: Optional[float] = 0.0,
1051
+ generator: Optional[torch.Generator] = None,
1052
+ max_embeddings_multiples: Optional[int] = 3,
1053
+ output_type: Optional[str] = "pil",
1054
+ return_dict: bool = True,
1055
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1056
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1057
+ callback_steps: int = 1,
1058
+ ):
1059
+ r"""
1060
+ Function for image-to-image generation.
1061
+ Args:
1062
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1063
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1064
+ process.
1065
+ prompt (`str` or `List[str]`):
1066
+ The prompt or prompts to guide the image generation.
1067
+ negative_prompt (`str` or `List[str]`, *optional*):
1068
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1069
+ if `guidance_scale` is less than `1`).
1070
+ strength (`float`, *optional*, defaults to 0.8):
1071
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1072
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1073
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1074
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1075
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1076
+ num_inference_steps (`int`, *optional*, defaults to 50):
1077
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1078
+ expense of slower inference. This parameter will be modulated by `strength`.
1079
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1080
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1081
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1082
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1083
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1084
+ usually at the expense of lower image quality.
1085
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1086
+ The number of images to generate per prompt.
1087
+ eta (`float`, *optional*, defaults to 0.0):
1088
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1089
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1090
+ generator (`torch.Generator`, *optional*):
1091
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1092
+ deterministic.
1093
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1094
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1095
+ output_type (`str`, *optional*, defaults to `"pil"`):
1096
+ The output format of the generate image. Choose between
1097
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1098
+ return_dict (`bool`, *optional*, defaults to `True`):
1099
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1100
+ plain tuple.
1101
+ callback (`Callable`, *optional*):
1102
+ A function that will be called every `callback_steps` steps during inference. The function will be
1103
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1104
+ is_cancelled_callback (`Callable`, *optional*):
1105
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1106
+ `True`, the inference will be cancelled.
1107
+ callback_steps (`int`, *optional*, defaults to 1):
1108
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1109
+ called at every step.
1110
+ Returns:
1111
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1112
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1113
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1114
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1115
+ (nsfw) content, according to the `safety_checker`.
1116
+ """
1117
+ return self.__call__(
1118
+ prompt=prompt,
1119
+ negative_prompt=negative_prompt,
1120
+ image=image,
1121
+ num_inference_steps=num_inference_steps,
1122
+ guidance_scale=guidance_scale,
1123
+ strength=strength,
1124
+ num_images_per_prompt=num_images_per_prompt,
1125
+ eta=eta,
1126
+ generator=generator,
1127
+ max_embeddings_multiples=max_embeddings_multiples,
1128
+ output_type=output_type,
1129
+ return_dict=return_dict,
1130
+ callback=callback,
1131
+ is_cancelled_callback=is_cancelled_callback,
1132
+ callback_steps=callback_steps,
1133
+ )
1134
+
1135
+ def inpaint(
1136
+ self,
1137
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1138
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1139
+ prompt: Union[str, List[str]],
1140
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1141
+ strength: float = 0.8,
1142
+ num_inference_steps: Optional[int] = 50,
1143
+ guidance_scale: Optional[float] = 7.5,
1144
+ num_images_per_prompt: Optional[int] = 1,
1145
+ eta: Optional[float] = 0.0,
1146
+ generator: Optional[torch.Generator] = None,
1147
+ max_embeddings_multiples: Optional[int] = 3,
1148
+ output_type: Optional[str] = "pil",
1149
+ return_dict: bool = True,
1150
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1151
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1152
+ callback_steps: int = 1,
1153
+ ):
1154
+ r"""
1155
+ Function for inpaint.
1156
+ Args:
1157
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1158
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1159
+ process. This is the image whose masked region will be inpainted.
1160
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1161
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1162
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1163
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1164
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1165
+ prompt (`str` or `List[str]`):
1166
+ The prompt or prompts to guide the image generation.
1167
+ negative_prompt (`str` or `List[str]`, *optional*):
1168
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1169
+ if `guidance_scale` is less than `1`).
1170
+ strength (`float`, *optional*, defaults to 0.8):
1171
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1172
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1173
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1174
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1175
+ num_inference_steps (`int`, *optional*, defaults to 50):
1176
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1177
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1178
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1179
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1180
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1181
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1182
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1183
+ usually at the expense of lower image quality.
1184
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1185
+ The number of images to generate per prompt.
1186
+ eta (`float`, *optional*, defaults to 0.0):
1187
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1188
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1189
+ generator (`torch.Generator`, *optional*):
1190
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1191
+ deterministic.
1192
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1193
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1194
+ output_type (`str`, *optional*, defaults to `"pil"`):
1195
+ The output format of the generate image. Choose between
1196
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1197
+ return_dict (`bool`, *optional*, defaults to `True`):
1198
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1199
+ plain tuple.
1200
+ callback (`Callable`, *optional*):
1201
+ A function that will be called every `callback_steps` steps during inference. The function will be
1202
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1203
+ is_cancelled_callback (`Callable`, *optional*):
1204
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1205
+ `True`, the inference will be cancelled.
1206
+ callback_steps (`int`, *optional*, defaults to 1):
1207
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1208
+ called at every step.
1209
+ Returns:
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1212
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1213
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1214
+ (nsfw) content, according to the `safety_checker`.
1215
+ """
1216
+ return self.__call__(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ image=image,
1220
+ mask_image=mask_image,
1221
+ num_inference_steps=num_inference_steps,
1222
+ guidance_scale=guidance_scale,
1223
+ strength=strength,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ eta=eta,
1226
+ generator=generator,
1227
+ max_embeddings_multiples=max_embeddings_multiples,
1228
+ output_type=output_type,
1229
+ return_dict=return_dict,
1230
+ callback=callback,
1231
+ is_cancelled_callback=is_cancelled_callback,
1232
+ callback_steps=callback_steps,
1233
+ )
library/model_util.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ from library.device_utils import init_ipex
9
+ init_ipex()
10
+
11
+ import diffusers
12
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
13
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
14
+ from safetensors.torch import load_file, save_file
15
+ from library.original_unet import UNet2DConditionModel
16
+ from library.utils import setup_logging
17
+ setup_logging()
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # DiffUsers版StableDiffusionのモデルパラメータ
22
+ NUM_TRAIN_TIMESTEPS = 1000
23
+ BETA_START = 0.00085
24
+ BETA_END = 0.0120
25
+
26
+ UNET_PARAMS_MODEL_CHANNELS = 320
27
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
28
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
29
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
30
+ UNET_PARAMS_IN_CHANNELS = 4
31
+ UNET_PARAMS_OUT_CHANNELS = 4
32
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
33
+ UNET_PARAMS_CONTEXT_DIM = 768
34
+ UNET_PARAMS_NUM_HEADS = 8
35
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
36
+
37
+ VAE_PARAMS_Z_CHANNELS = 4
38
+ VAE_PARAMS_RESOLUTION = 256
39
+ VAE_PARAMS_IN_CHANNELS = 3
40
+ VAE_PARAMS_OUT_CH = 3
41
+ VAE_PARAMS_CH = 128
42
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
43
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
44
+
45
+ # V2
46
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
47
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
48
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
49
+
50
+ # Diffusersの設定を読み込むための参照モデル
51
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
52
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
53
+
54
+
55
+ # region StableDiffusion->Diffusersの変換コード
56
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
57
+
58
+
59
+ def shave_segments(path, n_shave_prefix_segments=1):
60
+ """
61
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
62
+ """
63
+ if n_shave_prefix_segments >= 0:
64
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
65
+ else:
66
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
67
+
68
+
69
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
70
+ """
71
+ Updates paths inside resnets to the new naming scheme (local renaming)
72
+ """
73
+ mapping = []
74
+ for old_item in old_list:
75
+ new_item = old_item.replace("in_layers.0", "norm1")
76
+ new_item = new_item.replace("in_layers.2", "conv1")
77
+
78
+ new_item = new_item.replace("out_layers.0", "norm2")
79
+ new_item = new_item.replace("out_layers.3", "conv2")
80
+
81
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
82
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
83
+
84
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
85
+
86
+ mapping.append({"old": old_item, "new": new_item})
87
+
88
+ return mapping
89
+
90
+
91
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
92
+ """
93
+ Updates paths inside resnets to the new naming scheme (local renaming)
94
+ """
95
+ mapping = []
96
+ for old_item in old_list:
97
+ new_item = old_item
98
+
99
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
100
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
101
+
102
+ mapping.append({"old": old_item, "new": new_item})
103
+
104
+ return mapping
105
+
106
+
107
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
108
+ """
109
+ Updates paths inside attentions to the new naming scheme (local renaming)
110
+ """
111
+ mapping = []
112
+ for old_item in old_list:
113
+ new_item = old_item
114
+
115
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
116
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
117
+
118
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
119
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
120
+
121
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
122
+
123
+ mapping.append({"old": old_item, "new": new_item})
124
+
125
+ return mapping
126
+
127
+
128
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
129
+ """
130
+ Updates paths inside attentions to the new naming scheme (local renaming)
131
+ """
132
+ mapping = []
133
+ for old_item in old_list:
134
+ new_item = old_item
135
+
136
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
137
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
138
+
139
+ if diffusers.__version__ < "0.17.0":
140
+ new_item = new_item.replace("q.weight", "query.weight")
141
+ new_item = new_item.replace("q.bias", "query.bias")
142
+
143
+ new_item = new_item.replace("k.weight", "key.weight")
144
+ new_item = new_item.replace("k.bias", "key.bias")
145
+
146
+ new_item = new_item.replace("v.weight", "value.weight")
147
+ new_item = new_item.replace("v.bias", "value.bias")
148
+
149
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
150
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
151
+ else:
152
+ new_item = new_item.replace("q.weight", "to_q.weight")
153
+ new_item = new_item.replace("q.bias", "to_q.bias")
154
+
155
+ new_item = new_item.replace("k.weight", "to_k.weight")
156
+ new_item = new_item.replace("k.bias", "to_k.bias")
157
+
158
+ new_item = new_item.replace("v.weight", "to_v.weight")
159
+ new_item = new_item.replace("v.bias", "to_v.bias")
160
+
161
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
162
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
163
+
164
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
165
+
166
+ mapping.append({"old": old_item, "new": new_item})
167
+
168
+ return mapping
169
+
170
+
171
+ def assign_to_checkpoint(
172
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
173
+ ):
174
+ """
175
+ This does the final conversion step: take locally converted weights and apply a global renaming
176
+ to them. It splits attention layers, and takes into account additional replacements
177
+ that may arise.
178
+
179
+ Assigns the weights to the new checkpoint.
180
+ """
181
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
182
+
183
+ # Splits the attention layers into three variables.
184
+ if attention_paths_to_split is not None:
185
+ for path, path_map in attention_paths_to_split.items():
186
+ old_tensor = old_checkpoint[path]
187
+ channels = old_tensor.shape[0] // 3
188
+
189
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
190
+
191
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
192
+
193
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
194
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
195
+
196
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
197
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
198
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
199
+
200
+ for path in paths:
201
+ new_path = path["new"]
202
+
203
+ # These have already been assigned
204
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
205
+ continue
206
+
207
+ # Global renaming happens here
208
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
209
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
210
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
211
+
212
+ if additional_replacements is not None:
213
+ for replacement in additional_replacements:
214
+ new_path = new_path.replace(replacement["old"], replacement["new"])
215
+
216
+ # proj_attn.weight has to be converted from conv 1D to linear
217
+ reshaping = False
218
+ if diffusers.__version__ < "0.17.0":
219
+ if "proj_attn.weight" in new_path:
220
+ reshaping = True
221
+ else:
222
+ if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
223
+ reshaping = True
224
+
225
+ if reshaping:
226
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
227
+ else:
228
+ checkpoint[new_path] = old_checkpoint[path["old"]]
229
+
230
+
231
+ def conv_attn_to_linear(checkpoint):
232
+ keys = list(checkpoint.keys())
233
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
234
+ for key in keys:
235
+ if ".".join(key.split(".")[-2:]) in attn_keys:
236
+ if checkpoint[key].ndim > 2:
237
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
238
+ elif "proj_attn.weight" in key:
239
+ if checkpoint[key].ndim > 2:
240
+ checkpoint[key] = checkpoint[key][:, :, 0]
241
+
242
+
243
+ def linear_transformer_to_conv(checkpoint):
244
+ keys = list(checkpoint.keys())
245
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
246
+ for key in keys:
247
+ if ".".join(key.split(".")[-2:]) in tf_keys:
248
+ if checkpoint[key].ndim == 2:
249
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
250
+
251
+
252
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
253
+ """
254
+ Takes a state dict and a config, and returns a converted checkpoint.
255
+ """
256
+
257
+ # extract state_dict for UNet
258
+ unet_state_dict = {}
259
+ unet_key = "model.diffusion_model."
260
+ keys = list(checkpoint.keys())
261
+ for key in keys:
262
+ if key.startswith(unet_key):
263
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
264
+
265
+ new_checkpoint = {}
266
+
267
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
268
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
269
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
270
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
271
+
272
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
273
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
274
+
275
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
276
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
277
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
278
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
279
+
280
+ # Retrieves the keys for the input blocks only
281
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
282
+ input_blocks = {
283
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
284
+ }
285
+
286
+ # Retrieves the keys for the middle blocks only
287
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
288
+ middle_blocks = {
289
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
290
+ }
291
+
292
+ # Retrieves the keys for the output blocks only
293
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
294
+ output_blocks = {
295
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
296
+ }
297
+
298
+ for i in range(1, num_input_blocks):
299
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
300
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
301
+
302
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
303
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
304
+
305
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
306
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
307
+ f"input_blocks.{i}.0.op.weight"
308
+ )
309
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
310
+
311
+ paths = renew_resnet_paths(resnets)
312
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
313
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
314
+
315
+ if len(attentions):
316
+ paths = renew_attention_paths(attentions)
317
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
318
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
319
+
320
+ resnet_0 = middle_blocks[0]
321
+ attentions = middle_blocks[1]
322
+ resnet_1 = middle_blocks[2]
323
+
324
+ resnet_0_paths = renew_resnet_paths(resnet_0)
325
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
326
+
327
+ resnet_1_paths = renew_resnet_paths(resnet_1)
328
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
329
+
330
+ attentions_paths = renew_attention_paths(attentions)
331
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
332
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
333
+
334
+ for i in range(num_output_blocks):
335
+ block_id = i // (config["layers_per_block"] + 1)
336
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
337
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
338
+ output_block_list = {}
339
+
340
+ for layer in output_block_layers:
341
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
342
+ if layer_id in output_block_list:
343
+ output_block_list[layer_id].append(layer_name)
344
+ else:
345
+ output_block_list[layer_id] = [layer_name]
346
+
347
+ if len(output_block_list) > 1:
348
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
349
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
350
+
351
+ resnet_0_paths = renew_resnet_paths(resnets)
352
+ paths = renew_resnet_paths(resnets)
353
+
354
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
355
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
356
+
357
+ # オリジナル:
358
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
359
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
360
+
361
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
362
+ for l in output_block_list.values():
363
+ l.sort()
364
+
365
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
366
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
367
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
368
+ f"output_blocks.{i}.{index}.conv.bias"
369
+ ]
370
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
371
+ f"output_blocks.{i}.{index}.conv.weight"
372
+ ]
373
+
374
+ # Clear attentions as they have been attributed above.
375
+ if len(attentions) == 2:
376
+ attentions = []
377
+
378
+ if len(attentions):
379
+ paths = renew_attention_paths(attentions)
380
+ meta_path = {
381
+ "old": f"output_blocks.{i}.1",
382
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
383
+ }
384
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
385
+ else:
386
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
387
+ for path in resnet_0_paths:
388
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
389
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
390
+
391
+ new_checkpoint[new_path] = unet_state_dict[old_path]
392
+
393
+ # SDのv2では1*1のconv2dがlinearに変わっている
394
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
395
+ if v2 and not config.get("use_linear_projection", False):
396
+ linear_transformer_to_conv(new_checkpoint)
397
+
398
+ return new_checkpoint
399
+
400
+
401
+ def convert_ldm_vae_checkpoint(checkpoint, config):
402
+ # extract state dict for VAE
403
+ vae_state_dict = {}
404
+ vae_key = "first_stage_model."
405
+ keys = list(checkpoint.keys())
406
+ for key in keys:
407
+ if key.startswith(vae_key):
408
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
409
+ # if len(vae_state_dict) == 0:
410
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
411
+ # vae_state_dict = checkpoint
412
+
413
+ new_checkpoint = {}
414
+
415
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
416
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
417
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
418
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
419
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
420
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
421
+
422
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
423
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
424
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
425
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
426
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
427
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
428
+
429
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
430
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
431
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
432
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
433
+
434
+ # Retrieves the keys for the encoder down blocks only
435
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
436
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
437
+
438
+ # Retrieves the keys for the decoder up blocks only
439
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
440
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
441
+
442
+ for i in range(num_down_blocks):
443
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
444
+
445
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
446
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
447
+ f"encoder.down.{i}.downsample.conv.weight"
448
+ )
449
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
450
+ f"encoder.down.{i}.downsample.conv.bias"
451
+ )
452
+
453
+ paths = renew_vae_resnet_paths(resnets)
454
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
455
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
456
+
457
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
458
+ num_mid_res_blocks = 2
459
+ for i in range(1, num_mid_res_blocks + 1):
460
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
461
+
462
+ paths = renew_vae_resnet_paths(resnets)
463
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
464
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
465
+
466
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
467
+ paths = renew_vae_attention_paths(mid_attentions)
468
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
469
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
470
+ conv_attn_to_linear(new_checkpoint)
471
+
472
+ for i in range(num_up_blocks):
473
+ block_id = num_up_blocks - 1 - i
474
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
475
+
476
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
477
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
478
+ f"decoder.up.{block_id}.upsample.conv.weight"
479
+ ]
480
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
481
+ f"decoder.up.{block_id}.upsample.conv.bias"
482
+ ]
483
+
484
+ paths = renew_vae_resnet_paths(resnets)
485
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
486
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
487
+
488
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
489
+ num_mid_res_blocks = 2
490
+ for i in range(1, num_mid_res_blocks + 1):
491
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
492
+
493
+ paths = renew_vae_resnet_paths(resnets)
494
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
495
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
496
+
497
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
498
+ paths = renew_vae_attention_paths(mid_attentions)
499
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
500
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
501
+ conv_attn_to_linear(new_checkpoint)
502
+ return new_checkpoint
503
+
504
+
505
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
506
+ """
507
+ Creates a config for the diffusers based on the config of the LDM model.
508
+ """
509
+ # unet_params = original_config.model.params.unet_config.params
510
+
511
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
512
+
513
+ down_block_types = []
514
+ resolution = 1
515
+ for i in range(len(block_out_channels)):
516
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
517
+ down_block_types.append(block_type)
518
+ if i != len(block_out_channels) - 1:
519
+ resolution *= 2
520
+
521
+ up_block_types = []
522
+ for i in range(len(block_out_channels)):
523
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
524
+ up_block_types.append(block_type)
525
+ resolution //= 2
526
+
527
+ config = dict(
528
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
529
+ in_channels=UNET_PARAMS_IN_CHANNELS,
530
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
531
+ down_block_types=tuple(down_block_types),
532
+ up_block_types=tuple(up_block_types),
533
+ block_out_channels=tuple(block_out_channels),
534
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
535
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
536
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
537
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
538
+ )
539
+ if v2 and use_linear_projection_in_v2:
540
+ config["use_linear_projection"] = True
541
+
542
+ return config
543
+
544
+
545
+ def create_vae_diffusers_config():
546
+ """
547
+ Creates a config for the diffusers based on the config of the LDM model.
548
+ """
549
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
550
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
551
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
552
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
553
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
554
+
555
+ config = dict(
556
+ sample_size=VAE_PARAMS_RESOLUTION,
557
+ in_channels=VAE_PARAMS_IN_CHANNELS,
558
+ out_channels=VAE_PARAMS_OUT_CH,
559
+ down_block_types=tuple(down_block_types),
560
+ up_block_types=tuple(up_block_types),
561
+ block_out_channels=tuple(block_out_channels),
562
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
563
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
564
+ )
565
+ return config
566
+
567
+
568
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
569
+ keys = list(checkpoint.keys())
570
+ text_model_dict = {}
571
+ for key in keys:
572
+ if key.startswith("cond_stage_model.transformer"):
573
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
574
+
575
+ # remove position_ids for newer transformer, which causes error :(
576
+ if "text_model.embeddings.position_ids" in text_model_dict:
577
+ text_model_dict.pop("text_model.embeddings.position_ids")
578
+
579
+ return text_model_dict
580
+
581
+
582
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
583
+ # 嫌になるくらい違うぞ!
584
+ def convert_key(key):
585
+ if not key.startswith("cond_stage_model"):
586
+ return None
587
+
588
+ # common conversion
589
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
590
+ key = key.replace("cond_stage_model.model.", "text_model.")
591
+
592
+ if "resblocks" in key:
593
+ # resblocks conversion
594
+ key = key.replace(".resblocks.", ".layers.")
595
+ if ".ln_" in key:
596
+ key = key.replace(".ln_", ".layer_norm")
597
+ elif ".mlp." in key:
598
+ key = key.replace(".c_fc.", ".fc1.")
599
+ key = key.replace(".c_proj.", ".fc2.")
600
+ elif ".attn.out_proj" in key:
601
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
602
+ elif ".attn.in_proj" in key:
603
+ key = None # 特殊なので後で処理する
604
+ else:
605
+ raise ValueError(f"unexpected key in SD: {key}")
606
+ elif ".positional_embedding" in key:
607
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
608
+ elif ".text_projection" in key:
609
+ key = None # 使われない???
610
+ elif ".logit_scale" in key:
611
+ key = None # 使われない???
612
+ elif ".token_embedding" in key:
613
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
614
+ elif ".ln_final" in key:
615
+ key = key.replace(".ln_final", ".final_layer_norm")
616
+ return key
617
+
618
+ keys = list(checkpoint.keys())
619
+ new_sd = {}
620
+ for key in keys:
621
+ # remove resblocks 23
622
+ if ".resblocks.23." in key:
623
+ continue
624
+ new_key = convert_key(key)
625
+ if new_key is None:
626
+ continue
627
+ new_sd[new_key] = checkpoint[key]
628
+
629
+ # attnの変換
630
+ for key in keys:
631
+ if ".resblocks.23." in key:
632
+ continue
633
+ if ".resblocks" in key and ".attn.in_proj_" in key:
634
+ # 三つに分割
635
+ values = torch.chunk(checkpoint[key], 3)
636
+
637
+ key_suffix = ".weight" if "weight" in key else ".bias"
638
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
639
+ key_pfx = key_pfx.replace("_weight", "")
640
+ key_pfx = key_pfx.replace("_bias", "")
641
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
642
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
643
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
644
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
645
+
646
+ # rename or add position_ids
647
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
648
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
649
+ # waifu diffusion v1.4
650
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
651
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
652
+ else:
653
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
654
+
655
+ new_sd["text_model.embeddings.position_ids"] = position_ids
656
+ return new_sd
657
+
658
+
659
+ # endregion
660
+
661
+
662
+ # region Diffusers->StableDiffusion の変換コード
663
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
664
+
665
+
666
+ def conv_transformer_to_linear(checkpoint):
667
+ keys = list(checkpoint.keys())
668
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
669
+ for key in keys:
670
+ if ".".join(key.split(".")[-2:]) in tf_keys:
671
+ if checkpoint[key].ndim > 2:
672
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
673
+
674
+
675
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
676
+ unet_conversion_map = [
677
+ # (stable-diffusion, HF Diffusers)
678
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
679
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
680
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
681
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
682
+ ("input_blocks.0.0.weight", "conv_in.weight"),
683
+ ("input_blocks.0.0.bias", "conv_in.bias"),
684
+ ("out.0.weight", "conv_norm_out.weight"),
685
+ ("out.0.bias", "conv_norm_out.bias"),
686
+ ("out.2.weight", "conv_out.weight"),
687
+ ("out.2.bias", "conv_out.bias"),
688
+ ]
689
+
690
+ unet_conversion_map_resnet = [
691
+ # (stable-diffusion, HF Diffusers)
692
+ ("in_layers.0", "norm1"),
693
+ ("in_layers.2", "conv1"),
694
+ ("out_layers.0", "norm2"),
695
+ ("out_layers.3", "conv2"),
696
+ ("emb_layers.1", "time_emb_proj"),
697
+ ("skip_connection", "conv_shortcut"),
698
+ ]
699
+
700
+ unet_conversion_map_layer = []
701
+ for i in range(4):
702
+ # loop over downblocks/upblocks
703
+
704
+ for j in range(2):
705
+ # loop over resnets/attentions for downblocks
706
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
707
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
708
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
709
+
710
+ if i < 3:
711
+ # no attention layers in down_blocks.3
712
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
713
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
714
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
715
+
716
+ for j in range(3):
717
+ # loop over resnets/attentions for upblocks
718
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
719
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
720
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
721
+
722
+ if i > 0:
723
+ # no attention layers in up_blocks.0
724
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
725
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
726
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
727
+
728
+ if i < 3:
729
+ # no downsample in down_blocks.3
730
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
731
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
732
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
733
+
734
+ # no upsample in up_blocks.3
735
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
736
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
737
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
738
+
739
+ hf_mid_atn_prefix = "mid_block.attentions.0."
740
+ sd_mid_atn_prefix = "middle_block.1."
741
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
742
+
743
+ for j in range(2):
744
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
745
+ sd_mid_res_prefix = f"middle_block.{2*j}."
746
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
747
+
748
+ # buyer beware: this is a *brittle* function,
749
+ # and correct output requires that all of these pieces interact in
750
+ # the exact order in which I have arranged them.
751
+ mapping = {k: k for k in unet_state_dict.keys()}
752
+ for sd_name, hf_name in unet_conversion_map:
753
+ mapping[hf_name] = sd_name
754
+ for k, v in mapping.items():
755
+ if "resnets" in k:
756
+ for sd_part, hf_part in unet_conversion_map_resnet:
757
+ v = v.replace(hf_part, sd_part)
758
+ mapping[k] = v
759
+ for k, v in mapping.items():
760
+ for sd_part, hf_part in unet_conversion_map_layer:
761
+ v = v.replace(hf_part, sd_part)
762
+ mapping[k] = v
763
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
764
+
765
+ if v2:
766
+ conv_transformer_to_linear(new_state_dict)
767
+
768
+ return new_state_dict
769
+
770
+
771
+ def controlnet_conversion_map():
772
+ unet_conversion_map = [
773
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
774
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
775
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
776
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
777
+ ("input_blocks.0.0.weight", "conv_in.weight"),
778
+ ("input_blocks.0.0.bias", "conv_in.bias"),
779
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
780
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
781
+ ]
782
+
783
+ unet_conversion_map_resnet = [
784
+ ("in_layers.0", "norm1"),
785
+ ("in_layers.2", "conv1"),
786
+ ("out_layers.0", "norm2"),
787
+ ("out_layers.3", "conv2"),
788
+ ("emb_layers.1", "time_emb_proj"),
789
+ ("skip_connection", "conv_shortcut"),
790
+ ]
791
+
792
+ unet_conversion_map_layer = []
793
+ for i in range(4):
794
+ for j in range(2):
795
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
796
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
797
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
798
+
799
+ if i < 3:
800
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
801
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
802
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
803
+
804
+ if i < 3:
805
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
806
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
807
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
808
+
809
+ hf_mid_atn_prefix = "mid_block.attentions.0."
810
+ sd_mid_atn_prefix = "middle_block.1."
811
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
812
+
813
+ for j in range(2):
814
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
815
+ sd_mid_res_prefix = f"middle_block.{2*j}."
816
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
817
+
818
+ controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
819
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
820
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
821
+ sd_prefix = f"input_hint_block.{i*2}."
822
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
823
+
824
+ for i in range(12):
825
+ hf_prefix = f"controlnet_down_blocks.{i}."
826
+ sd_prefix = f"zero_convs.{i}.0."
827
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
828
+
829
+ return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
830
+
831
+
832
+ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
833
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
834
+
835
+ mapping = {k: k for k in controlnet_state_dict.keys()}
836
+ for sd_name, diffusers_name in unet_conversion_map:
837
+ mapping[diffusers_name] = sd_name
838
+ for k, v in mapping.items():
839
+ if "resnets" in k:
840
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
841
+ v = v.replace(diffusers_part, sd_part)
842
+ mapping[k] = v
843
+ for k, v in mapping.items():
844
+ for sd_part, diffusers_part in unet_conversion_map_layer:
845
+ v = v.replace(diffusers_part, sd_part)
846
+ mapping[k] = v
847
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
848
+ return new_state_dict
849
+
850
+
851
+ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
852
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
853
+
854
+ mapping = {k: k for k in controlnet_state_dict.keys()}
855
+ for sd_name, diffusers_name in unet_conversion_map:
856
+ mapping[sd_name] = diffusers_name
857
+ for k, v in mapping.items():
858
+ for sd_part, diffusers_part in unet_conversion_map_layer:
859
+ v = v.replace(sd_part, diffusers_part)
860
+ mapping[k] = v
861
+ for k, v in mapping.items():
862
+ if "resnets" in v:
863
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
864
+ v = v.replace(sd_part, diffusers_part)
865
+ mapping[k] = v
866
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
867
+ return new_state_dict
868
+
869
+
870
+ # ================#
871
+ # VAE Conversion #
872
+ # ================#
873
+
874
+
875
+ def reshape_weight_for_sd(w):
876
+ # convert HF linear weights to SD conv2d weights
877
+ return w.reshape(*w.shape, 1, 1)
878
+
879
+
880
+ def convert_vae_state_dict(vae_state_dict):
881
+ vae_conversion_map = [
882
+ # (stable-diffusion, HF Diffusers)
883
+ ("nin_shortcut", "conv_shortcut"),
884
+ ("norm_out", "conv_norm_out"),
885
+ ("mid.attn_1.", "mid_block.attentions.0."),
886
+ ]
887
+
888
+ for i in range(4):
889
+ # down_blocks have two resnets
890
+ for j in range(2):
891
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
892
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
893
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
894
+
895
+ if i < 3:
896
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
897
+ sd_downsample_prefix = f"down.{i}.downsample."
898
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
899
+
900
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
901
+ sd_upsample_prefix = f"up.{3-i}.upsample."
902
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
903
+
904
+ # up_blocks have three resnets
905
+ # also, up blocks in hf are numbered in reverse from sd
906
+ for j in range(3):
907
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
908
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
909
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
910
+
911
+ # this part accounts for mid blocks in both the encoder and the decoder
912
+ for i in range(2):
913
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
914
+ sd_mid_res_prefix = f"mid.block_{i+1}."
915
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
916
+
917
+ if diffusers.__version__ < "0.17.0":
918
+ vae_conversion_map_attn = [
919
+ # (stable-diffusion, HF Diffusers)
920
+ ("norm.", "group_norm."),
921
+ ("q.", "query."),
922
+ ("k.", "key."),
923
+ ("v.", "value."),
924
+ ("proj_out.", "proj_attn."),
925
+ ]
926
+ else:
927
+ vae_conversion_map_attn = [
928
+ # (stable-diffusion, HF Diffusers)
929
+ ("norm.", "group_norm."),
930
+ ("q.", "to_q."),
931
+ ("k.", "to_k."),
932
+ ("v.", "to_v."),
933
+ ("proj_out.", "to_out.0."),
934
+ ]
935
+
936
+ mapping = {k: k for k in vae_state_dict.keys()}
937
+ for k, v in mapping.items():
938
+ for sd_part, hf_part in vae_conversion_map:
939
+ v = v.replace(hf_part, sd_part)
940
+ mapping[k] = v
941
+ for k, v in mapping.items():
942
+ if "attentions" in k:
943
+ for sd_part, hf_part in vae_conversion_map_attn:
944
+ v = v.replace(hf_part, sd_part)
945
+ mapping[k] = v
946
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
947
+ weights_to_convert = ["q", "k", "v", "proj_out"]
948
+ for k, v in new_state_dict.items():
949
+ for weight_name in weights_to_convert:
950
+ if f"mid.attn_1.{weight_name}.weight" in k:
951
+ # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
952
+ new_state_dict[k] = reshape_weight_for_sd(v)
953
+
954
+ return new_state_dict
955
+
956
+
957
+ # endregion
958
+
959
+ # region 自作のモデル読み書きなど
960
+
961
+
962
+ def is_safetensors(path):
963
+ return os.path.splitext(path)[1].lower() == ".safetensors"
964
+
965
+
966
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
967
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
968
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
969
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
970
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
971
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
972
+ ]
973
+
974
+ if is_safetensors(ckpt_path):
975
+ checkpoint = None
976
+ state_dict = load_file(ckpt_path) # , device) # may causes error
977
+ else:
978
+ checkpoint = torch.load(ckpt_path, map_location=device)
979
+ if "state_dict" in checkpoint:
980
+ state_dict = checkpoint["state_dict"]
981
+ else:
982
+ state_dict = checkpoint
983
+ checkpoint = None
984
+
985
+ key_reps = []
986
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
987
+ for key in state_dict.keys():
988
+ if key.startswith(rep_from):
989
+ new_key = rep_to + key[len(rep_from) :]
990
+ key_reps.append((key, new_key))
991
+
992
+ for key, new_key in key_reps:
993
+ state_dict[new_key] = state_dict[key]
994
+ del state_dict[key]
995
+
996
+ return checkpoint, state_dict
997
+
998
+
999
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
1000
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
1001
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
1002
+
1003
+ # Convert the UNet2DConditionModel model.
1004
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
1005
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
1006
+
1007
+ unet = UNet2DConditionModel(**unet_config).to(device)
1008
+ info = unet.load_state_dict(converted_unet_checkpoint)
1009
+ logger.info(f"loading u-net: {info}")
1010
+
1011
+ # Convert the VAE model.
1012
+ vae_config = create_vae_diffusers_config()
1013
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
1014
+
1015
+ vae = AutoencoderKL(**vae_config).to(device)
1016
+ info = vae.load_state_dict(converted_vae_checkpoint)
1017
+ logger.info(f"loading vae: {info}")
1018
+
1019
+ # convert text_model
1020
+ if v2:
1021
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
1022
+ cfg = CLIPTextConfig(
1023
+ vocab_size=49408,
1024
+ hidden_size=1024,
1025
+ intermediate_size=4096,
1026
+ num_hidden_layers=23,
1027
+ num_attention_heads=16,
1028
+ max_position_embeddings=77,
1029
+ hidden_act="gelu",
1030
+ layer_norm_eps=1e-05,
1031
+ dropout=0.0,
1032
+ attention_dropout=0.0,
1033
+ initializer_range=0.02,
1034
+ initializer_factor=1.0,
1035
+ pad_token_id=1,
1036
+ bos_token_id=0,
1037
+ eos_token_id=2,
1038
+ model_type="clip_text_model",
1039
+ projection_dim=512,
1040
+ torch_dtype="float32",
1041
+ transformers_version="4.25.0.dev0",
1042
+ )
1043
+ text_model = CLIPTextModel._from_config(cfg)
1044
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1045
+ else:
1046
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
1047
+
1048
+ # logging.set_verbosity_error() # don't show annoying warning
1049
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
1050
+ # logging.set_verbosity_warning()
1051
+ # logger.info(f"config: {text_model.config}")
1052
+ cfg = CLIPTextConfig(
1053
+ vocab_size=49408,
1054
+ hidden_size=768,
1055
+ intermediate_size=3072,
1056
+ num_hidden_layers=12,
1057
+ num_attention_heads=12,
1058
+ max_position_embeddings=77,
1059
+ hidden_act="quick_gelu",
1060
+ layer_norm_eps=1e-05,
1061
+ dropout=0.0,
1062
+ attention_dropout=0.0,
1063
+ initializer_range=0.02,
1064
+ initializer_factor=1.0,
1065
+ pad_token_id=1,
1066
+ bos_token_id=0,
1067
+ eos_token_id=2,
1068
+ model_type="clip_text_model",
1069
+ projection_dim=768,
1070
+ torch_dtype="float32",
1071
+ )
1072
+ text_model = CLIPTextModel._from_config(cfg)
1073
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1074
+ logger.info(f"loading text encoder: {info}")
1075
+
1076
+ return text_model, vae, unet
1077
+
1078
+
1079
+ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
1080
+ # only for reference
1081
+ version_str = "sd"
1082
+ if v2:
1083
+ version_str += "_v2"
1084
+ else:
1085
+ version_str += "_v1"
1086
+ if v_parameterization:
1087
+ version_str += "_v"
1088
+ return version_str
1089
+
1090
+
1091
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
1092
+ def convert_key(key):
1093
+ # position_idsの除去
1094
+ if ".position_ids" in key:
1095
+ return None
1096
+
1097
+ # common
1098
+ key = key.replace("text_model.encoder.", "transformer.")
1099
+ key = key.replace("text_model.", "")
1100
+ if "layers" in key:
1101
+ # resblocks conversion
1102
+ key = key.replace(".layers.", ".resblocks.")
1103
+ if ".layer_norm" in key:
1104
+ key = key.replace(".layer_norm", ".ln_")
1105
+ elif ".mlp." in key:
1106
+ key = key.replace(".fc1.", ".c_fc.")
1107
+ key = key.replace(".fc2.", ".c_proj.")
1108
+ elif ".self_attn.out_proj" in key:
1109
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
1110
+ elif ".self_attn." in key:
1111
+ key = None # 特殊なので後で処理する
1112
+ else:
1113
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
1114
+ elif ".position_embedding" in key:
1115
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
1116
+ elif ".token_embedding" in key:
1117
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
1118
+ elif "final_layer_norm" in key:
1119
+ key = key.replace("final_layer_norm", "ln_final")
1120
+ return key
1121
+
1122
+ keys = list(checkpoint.keys())
1123
+ new_sd = {}
1124
+ for key in keys:
1125
+ new_key = convert_key(key)
1126
+ if new_key is None:
1127
+ continue
1128
+ new_sd[new_key] = checkpoint[key]
1129
+
1130
+ # attnの変換
1131
+ for key in keys:
1132
+ if "layers" in key and "q_proj" in key:
1133
+ # 三つを結合
1134
+ key_q = key
1135
+ key_k = key.replace("q_proj", "k_proj")
1136
+ key_v = key.replace("q_proj", "v_proj")
1137
+
1138
+ value_q = checkpoint[key_q]
1139
+ value_k = checkpoint[key_k]
1140
+ value_v = checkpoint[key_v]
1141
+ value = torch.cat([value_q, value_k, value_v])
1142
+
1143
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
1144
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
1145
+ new_sd[new_key] = value
1146
+
1147
+ # 最後の層などを捏造するか
1148
+ if make_dummy_weights:
1149
+ logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
1150
+ keys = list(new_sd.keys())
1151
+ for key in keys:
1152
+ if key.startswith("transformer.resblocks.22."):
1153
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
1154
+
1155
+ # Diffusersに含まれない重みを作っておく
1156
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
1157
+ new_sd["logit_scale"] = torch.tensor(1)
1158
+
1159
+ return new_sd
1160
+
1161
+
1162
+ def save_stable_diffusion_checkpoint(
1163
+ v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
1164
+ ):
1165
+ if ckpt_path is not None:
1166
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1167
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1168
+ if checkpoint is None: # safetensors または state_dictのckpt
1169
+ checkpoint = {}
1170
+ strict = False
1171
+ else:
1172
+ strict = True
1173
+ if "state_dict" in state_dict:
1174
+ del state_dict["state_dict"]
1175
+ else:
1176
+ # 新しく作る
1177
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1178
+ checkpoint = {}
1179
+ state_dict = {}
1180
+ strict = False
1181
+
1182
+ def update_sd(prefix, sd):
1183
+ for k, v in sd.items():
1184
+ key = prefix + k
1185
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1186
+ if save_dtype is not None:
1187
+ v = v.detach().clone().to("cpu").to(save_dtype)
1188
+ state_dict[key] = v
1189
+
1190
+ # Convert the UNet model
1191
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1192
+ update_sd("model.diffusion_model.", unet_state_dict)
1193
+
1194
+ # Convert the text encoder model
1195
+ if v2:
1196
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1197
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1198
+ update_sd("cond_stage_model.model.", text_enc_dict)
1199
+ else:
1200
+ text_enc_dict = text_encoder.state_dict()
1201
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1202
+
1203
+ # Convert the VAE
1204
+ if vae is not None:
1205
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1206
+ update_sd("first_stage_model.", vae_dict)
1207
+
1208
+ # Put together new checkpoint
1209
+ key_count = len(state_dict.keys())
1210
+ new_ckpt = {"state_dict": state_dict}
1211
+
1212
+ # epoch and global_step are sometimes not int
1213
+ try:
1214
+ if "epoch" in checkpoint:
1215
+ epochs += checkpoint["epoch"]
1216
+ if "global_step" in checkpoint:
1217
+ steps += checkpoint["global_step"]
1218
+ except:
1219
+ pass
1220
+
1221
+ new_ckpt["epoch"] = epochs
1222
+ new_ckpt["global_step"] = steps
1223
+
1224
+ if is_safetensors(output_file):
1225
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1226
+ save_file(state_dict, output_file, metadata)
1227
+ else:
1228
+ torch.save(new_ckpt, output_file)
1229
+
1230
+ return key_count
1231
+
1232
+
1233
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1234
+ if pretrained_model_name_or_path is None:
1235
+ # load default settings for v1/v2
1236
+ if v2:
1237
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1238
+ else:
1239
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1240
+
1241
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1242
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1243
+ if vae is None:
1244
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1245
+
1246
+ # original U-Net cannot be saved, so we need to convert it to the Diffusers version
1247
+ # TODO this consumes a lot of memory
1248
+ diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
1249
+ diffusers_unet.load_state_dict(unet.state_dict())
1250
+
1251
+ pipeline = StableDiffusionPipeline(
1252
+ unet=diffusers_unet,
1253
+ text_encoder=text_encoder,
1254
+ vae=vae,
1255
+ scheduler=scheduler,
1256
+ tokenizer=tokenizer,
1257
+ safety_checker=None,
1258
+ feature_extractor=None,
1259
+ requires_safety_checker=None,
1260
+ )
1261
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1262
+
1263
+
1264
+ VAE_PREFIX = "first_stage_model."
1265
+
1266
+
1267
+ def load_vae(vae_id, dtype):
1268
+ logger.info(f"load VAE: {vae_id}")
1269
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1270
+ # Diffusers local/remote
1271
+ try:
1272
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1273
+ except EnvironmentError as e:
1274
+ logger.error(f"exception occurs in loading vae: {e}")
1275
+ logger.error("retry with subfolder='vae'")
1276
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1277
+ return vae
1278
+
1279
+ # local
1280
+ vae_config = create_vae_diffusers_config()
1281
+
1282
+ if vae_id.endswith(".bin"):
1283
+ # SD 1.5 VAE on Huggingface
1284
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1285
+ else:
1286
+ # StableDiffusion
1287
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1288
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1289
+
1290
+ # vae only or full model
1291
+ full_model = False
1292
+ for vae_key in vae_sd:
1293
+ if vae_key.startswith(VAE_PREFIX):
1294
+ full_model = True
1295
+ break
1296
+ if not full_model:
1297
+ sd = {}
1298
+ for key, value in vae_sd.items():
1299
+ sd[VAE_PREFIX + key] = value
1300
+ vae_sd = sd
1301
+ del sd
1302
+
1303
+ # Convert the VAE model.
1304
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1305
+
1306
+ vae = AutoencoderKL(**vae_config)
1307
+ vae.load_state_dict(converted_vae_checkpoint)
1308
+ return vae
1309
+
1310
+
1311
+ # endregion
1312
+
1313
+
1314
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1315
+ max_width, max_height = max_reso
1316
+ max_area = max_width * max_height
1317
+
1318
+ resos = set()
1319
+
1320
+ width = int(math.sqrt(max_area) // divisible) * divisible
1321
+ resos.add((width, width))
1322
+
1323
+ width = min_size
1324
+ while width <= max_size:
1325
+ height = min(max_size, int((max_area // width) // divisible) * divisible)
1326
+ if height >= min_size:
1327
+ resos.add((width, height))
1328
+ resos.add((height, width))
1329
+
1330
+ # # make additional resos
1331
+ # if width >= height and width - divisible >= min_size:
1332
+ # resos.add((width - divisible, height))
1333
+ # resos.add((height, width - divisible))
1334
+ # if height >= width and height - divisible >= min_size:
1335
+ # resos.add((width, height - divisible))
1336
+ # resos.add((height - divisible, width))
1337
+
1338
+ width += divisible
1339
+
1340
+ resos = list(resos)
1341
+ resos.sort()
1342
+ return resos
1343
+
1344
+
1345
+ if __name__ == "__main__":
1346
+ resos = make_bucket_resolutions((512, 768))
1347
+ logger.info(f"{len(resos)}")
1348
+ logger.info(f"{resos}")
1349
+ aspect_ratios = [w / h for w, h in resos]
1350
+ logger.info(f"{aspect_ratios}")
1351
+
1352
+ ars = set()
1353
+ for ar in aspect_ratios:
1354
+ if ar in ars:
1355
+ logger.error(f"error! duplicate ar: {ar}")
1356
+ ars.add(ar)
library/original_unet.py ADDED
@@ -0,0 +1,1919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
2
+ # 条件分岐等で不要な部分は削除している
3
+ # コードの多くはDiffusersからコピーしている
4
+ # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
5
+
6
+ # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
7
+ # Unnecessary parts are deleted by condition branching.
8
+ # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
9
+
10
+ """
11
+ v1.5とv2.1の相違点は
12
+ - attention_head_dimがintかlist[int]か
13
+ - cross_attention_dimが768か1024か
14
+ - use_linear_projection: trueがない(=False, 1.5)かあるか
15
+ - upcast_attentionがFalse(1.5)かTrue(2.1)か
16
+ - (以下は多分無視していい)
17
+ - sample_sizeが64か96か
18
+ - dual_cross_attentionがあるかないか
19
+ - num_class_embedsがあるかないか
20
+ - only_cross_attentionがあるかないか
21
+
22
+ v1.5
23
+ {
24
+ "_class_name": "UNet2DConditionModel",
25
+ "_diffusers_version": "0.6.0",
26
+ "act_fn": "silu",
27
+ "attention_head_dim": 8,
28
+ "block_out_channels": [
29
+ 320,
30
+ 640,
31
+ 1280,
32
+ 1280
33
+ ],
34
+ "center_input_sample": false,
35
+ "cross_attention_dim": 768,
36
+ "down_block_types": [
37
+ "CrossAttnDownBlock2D",
38
+ "CrossAttnDownBlock2D",
39
+ "CrossAttnDownBlock2D",
40
+ "DownBlock2D"
41
+ ],
42
+ "downsample_padding": 1,
43
+ "flip_sin_to_cos": true,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_scale_factor": 1,
48
+ "norm_eps": 1e-05,
49
+ "norm_num_groups": 32,
50
+ "out_channels": 4,
51
+ "sample_size": 64,
52
+ "up_block_types": [
53
+ "UpBlock2D",
54
+ "CrossAttnUpBlock2D",
55
+ "CrossAttnUpBlock2D",
56
+ "CrossAttnUpBlock2D"
57
+ ]
58
+ }
59
+
60
+ v2.1
61
+ {
62
+ "_class_name": "UNet2DConditionModel",
63
+ "_diffusers_version": "0.10.0.dev0",
64
+ "act_fn": "silu",
65
+ "attention_head_dim": [
66
+ 5,
67
+ 10,
68
+ 20,
69
+ 20
70
+ ],
71
+ "block_out_channels": [
72
+ 320,
73
+ 640,
74
+ 1280,
75
+ 1280
76
+ ],
77
+ "center_input_sample": false,
78
+ "cross_attention_dim": 1024,
79
+ "down_block_types": [
80
+ "CrossAttnDownBlock2D",
81
+ "CrossAttnDownBlock2D",
82
+ "CrossAttnDownBlock2D",
83
+ "DownBlock2D"
84
+ ],
85
+ "downsample_padding": 1,
86
+ "dual_cross_attention": false,
87
+ "flip_sin_to_cos": true,
88
+ "freq_shift": 0,
89
+ "in_channels": 4,
90
+ "layers_per_block": 2,
91
+ "mid_block_scale_factor": 1,
92
+ "norm_eps": 1e-05,
93
+ "norm_num_groups": 32,
94
+ "num_class_embeds": null,
95
+ "only_cross_attention": false,
96
+ "out_channels": 4,
97
+ "sample_size": 96,
98
+ "up_block_types": [
99
+ "UpBlock2D",
100
+ "CrossAttnUpBlock2D",
101
+ "CrossAttnUpBlock2D",
102
+ "CrossAttnUpBlock2D"
103
+ ],
104
+ "use_linear_projection": true,
105
+ "upcast_attention": true
106
+ }
107
+ """
108
+
109
+ import math
110
+ from types import SimpleNamespace
111
+ from typing import Dict, Optional, Tuple, Union
112
+ import torch
113
+ from torch import nn
114
+ from torch.nn import functional as F
115
+ from einops import rearrange
116
+ from library.utils import setup_logging
117
+ setup_logging()
118
+ import logging
119
+ logger = logging.getLogger(__name__)
120
+
121
+ BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
122
+ TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
123
+ TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
124
+ IN_CHANNELS: int = 4
125
+ OUT_CHANNELS: int = 4
126
+ LAYERS_PER_BLOCK: int = 2
127
+ LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
128
+ TIME_EMBED_FLIP_SIN_TO_COS: bool = True
129
+ TIME_EMBED_FREQ_SHIFT: int = 0
130
+ NORM_GROUPS: int = 32
131
+ NORM_EPS: float = 1e-5
132
+ TRANSFORMER_NORM_NUM_GROUPS = 32
133
+
134
+ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
135
+ UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
136
+
137
+
138
+ # region memory efficient attention
139
+
140
+ # FlashAttentionを使うCrossAttention
141
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
142
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
143
+
144
+ # constants
145
+
146
+ EPSILON = 1e-6
147
+
148
+ # helper functions
149
+
150
+
151
+ def exists(val):
152
+ return val is not None
153
+
154
+
155
+ def default(val, d):
156
+ return val if exists(val) else d
157
+
158
+
159
+ # flash attention forwards and backwards
160
+
161
+ # https://arxiv.org/abs/2205.14135
162
+
163
+
164
+ class FlashAttentionFunction(torch.autograd.Function):
165
+ @staticmethod
166
+ @torch.no_grad()
167
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
168
+ """Algorithm 2 in the paper"""
169
+
170
+ device = q.device
171
+ dtype = q.dtype
172
+ max_neg_value = -torch.finfo(q.dtype).max
173
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
174
+
175
+ o = torch.zeros_like(q)
176
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
177
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
178
+
179
+ scale = q.shape[-1] ** -0.5
180
+
181
+ if not exists(mask):
182
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
183
+ else:
184
+ mask = rearrange(mask, "b n -> b 1 1 n")
185
+ mask = mask.split(q_bucket_size, dim=-1)
186
+
187
+ row_splits = zip(
188
+ q.split(q_bucket_size, dim=-2),
189
+ o.split(q_bucket_size, dim=-2),
190
+ mask,
191
+ all_row_sums.split(q_bucket_size, dim=-2),
192
+ all_row_maxes.split(q_bucket_size, dim=-2),
193
+ )
194
+
195
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
196
+ q_start_index = ind * q_bucket_size - qk_len_diff
197
+
198
+ col_splits = zip(
199
+ k.split(k_bucket_size, dim=-2),
200
+ v.split(k_bucket_size, dim=-2),
201
+ )
202
+
203
+ for k_ind, (kc, vc) in enumerate(col_splits):
204
+ k_start_index = k_ind * k_bucket_size
205
+
206
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
207
+
208
+ if exists(row_mask):
209
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
210
+
211
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
212
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
213
+ q_start_index - k_start_index + 1
214
+ )
215
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
216
+
217
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
218
+ attn_weights -= block_row_maxes
219
+ exp_weights = torch.exp(attn_weights)
220
+
221
+ if exists(row_mask):
222
+ exp_weights.masked_fill_(~row_mask, 0.0)
223
+
224
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
225
+
226
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
227
+
228
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
229
+
230
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
231
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
232
+
233
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
234
+
235
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
236
+
237
+ row_maxes.copy_(new_row_maxes)
238
+ row_sums.copy_(new_row_sums)
239
+
240
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
241
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
242
+
243
+ return o
244
+
245
+ @staticmethod
246
+ @torch.no_grad()
247
+ def backward(ctx, do):
248
+ """Algorithm 4 in the paper"""
249
+
250
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
251
+ q, k, v, o, l, m = ctx.saved_tensors
252
+
253
+ device = q.device
254
+
255
+ max_neg_value = -torch.finfo(q.dtype).max
256
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
257
+
258
+ dq = torch.zeros_like(q)
259
+ dk = torch.zeros_like(k)
260
+ dv = torch.zeros_like(v)
261
+
262
+ row_splits = zip(
263
+ q.split(q_bucket_size, dim=-2),
264
+ o.split(q_bucket_size, dim=-2),
265
+ do.split(q_bucket_size, dim=-2),
266
+ mask,
267
+ l.split(q_bucket_size, dim=-2),
268
+ m.split(q_bucket_size, dim=-2),
269
+ dq.split(q_bucket_size, dim=-2),
270
+ )
271
+
272
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
273
+ q_start_index = ind * q_bucket_size - qk_len_diff
274
+
275
+ col_splits = zip(
276
+ k.split(k_bucket_size, dim=-2),
277
+ v.split(k_bucket_size, dim=-2),
278
+ dk.split(k_bucket_size, dim=-2),
279
+ dv.split(k_bucket_size, dim=-2),
280
+ )
281
+
282
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
283
+ k_start_index = k_ind * k_bucket_size
284
+
285
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
286
+
287
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
288
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
289
+ q_start_index - k_start_index + 1
290
+ )
291
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
292
+
293
+ exp_attn_weights = torch.exp(attn_weights - mc)
294
+
295
+ if exists(row_mask):
296
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
297
+
298
+ p = exp_attn_weights / lc
299
+
300
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
301
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
302
+
303
+ D = (doc * oc).sum(dim=-1, keepdims=True)
304
+ ds = p * scale * (dp - D)
305
+
306
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
307
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
308
+
309
+ dqc.add_(dq_chunk)
310
+ dkc.add_(dk_chunk)
311
+ dvc.add_(dv_chunk)
312
+
313
+ return dq, dk, dv, None, None, None, None
314
+
315
+
316
+ # endregion
317
+
318
+
319
+ def get_parameter_dtype(parameter: torch.nn.Module):
320
+ return next(parameter.parameters()).dtype
321
+
322
+
323
+ def get_parameter_device(parameter: torch.nn.Module):
324
+ return next(parameter.parameters()).device
325
+
326
+
327
+ def get_timestep_embedding(
328
+ timesteps: torch.Tensor,
329
+ embedding_dim: int,
330
+ flip_sin_to_cos: bool = False,
331
+ downscale_freq_shift: float = 1,
332
+ scale: float = 1,
333
+ max_period: int = 10000,
334
+ ):
335
+ """
336
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
337
+
338
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
339
+ These may be fractional.
340
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
341
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
342
+ """
343
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
344
+
345
+ half_dim = embedding_dim // 2
346
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
347
+ exponent = exponent / (half_dim - downscale_freq_shift)
348
+
349
+ emb = torch.exp(exponent)
350
+ emb = timesteps[:, None].float() * emb[None, :]
351
+
352
+ # scale embeddings
353
+ emb = scale * emb
354
+
355
+ # concat sine and cosine embeddings
356
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
357
+
358
+ # flip sine and cosine embeddings
359
+ if flip_sin_to_cos:
360
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
361
+
362
+ # zero pad
363
+ if embedding_dim % 2 == 1:
364
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
365
+ return emb
366
+
367
+
368
+ # Deep Shrink: We do not common this function, because minimize dependencies.
369
+ def resize_like(x, target, mode="bicubic", align_corners=False):
370
+ org_dtype = x.dtype
371
+ if org_dtype == torch.bfloat16:
372
+ x = x.to(torch.float32)
373
+
374
+ if x.shape[-2:] != target.shape[-2:]:
375
+ if mode == "nearest":
376
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
377
+ else:
378
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
379
+
380
+ if org_dtype == torch.bfloat16:
381
+ x = x.to(org_dtype)
382
+ return x
383
+
384
+
385
+ class SampleOutput:
386
+ def __init__(self, sample):
387
+ self.sample = sample
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
392
+ super().__init__()
393
+
394
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
395
+ self.act = None
396
+ if act_fn == "silu":
397
+ self.act = nn.SiLU()
398
+ elif act_fn == "mish":
399
+ self.act = nn.Mish()
400
+
401
+ if out_dim is not None:
402
+ time_embed_dim_out = out_dim
403
+ else:
404
+ time_embed_dim_out = time_embed_dim
405
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
406
+
407
+ def forward(self, sample):
408
+ sample = self.linear_1(sample)
409
+
410
+ if self.act is not None:
411
+ sample = self.act(sample)
412
+
413
+ sample = self.linear_2(sample)
414
+ return sample
415
+
416
+
417
+ class Timesteps(nn.Module):
418
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
419
+ super().__init__()
420
+ self.num_channels = num_channels
421
+ self.flip_sin_to_cos = flip_sin_to_cos
422
+ self.downscale_freq_shift = downscale_freq_shift
423
+
424
+ def forward(self, timesteps):
425
+ t_emb = get_timestep_embedding(
426
+ timesteps,
427
+ self.num_channels,
428
+ flip_sin_to_cos=self.flip_sin_to_cos,
429
+ downscale_freq_shift=self.downscale_freq_shift,
430
+ )
431
+ return t_emb
432
+
433
+
434
+ class ResnetBlock2D(nn.Module):
435
+ def __init__(
436
+ self,
437
+ in_channels,
438
+ out_channels,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+
444
+ self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
445
+
446
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
447
+
448
+ self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
449
+
450
+ self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
451
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
452
+
453
+ # if non_linearity == "swish":
454
+ self.nonlinearity = lambda x: F.silu(x)
455
+
456
+ self.use_in_shortcut = self.in_channels != self.out_channels
457
+
458
+ self.conv_shortcut = None
459
+ if self.use_in_shortcut:
460
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
461
+
462
+ def forward(self, input_tensor, temb):
463
+ hidden_states = input_tensor
464
+
465
+ hidden_states = self.norm1(hidden_states)
466
+ hidden_states = self.nonlinearity(hidden_states)
467
+
468
+ hidden_states = self.conv1(hidden_states)
469
+
470
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
471
+ hidden_states = hidden_states + temb
472
+
473
+ hidden_states = self.norm2(hidden_states)
474
+ hidden_states = self.nonlinearity(hidden_states)
475
+
476
+ hidden_states = self.conv2(hidden_states)
477
+
478
+ if self.conv_shortcut is not None:
479
+ input_tensor = self.conv_shortcut(input_tensor)
480
+
481
+ output_tensor = input_tensor + hidden_states
482
+
483
+ return output_tensor
484
+
485
+
486
+ class DownBlock2D(nn.Module):
487
+ def __init__(
488
+ self,
489
+ in_channels: int,
490
+ out_channels: int,
491
+ add_downsample=True,
492
+ ):
493
+ super().__init__()
494
+
495
+ self.has_cross_attention = False
496
+ resnets = []
497
+
498
+ for i in range(LAYERS_PER_BLOCK):
499
+ in_channels = in_channels if i == 0 else out_channels
500
+ resnets.append(
501
+ ResnetBlock2D(
502
+ in_channels=in_channels,
503
+ out_channels=out_channels,
504
+ )
505
+ )
506
+ self.resnets = nn.ModuleList(resnets)
507
+
508
+ if add_downsample:
509
+ self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
510
+ else:
511
+ self.downsamplers = None
512
+
513
+ self.gradient_checkpointing = False
514
+
515
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
516
+ pass
517
+
518
+ def set_use_sdpa(self, sdpa):
519
+ pass
520
+
521
+ def forward(self, hidden_states, temb=None):
522
+ output_states = ()
523
+
524
+ for resnet in self.resnets:
525
+ if self.training and self.gradient_checkpointing:
526
+
527
+ def create_custom_forward(module):
528
+ def custom_forward(*inputs):
529
+ return module(*inputs)
530
+
531
+ return custom_forward
532
+
533
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
534
+ else:
535
+ hidden_states = resnet(hidden_states, temb)
536
+
537
+ output_states += (hidden_states,)
538
+
539
+ if self.downsamplers is not None:
540
+ for downsampler in self.downsamplers:
541
+ hidden_states = downsampler(hidden_states)
542
+
543
+ output_states += (hidden_states,)
544
+
545
+ return hidden_states, output_states
546
+
547
+
548
+ class Downsample2D(nn.Module):
549
+ def __init__(self, channels, out_channels):
550
+ super().__init__()
551
+
552
+ self.channels = channels
553
+ self.out_channels = out_channels
554
+
555
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
556
+
557
+ def forward(self, hidden_states):
558
+ assert hidden_states.shape[1] == self.channels
559
+ hidden_states = self.conv(hidden_states)
560
+
561
+ return hidden_states
562
+
563
+
564
+ class CrossAttention(nn.Module):
565
+ def __init__(
566
+ self,
567
+ query_dim: int,
568
+ cross_attention_dim: Optional[int] = None,
569
+ heads: int = 8,
570
+ dim_head: int = 64,
571
+ upcast_attention: bool = False,
572
+ ):
573
+ super().__init__()
574
+ inner_dim = dim_head * heads
575
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
576
+ self.upcast_attention = upcast_attention
577
+
578
+ self.scale = dim_head**-0.5
579
+ self.heads = heads
580
+
581
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
582
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
583
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
584
+
585
+ self.to_out = nn.ModuleList([])
586
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
587
+ # no dropout here
588
+
589
+ self.use_memory_efficient_attention_xformers = False
590
+ self.use_memory_efficient_attention_mem_eff = False
591
+ self.use_sdpa = False
592
+
593
+ # Attention processor
594
+ self.processor = None
595
+
596
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
597
+ self.use_memory_efficient_attention_xformers = xformers
598
+ self.use_memory_efficient_attention_mem_eff = mem_eff
599
+
600
+ def set_use_sdpa(self, sdpa):
601
+ self.use_sdpa = sdpa
602
+
603
+ def reshape_heads_to_batch_dim(self, tensor):
604
+ batch_size, seq_len, dim = tensor.shape
605
+ head_size = self.heads
606
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
607
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
608
+ return tensor
609
+
610
+ def reshape_batch_dim_to_heads(self, tensor):
611
+ batch_size, seq_len, dim = tensor.shape
612
+ head_size = self.heads
613
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
614
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
615
+ return tensor
616
+
617
+ def set_processor(self):
618
+ return self.processor
619
+
620
+ def get_processor(self):
621
+ return self.processor
622
+
623
+ def forward(self, hidden_states, context=None, mask=None, **kwargs):
624
+ if self.processor is not None:
625
+ (
626
+ hidden_states,
627
+ encoder_hidden_states,
628
+ attention_mask,
629
+ ) = translate_attention_names_from_diffusers(
630
+ hidden_states=hidden_states, context=context, mask=mask, **kwargs
631
+ )
632
+ return self.processor(
633
+ attn=self,
634
+ hidden_states=hidden_states,
635
+ encoder_hidden_states=context,
636
+ attention_mask=mask,
637
+ **kwargs
638
+ )
639
+ if self.use_memory_efficient_attention_xformers:
640
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
641
+ if self.use_memory_efficient_attention_mem_eff:
642
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
643
+ if self.use_sdpa:
644
+ return self.forward_sdpa(hidden_states, context, mask)
645
+
646
+ query = self.to_q(hidden_states)
647
+ context = context if context is not None else hidden_states
648
+ key = self.to_k(context)
649
+ value = self.to_v(context)
650
+
651
+ query = self.reshape_heads_to_batch_dim(query)
652
+ key = self.reshape_heads_to_batch_dim(key)
653
+ value = self.reshape_heads_to_batch_dim(value)
654
+
655
+ hidden_states = self._attention(query, key, value)
656
+
657
+ # linear proj
658
+ hidden_states = self.to_out[0](hidden_states)
659
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
660
+ return hidden_states
661
+
662
+ def _attention(self, query, key, value):
663
+ if self.upcast_attention:
664
+ query = query.float()
665
+ key = key.float()
666
+
667
+ attention_scores = torch.baddbmm(
668
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
669
+ query,
670
+ key.transpose(-1, -2),
671
+ beta=0,
672
+ alpha=self.scale,
673
+ )
674
+ attention_probs = attention_scores.softmax(dim=-1)
675
+
676
+ # cast back to the original dtype
677
+ attention_probs = attention_probs.to(value.dtype)
678
+
679
+ # compute attention output
680
+ hidden_states = torch.bmm(attention_probs, value)
681
+
682
+ # reshape hidden_states
683
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
+ return hidden_states
685
+
686
+ # TODO support Hypernetworks
687
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
688
+ import xformers.ops
689
+
690
+ h = self.heads
691
+ q_in = self.to_q(x)
692
+ context = context if context is not None else x
693
+ context = context.to(x.dtype)
694
+ k_in = self.to_k(context)
695
+ v_in = self.to_v(context)
696
+
697
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
698
+ del q_in, k_in, v_in
699
+
700
+ q = q.contiguous()
701
+ k = k.contiguous()
702
+ v = v.contiguous()
703
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
704
+
705
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
706
+
707
+ out = self.to_out[0](out)
708
+ return out
709
+
710
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
711
+ flash_func = FlashAttentionFunction
712
+
713
+ q_bucket_size = 512
714
+ k_bucket_size = 1024
715
+
716
+ h = self.heads
717
+ q = self.to_q(x)
718
+ context = context if context is not None else x
719
+ context = context.to(x.dtype)
720
+ k = self.to_k(context)
721
+ v = self.to_v(context)
722
+ del context, x
723
+
724
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
725
+
726
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
727
+
728
+ out = rearrange(out, "b h n d -> b n (h d)")
729
+
730
+ out = self.to_out[0](out)
731
+ return out
732
+
733
+ def forward_sdpa(self, x, context=None, mask=None):
734
+ h = self.heads
735
+ q_in = self.to_q(x)
736
+ context = context if context is not None else x
737
+ context = context.to(x.dtype)
738
+ k_in = self.to_k(context)
739
+ v_in = self.to_v(context)
740
+
741
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
742
+ del q_in, k_in, v_in
743
+
744
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
745
+
746
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
747
+
748
+ out = self.to_out[0](out)
749
+ return out
750
+
751
+ def translate_attention_names_from_diffusers(
752
+ hidden_states: torch.FloatTensor,
753
+ context: Optional[torch.FloatTensor] = None,
754
+ mask: Optional[torch.FloatTensor] = None,
755
+ # HF naming
756
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
757
+ attention_mask: Optional[torch.FloatTensor] = None
758
+ ):
759
+ # translate from hugging face diffusers
760
+ context = context if context is not None else encoder_hidden_states
761
+
762
+ # translate from hugging face diffusers
763
+ mask = mask if mask is not None else attention_mask
764
+
765
+ return hidden_states, context, mask
766
+
767
+ # feedforward
768
+ class GEGLU(nn.Module):
769
+ r"""
770
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
771
+
772
+ Parameters:
773
+ dim_in (`int`): The number of channels in the input.
774
+ dim_out (`int`): The number of channels in the output.
775
+ """
776
+
777
+ def __init__(self, dim_in: int, dim_out: int):
778
+ super().__init__()
779
+ self.proj = nn.Linear(dim_in, dim_out * 2)
780
+
781
+ def gelu(self, gate):
782
+ if gate.device.type != "mps":
783
+ return F.gelu(gate)
784
+ # mps: gelu is not implemented for float16
785
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
786
+
787
+ def forward(self, hidden_states):
788
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
789
+ return hidden_states * self.gelu(gate)
790
+
791
+
792
+ class FeedForward(nn.Module):
793
+ def __init__(
794
+ self,
795
+ dim: int,
796
+ ):
797
+ super().__init__()
798
+ inner_dim = int(dim * 4) # mult is always 4
799
+
800
+ self.net = nn.ModuleList([])
801
+ # project in
802
+ self.net.append(GEGLU(dim, inner_dim))
803
+ # project dropout
804
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
805
+ # project out
806
+ self.net.append(nn.Linear(inner_dim, dim))
807
+
808
+ def forward(self, hidden_states):
809
+ for module in self.net:
810
+ hidden_states = module(hidden_states)
811
+ return hidden_states
812
+
813
+
814
+ class BasicTransformerBlock(nn.Module):
815
+ def __init__(
816
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
817
+ ):
818
+ super().__init__()
819
+
820
+ # 1. Self-Attn
821
+ self.attn1 = CrossAttention(
822
+ query_dim=dim,
823
+ cross_attention_dim=None,
824
+ heads=num_attention_heads,
825
+ dim_head=attention_head_dim,
826
+ upcast_attention=upcast_attention,
827
+ )
828
+ self.ff = FeedForward(dim)
829
+
830
+ # 2. Cross-Attn
831
+ self.attn2 = CrossAttention(
832
+ query_dim=dim,
833
+ cross_attention_dim=cross_attention_dim,
834
+ heads=num_attention_heads,
835
+ dim_head=attention_head_dim,
836
+ upcast_attention=upcast_attention,
837
+ )
838
+
839
+ self.norm1 = nn.LayerNorm(dim)
840
+ self.norm2 = nn.LayerNorm(dim)
841
+
842
+ # 3. Feed-forward
843
+ self.norm3 = nn.LayerNorm(dim)
844
+
845
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
846
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
847
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
848
+
849
+ def set_use_sdpa(self, sdpa: bool):
850
+ self.attn1.set_use_sdpa(sdpa)
851
+ self.attn2.set_use_sdpa(sdpa)
852
+
853
+ def forward(self, hidden_states, context=None, timestep=None):
854
+ # 1. Self-Attention
855
+ norm_hidden_states = self.norm1(hidden_states)
856
+
857
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
858
+
859
+ # 2. Cross-Attention
860
+ norm_hidden_states = self.norm2(hidden_states)
861
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
862
+
863
+ # 3. Feed-forward
864
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
865
+
866
+ return hidden_states
867
+
868
+
869
+ class Transformer2DModel(nn.Module):
870
+ def __init__(
871
+ self,
872
+ num_attention_heads: int = 16,
873
+ attention_head_dim: int = 88,
874
+ in_channels: Optional[int] = None,
875
+ cross_attention_dim: Optional[int] = None,
876
+ use_linear_projection: bool = False,
877
+ upcast_attention: bool = False,
878
+ ):
879
+ super().__init__()
880
+ self.in_channels = in_channels
881
+ self.num_attention_heads = num_attention_heads
882
+ self.attention_head_dim = attention_head_dim
883
+ inner_dim = num_attention_heads * attention_head_dim
884
+ self.use_linear_projection = use_linear_projection
885
+
886
+ self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
887
+
888
+ if use_linear_projection:
889
+ self.proj_in = nn.Linear(in_channels, inner_dim)
890
+ else:
891
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
892
+
893
+ self.transformer_blocks = nn.ModuleList(
894
+ [
895
+ BasicTransformerBlock(
896
+ inner_dim,
897
+ num_attention_heads,
898
+ attention_head_dim,
899
+ cross_attention_dim=cross_attention_dim,
900
+ upcast_attention=upcast_attention,
901
+ )
902
+ ]
903
+ )
904
+
905
+ if use_linear_projection:
906
+ self.proj_out = nn.Linear(in_channels, inner_dim)
907
+ else:
908
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
909
+
910
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
911
+ for transformer in self.transformer_blocks:
912
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
913
+
914
+ def set_use_sdpa(self, sdpa):
915
+ for transformer in self.transformer_blocks:
916
+ transformer.set_use_sdpa(sdpa)
917
+
918
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
919
+ # 1. Input
920
+ batch, _, height, weight = hidden_states.shape
921
+ residual = hidden_states
922
+
923
+ hidden_states = self.norm(hidden_states)
924
+ if not self.use_linear_projection:
925
+ hidden_states = self.proj_in(hidden_states)
926
+ inner_dim = hidden_states.shape[1]
927
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
928
+ else:
929
+ inner_dim = hidden_states.shape[1]
930
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
931
+ hidden_states = self.proj_in(hidden_states)
932
+
933
+ # 2. Blocks
934
+ for block in self.transformer_blocks:
935
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
936
+
937
+ # 3. Output
938
+ if not self.use_linear_projection:
939
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
940
+ hidden_states = self.proj_out(hidden_states)
941
+ else:
942
+ hidden_states = self.proj_out(hidden_states)
943
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
944
+
945
+ output = hidden_states + residual
946
+
947
+ if not return_dict:
948
+ return (output,)
949
+
950
+ return SampleOutput(sample=output)
951
+
952
+
953
+ class CrossAttnDownBlock2D(nn.Module):
954
+ def __init__(
955
+ self,
956
+ in_channels: int,
957
+ out_channels: int,
958
+ add_downsample=True,
959
+ cross_attention_dim=1280,
960
+ attn_num_head_channels=1,
961
+ use_linear_projection=False,
962
+ upcast_attention=False,
963
+ ):
964
+ super().__init__()
965
+ self.has_cross_attention = True
966
+ resnets = []
967
+ attentions = []
968
+
969
+ self.attn_num_head_channels = attn_num_head_channels
970
+
971
+ for i in range(LAYERS_PER_BLOCK):
972
+ in_channels = in_channels if i == 0 else out_channels
973
+
974
+ resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
975
+ attentions.append(
976
+ Transformer2DModel(
977
+ attn_num_head_channels,
978
+ out_channels // attn_num_head_channels,
979
+ in_channels=out_channels,
980
+ cross_attention_dim=cross_attention_dim,
981
+ use_linear_projection=use_linear_projection,
982
+ upcast_attention=upcast_attention,
983
+ )
984
+ )
985
+ self.attentions = nn.ModuleList(attentions)
986
+ self.resnets = nn.ModuleList(resnets)
987
+
988
+ if add_downsample:
989
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
990
+ else:
991
+ self.downsamplers = None
992
+
993
+ self.gradient_checkpointing = False
994
+
995
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
996
+ for attn in self.attentions:
997
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
998
+
999
+ def set_use_sdpa(self, sdpa):
1000
+ for attn in self.attentions:
1001
+ attn.set_use_sdpa(sdpa)
1002
+
1003
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1004
+ output_states = ()
1005
+
1006
+ for resnet, attn in zip(self.resnets, self.attentions):
1007
+ if self.training and self.gradient_checkpointing:
1008
+
1009
+ def create_custom_forward(module, return_dict=None):
1010
+ def custom_forward(*inputs):
1011
+ if return_dict is not None:
1012
+ return module(*inputs, return_dict=return_dict)
1013
+ else:
1014
+ return module(*inputs)
1015
+
1016
+ return custom_forward
1017
+
1018
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1019
+ hidden_states = torch.utils.checkpoint.checkpoint(
1020
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1021
+ )[0]
1022
+ else:
1023
+ hidden_states = resnet(hidden_states, temb)
1024
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1025
+
1026
+ output_states += (hidden_states,)
1027
+
1028
+ if self.downsamplers is not None:
1029
+ for downsampler in self.downsamplers:
1030
+ hidden_states = downsampler(hidden_states)
1031
+
1032
+ output_states += (hidden_states,)
1033
+
1034
+ return hidden_states, output_states
1035
+
1036
+
1037
+ class UNetMidBlock2DCrossAttn(nn.Module):
1038
+ def __init__(
1039
+ self,
1040
+ in_channels: int,
1041
+ attn_num_head_channels=1,
1042
+ cross_attention_dim=1280,
1043
+ use_linear_projection=False,
1044
+ ):
1045
+ super().__init__()
1046
+
1047
+ self.has_cross_attention = True
1048
+ self.attn_num_head_channels = attn_num_head_channels
1049
+
1050
+ # Middle block has two resnets and one attention
1051
+ resnets = [
1052
+ ResnetBlock2D(
1053
+ in_channels=in_channels,
1054
+ out_channels=in_channels,
1055
+ ),
1056
+ ResnetBlock2D(
1057
+ in_channels=in_channels,
1058
+ out_channels=in_channels,
1059
+ ),
1060
+ ]
1061
+ attentions = [
1062
+ Transformer2DModel(
1063
+ attn_num_head_channels,
1064
+ in_channels // attn_num_head_channels,
1065
+ in_channels=in_channels,
1066
+ cross_attention_dim=cross_attention_dim,
1067
+ use_linear_projection=use_linear_projection,
1068
+ )
1069
+ ]
1070
+
1071
+ self.attentions = nn.ModuleList(attentions)
1072
+ self.resnets = nn.ModuleList(resnets)
1073
+
1074
+ self.gradient_checkpointing = False
1075
+
1076
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1077
+ for attn in self.attentions:
1078
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1079
+
1080
+ def set_use_sdpa(self, sdpa):
1081
+ for attn in self.attentions:
1082
+ attn.set_use_sdpa(sdpa)
1083
+
1084
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1085
+ for i, resnet in enumerate(self.resnets):
1086
+ attn = None if i == 0 else self.attentions[i - 1]
1087
+
1088
+ if self.training and self.gradient_checkpointing:
1089
+
1090
+ def create_custom_forward(module, return_dict=None):
1091
+ def custom_forward(*inputs):
1092
+ if return_dict is not None:
1093
+ return module(*inputs, return_dict=return_dict)
1094
+ else:
1095
+ return module(*inputs)
1096
+
1097
+ return custom_forward
1098
+
1099
+ if attn is not None:
1100
+ hidden_states = torch.utils.checkpoint.checkpoint(
1101
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1102
+ )[0]
1103
+
1104
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1105
+ else:
1106
+ if attn is not None:
1107
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
1108
+ hidden_states = resnet(hidden_states, temb)
1109
+
1110
+ return hidden_states
1111
+
1112
+
1113
+ class Upsample2D(nn.Module):
1114
+ def __init__(self, channels, out_channels):
1115
+ super().__init__()
1116
+ self.channels = channels
1117
+ self.out_channels = out_channels
1118
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
1119
+
1120
+ def forward(self, hidden_states, output_size):
1121
+ assert hidden_states.shape[1] == self.channels
1122
+
1123
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
1124
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
1125
+ # https://github.com/pytorch/pytorch/issues/86679
1126
+ dtype = hidden_states.dtype
1127
+ if dtype == torch.bfloat16:
1128
+ hidden_states = hidden_states.to(torch.float32)
1129
+
1130
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1131
+ if hidden_states.shape[0] >= 64:
1132
+ hidden_states = hidden_states.contiguous()
1133
+
1134
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
1135
+ if output_size is None:
1136
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
1137
+ else:
1138
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
1139
+
1140
+ # If the input is bfloat16, we cast back to bfloat16
1141
+ if dtype == torch.bfloat16:
1142
+ hidden_states = hidden_states.to(dtype)
1143
+
1144
+ hidden_states = self.conv(hidden_states)
1145
+
1146
+ return hidden_states
1147
+
1148
+
1149
+ class UpBlock2D(nn.Module):
1150
+ def __init__(
1151
+ self,
1152
+ in_channels: int,
1153
+ prev_output_channel: int,
1154
+ out_channels: int,
1155
+ add_upsample=True,
1156
+ ):
1157
+ super().__init__()
1158
+
1159
+ self.has_cross_attention = False
1160
+ resnets = []
1161
+
1162
+ for i in range(LAYERS_PER_BLOCK_UP):
1163
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1164
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1165
+
1166
+ resnets.append(
1167
+ ResnetBlock2D(
1168
+ in_channels=resnet_in_channels + res_skip_channels,
1169
+ out_channels=out_channels,
1170
+ )
1171
+ )
1172
+
1173
+ self.resnets = nn.ModuleList(resnets)
1174
+
1175
+ if add_upsample:
1176
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1177
+ else:
1178
+ self.upsamplers = None
1179
+
1180
+ self.gradient_checkpointing = False
1181
+
1182
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1183
+ pass
1184
+
1185
+ def set_use_sdpa(self, sdpa):
1186
+ pass
1187
+
1188
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1189
+ for resnet in self.resnets:
1190
+ # pop res hidden states
1191
+ res_hidden_states = res_hidden_states_tuple[-1]
1192
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1193
+
1194
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1195
+
1196
+ if self.training and self.gradient_checkpointing:
1197
+
1198
+ def create_custom_forward(module):
1199
+ def custom_forward(*inputs):
1200
+ return module(*inputs)
1201
+
1202
+ return custom_forward
1203
+
1204
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1205
+ else:
1206
+ hidden_states = resnet(hidden_states, temb)
1207
+
1208
+ if self.upsamplers is not None:
1209
+ for upsampler in self.upsamplers:
1210
+ hidden_states = upsampler(hidden_states, upsample_size)
1211
+
1212
+ return hidden_states
1213
+
1214
+
1215
+ class CrossAttnUpBlock2D(nn.Module):
1216
+ def __init__(
1217
+ self,
1218
+ in_channels: int,
1219
+ out_channels: int,
1220
+ prev_output_channel: int,
1221
+ attn_num_head_channels=1,
1222
+ cross_attention_dim=1280,
1223
+ add_upsample=True,
1224
+ use_linear_projection=False,
1225
+ upcast_attention=False,
1226
+ ):
1227
+ super().__init__()
1228
+ resnets = []
1229
+ attentions = []
1230
+
1231
+ self.has_cross_attention = True
1232
+ self.attn_num_head_channels = attn_num_head_channels
1233
+
1234
+ for i in range(LAYERS_PER_BLOCK_UP):
1235
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1236
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1237
+
1238
+ resnets.append(
1239
+ ResnetBlock2D(
1240
+ in_channels=resnet_in_channels + res_skip_channels,
1241
+ out_channels=out_channels,
1242
+ )
1243
+ )
1244
+ attentions.append(
1245
+ Transformer2DModel(
1246
+ attn_num_head_channels,
1247
+ out_channels // attn_num_head_channels,
1248
+ in_channels=out_channels,
1249
+ cross_attention_dim=cross_attention_dim,
1250
+ use_linear_projection=use_linear_projection,
1251
+ upcast_attention=upcast_attention,
1252
+ )
1253
+ )
1254
+
1255
+ self.attentions = nn.ModuleList(attentions)
1256
+ self.resnets = nn.ModuleList(resnets)
1257
+
1258
+ if add_upsample:
1259
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1260
+ else:
1261
+ self.upsamplers = None
1262
+
1263
+ self.gradient_checkpointing = False
1264
+
1265
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1266
+ for attn in self.attentions:
1267
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1268
+
1269
+ def set_use_sdpa(self, sdpa):
1270
+ for attn in self.attentions:
1271
+ attn.set_use_sdpa(sdpa)
1272
+
1273
+ def forward(
1274
+ self,
1275
+ hidden_states,
1276
+ res_hidden_states_tuple,
1277
+ temb=None,
1278
+ encoder_hidden_states=None,
1279
+ upsample_size=None,
1280
+ ):
1281
+ for resnet, attn in zip(self.resnets, self.attentions):
1282
+ # pop res hidden states
1283
+ res_hidden_states = res_hidden_states_tuple[-1]
1284
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1285
+
1286
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1287
+
1288
+ if self.training and self.gradient_checkpointing:
1289
+
1290
+ def create_custom_forward(module, return_dict=None):
1291
+ def custom_forward(*inputs):
1292
+ if return_dict is not None:
1293
+ return module(*inputs, return_dict=return_dict)
1294
+ else:
1295
+ return module(*inputs)
1296
+
1297
+ return custom_forward
1298
+
1299
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1300
+ hidden_states = torch.utils.checkpoint.checkpoint(
1301
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1302
+ )[0]
1303
+ else:
1304
+ hidden_states = resnet(hidden_states, temb)
1305
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1306
+
1307
+ if self.upsamplers is not None:
1308
+ for upsampler in self.upsamplers:
1309
+ hidden_states = upsampler(hidden_states, upsample_size)
1310
+
1311
+ return hidden_states
1312
+
1313
+
1314
+ def get_down_block(
1315
+ down_block_type,
1316
+ in_channels,
1317
+ out_channels,
1318
+ add_downsample,
1319
+ attn_num_head_channels,
1320
+ cross_attention_dim,
1321
+ use_linear_projection,
1322
+ upcast_attention,
1323
+ ):
1324
+ if down_block_type == "DownBlock2D":
1325
+ return DownBlock2D(
1326
+ in_channels=in_channels,
1327
+ out_channels=out_channels,
1328
+ add_downsample=add_downsample,
1329
+ )
1330
+ elif down_block_type == "CrossAttnDownBlock2D":
1331
+ return CrossAttnDownBlock2D(
1332
+ in_channels=in_channels,
1333
+ out_channels=out_channels,
1334
+ add_downsample=add_downsample,
1335
+ cross_attention_dim=cross_attention_dim,
1336
+ attn_num_head_channels=attn_num_head_channels,
1337
+ use_linear_projection=use_linear_projection,
1338
+ upcast_attention=upcast_attention,
1339
+ )
1340
+
1341
+
1342
+ def get_up_block(
1343
+ up_block_type,
1344
+ in_channels,
1345
+ out_channels,
1346
+ prev_output_channel,
1347
+ add_upsample,
1348
+ attn_num_head_channels,
1349
+ cross_attention_dim=None,
1350
+ use_linear_projection=False,
1351
+ upcast_attention=False,
1352
+ ):
1353
+ if up_block_type == "UpBlock2D":
1354
+ return UpBlock2D(
1355
+ in_channels=in_channels,
1356
+ prev_output_channel=prev_output_channel,
1357
+ out_channels=out_channels,
1358
+ add_upsample=add_upsample,
1359
+ )
1360
+ elif up_block_type == "CrossAttnUpBlock2D":
1361
+ return CrossAttnUpBlock2D(
1362
+ in_channels=in_channels,
1363
+ out_channels=out_channels,
1364
+ prev_output_channel=prev_output_channel,
1365
+ attn_num_head_channels=attn_num_head_channels,
1366
+ cross_attention_dim=cross_attention_dim,
1367
+ add_upsample=add_upsample,
1368
+ use_linear_projection=use_linear_projection,
1369
+ upcast_attention=upcast_attention,
1370
+ )
1371
+
1372
+
1373
+ class UNet2DConditionModel(nn.Module):
1374
+ _supports_gradient_checkpointing = True
1375
+
1376
+ def __init__(
1377
+ self,
1378
+ sample_size: Optional[int] = None,
1379
+ attention_head_dim: Union[int, Tuple[int]] = 8,
1380
+ cross_attention_dim: int = 1280,
1381
+ use_linear_projection: bool = False,
1382
+ upcast_attention: bool = False,
1383
+ **kwargs,
1384
+ ):
1385
+ super().__init__()
1386
+ assert sample_size is not None, "sample_size must be specified"
1387
+ logger.info(
1388
+ f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
1389
+ )
1390
+
1391
+ # 外部からの参照用に定義しておく
1392
+ self.in_channels = IN_CHANNELS
1393
+ self.out_channels = OUT_CHANNELS
1394
+
1395
+ self.sample_size = sample_size
1396
+ self.prepare_config(sample_size=sample_size)
1397
+
1398
+ # state_dictの書式が変わるのでmoduleの持ち方は変えられない
1399
+
1400
+ # input
1401
+ self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
1402
+
1403
+ # time
1404
+ self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
1405
+
1406
+ self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
1407
+
1408
+ self.down_blocks = nn.ModuleList([])
1409
+ self.mid_block = None
1410
+ self.up_blocks = nn.ModuleList([])
1411
+
1412
+ if isinstance(attention_head_dim, int):
1413
+ attention_head_dim = (attention_head_dim,) * 4
1414
+
1415
+ # down
1416
+ output_channel = BLOCK_OUT_CHANNELS[0]
1417
+ for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
1418
+ input_channel = output_channel
1419
+ output_channel = BLOCK_OUT_CHANNELS[i]
1420
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1421
+
1422
+ down_block = get_down_block(
1423
+ down_block_type,
1424
+ in_channels=input_channel,
1425
+ out_channels=output_channel,
1426
+ add_downsample=not is_final_block,
1427
+ attn_num_head_channels=attention_head_dim[i],
1428
+ cross_attention_dim=cross_attention_dim,
1429
+ use_linear_projection=use_linear_projection,
1430
+ upcast_attention=upcast_attention,
1431
+ )
1432
+ self.down_blocks.append(down_block)
1433
+
1434
+ # mid
1435
+ self.mid_block = UNetMidBlock2DCrossAttn(
1436
+ in_channels=BLOCK_OUT_CHANNELS[-1],
1437
+ attn_num_head_channels=attention_head_dim[-1],
1438
+ cross_attention_dim=cross_attention_dim,
1439
+ use_linear_projection=use_linear_projection,
1440
+ )
1441
+
1442
+ # count how many layers upsample the images
1443
+ self.num_upsamplers = 0
1444
+
1445
+ # up
1446
+ reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
1447
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
1448
+ output_channel = reversed_block_out_channels[0]
1449
+ for i, up_block_type in enumerate(UP_BLOCK_TYPES):
1450
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1451
+
1452
+ prev_output_channel = output_channel
1453
+ output_channel = reversed_block_out_channels[i]
1454
+ input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
1455
+
1456
+ # add upsample block for all BUT final layer
1457
+ if not is_final_block:
1458
+ add_upsample = True
1459
+ self.num_upsamplers += 1
1460
+ else:
1461
+ add_upsample = False
1462
+
1463
+ up_block = get_up_block(
1464
+ up_block_type,
1465
+ in_channels=input_channel,
1466
+ out_channels=output_channel,
1467
+ prev_output_channel=prev_output_channel,
1468
+ add_upsample=add_upsample,
1469
+ attn_num_head_channels=reversed_attention_head_dim[i],
1470
+ cross_attention_dim=cross_attention_dim,
1471
+ use_linear_projection=use_linear_projection,
1472
+ upcast_attention=upcast_attention,
1473
+ )
1474
+ self.up_blocks.append(up_block)
1475
+ prev_output_channel = output_channel
1476
+
1477
+ # out
1478
+ self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
1479
+ self.conv_act = nn.SiLU()
1480
+ self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
1481
+
1482
+ # region diffusers compatibility
1483
+ def prepare_config(self, *args, **kwargs):
1484
+ self.config = SimpleNamespace(**kwargs)
1485
+
1486
+ @property
1487
+ def dtype(self) -> torch.dtype:
1488
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1489
+ return get_parameter_dtype(self)
1490
+
1491
+ @property
1492
+ def device(self) -> torch.device:
1493
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1494
+ return get_parameter_device(self)
1495
+
1496
+ def set_attention_slice(self, slice_size):
1497
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1498
+
1499
+ def is_gradient_checkpointing(self) -> bool:
1500
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1501
+
1502
+ def enable_gradient_checkpointing(self):
1503
+ self.set_gradient_checkpointing(value=True)
1504
+
1505
+ def disable_gradient_checkpointing(self):
1506
+ self.set_gradient_checkpointing(value=False)
1507
+
1508
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1509
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1510
+ for module in modules:
1511
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1512
+
1513
+ def set_use_sdpa(self, sdpa: bool) -> None:
1514
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1515
+ for module in modules:
1516
+ module.set_use_sdpa(sdpa)
1517
+
1518
+ def set_gradient_checkpointing(self, value=False):
1519
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1520
+ for module in modules:
1521
+ logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
1522
+ module.gradient_checkpointing = value
1523
+
1524
+ # endregion
1525
+
1526
+ def forward(
1527
+ self,
1528
+ sample: torch.FloatTensor,
1529
+ timestep: Union[torch.Tensor, float, int],
1530
+ encoder_hidden_states: torch.Tensor,
1531
+ class_labels: Optional[torch.Tensor] = None,
1532
+ return_dict: bool = True,
1533
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1534
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1535
+ ) -> Union[Dict, Tuple]:
1536
+ r"""
1537
+ Args:
1538
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1539
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1540
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1541
+ return_dict (`bool`, *optional*, defaults to `True`):
1542
+ Whether or not to return a dict instead of a plain tuple.
1543
+
1544
+ Returns:
1545
+ `SampleOutput` or `tuple`:
1546
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1547
+ """
1548
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1549
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1550
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1551
+ # on the fly if necessary.
1552
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1553
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1554
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1555
+ default_overall_up_factor = 2**self.num_upsamplers
1556
+
1557
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1558
+ # 64で割り切れないときはupsamplerにサイズを伝える
1559
+ forward_upsample_size = False
1560
+ upsample_size = None
1561
+
1562
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1563
+ # logger.info("Forward upsample size to force interpolation output size.")
1564
+ forward_upsample_size = True
1565
+
1566
+ # 1. time
1567
+ timesteps = timestep
1568
+ timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1569
+
1570
+ t_emb = self.time_proj(timesteps)
1571
+
1572
+ # timesteps does not contain any weights and will always return f32 tensors
1573
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1574
+ # there might be better ways to encapsulate this.
1575
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1576
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1577
+ # time_projでキャストしておけばいいんじゃね?
1578
+ t_emb = t_emb.to(dtype=self.dtype)
1579
+ emb = self.time_embedding(t_emb)
1580
+
1581
+ # 2. pre-process
1582
+ sample = self.conv_in(sample)
1583
+
1584
+ down_block_res_samples = (sample,)
1585
+ for downsample_block in self.down_blocks:
1586
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1587
+ # まあこちらのほうがわかりやすいかもしれない
1588
+ if downsample_block.has_cross_attention:
1589
+ sample, res_samples = downsample_block(
1590
+ hidden_states=sample,
1591
+ temb=emb,
1592
+ encoder_hidden_states=encoder_hidden_states,
1593
+ )
1594
+ else:
1595
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1596
+
1597
+ down_block_res_samples += res_samples
1598
+
1599
+ # skip connectionにControlNetの出力を追加する
1600
+ if down_block_additional_residuals is not None:
1601
+ down_block_res_samples = list(down_block_res_samples)
1602
+ for i in range(len(down_block_res_samples)):
1603
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1604
+ down_block_res_samples = tuple(down_block_res_samples)
1605
+
1606
+ # 4. mid
1607
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1608
+
1609
+ # ControlNetの出力を追加する
1610
+ if mid_block_additional_residual is not None:
1611
+ sample += mid_block_additional_residual
1612
+
1613
+ # 5. up
1614
+ for i, upsample_block in enumerate(self.up_blocks):
1615
+ is_final_block = i == len(self.up_blocks) - 1
1616
+
1617
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1618
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1619
+
1620
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1621
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1622
+ if not is_final_block and forward_upsample_size:
1623
+ upsample_size = down_block_res_samples[-1].shape[2:]
1624
+
1625
+ if upsample_block.has_cross_attention:
1626
+ sample = upsample_block(
1627
+ hidden_states=sample,
1628
+ temb=emb,
1629
+ res_hidden_states_tuple=res_samples,
1630
+ encoder_hidden_states=encoder_hidden_states,
1631
+ upsample_size=upsample_size,
1632
+ )
1633
+ else:
1634
+ sample = upsample_block(
1635
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1636
+ )
1637
+
1638
+ # 6. post-process
1639
+ sample = self.conv_norm_out(sample)
1640
+ sample = self.conv_act(sample)
1641
+ sample = self.conv_out(sample)
1642
+
1643
+ if not return_dict:
1644
+ return (sample,)
1645
+
1646
+ return SampleOutput(sample=sample)
1647
+
1648
+ def handle_unusual_timesteps(self, sample, timesteps):
1649
+ r"""
1650
+ timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
1651
+ """
1652
+ if not torch.is_tensor(timesteps):
1653
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1654
+ # This would be a good case for the `match` statement (Python 3.10+)
1655
+ is_mps = sample.device.type == "mps"
1656
+ if isinstance(timesteps, float):
1657
+ dtype = torch.float32 if is_mps else torch.float64
1658
+ else:
1659
+ dtype = torch.int32 if is_mps else torch.int64
1660
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1661
+ elif len(timesteps.shape) == 0:
1662
+ timesteps = timesteps[None].to(sample.device)
1663
+
1664
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1665
+ timesteps = timesteps.expand(sample.shape[0])
1666
+
1667
+ return timesteps
1668
+
1669
+
1670
+ class InferUNet2DConditionModel:
1671
+ def __init__(self, original_unet: UNet2DConditionModel):
1672
+ self.delegate = original_unet
1673
+
1674
+ # override original model's forward method: because forward is not called by `__call__`
1675
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1676
+ self.delegate.forward = self.forward
1677
+
1678
+ # override original model's up blocks' forward method
1679
+ for up_block in self.delegate.up_blocks:
1680
+ if up_block.__class__.__name__ == "UpBlock2D":
1681
+
1682
+ def resnet_wrapper(func, block):
1683
+ def forward(*args, **kwargs):
1684
+ return func(block, *args, **kwargs)
1685
+
1686
+ return forward
1687
+
1688
+ up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
1689
+
1690
+ elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
1691
+
1692
+ def cross_attn_up_wrapper(func, block):
1693
+ def forward(*args, **kwargs):
1694
+ return func(block, *args, **kwargs)
1695
+
1696
+ return forward
1697
+
1698
+ up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
1699
+
1700
+ # Deep Shrink
1701
+ self.ds_depth_1 = None
1702
+ self.ds_depth_2 = None
1703
+ self.ds_timesteps_1 = None
1704
+ self.ds_timesteps_2 = None
1705
+ self.ds_ratio = None
1706
+
1707
+ # call original model's methods
1708
+ def __getattr__(self, name):
1709
+ return getattr(self.delegate, name)
1710
+
1711
+ def __call__(self, *args, **kwargs):
1712
+ return self.delegate(*args, **kwargs)
1713
+
1714
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1715
+ if ds_depth_1 is None:
1716
+ logger.info("Deep Shrink is disabled.")
1717
+ self.ds_depth_1 = None
1718
+ self.ds_timesteps_1 = None
1719
+ self.ds_depth_2 = None
1720
+ self.ds_timesteps_2 = None
1721
+ self.ds_ratio = None
1722
+ else:
1723
+ logger.info(
1724
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1725
+ )
1726
+ self.ds_depth_1 = ds_depth_1
1727
+ self.ds_timesteps_1 = ds_timesteps_1
1728
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1729
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1730
+ self.ds_ratio = ds_ratio
1731
+
1732
+ def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1733
+ for resnet in _self.resnets:
1734
+ # pop res hidden states
1735
+ res_hidden_states = res_hidden_states_tuple[-1]
1736
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1737
+
1738
+ # Deep Shrink
1739
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1740
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1741
+
1742
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1743
+ hidden_states = resnet(hidden_states, temb)
1744
+
1745
+ if _self.upsamplers is not None:
1746
+ for upsampler in _self.upsamplers:
1747
+ hidden_states = upsampler(hidden_states, upsample_size)
1748
+
1749
+ return hidden_states
1750
+
1751
+ def cross_attn_up_block_forward(
1752
+ self,
1753
+ _self,
1754
+ hidden_states,
1755
+ res_hidden_states_tuple,
1756
+ temb=None,
1757
+ encoder_hidden_states=None,
1758
+ upsample_size=None,
1759
+ ):
1760
+ for resnet, attn in zip(_self.resnets, _self.attentions):
1761
+ # pop res hidden states
1762
+ res_hidden_states = res_hidden_states_tuple[-1]
1763
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1764
+
1765
+ # Deep Shrink
1766
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1767
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1768
+
1769
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1770
+ hidden_states = resnet(hidden_states, temb)
1771
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1772
+
1773
+ if _self.upsamplers is not None:
1774
+ for upsampler in _self.upsamplers:
1775
+ hidden_states = upsampler(hidden_states, upsample_size)
1776
+
1777
+ return hidden_states
1778
+
1779
+ def forward(
1780
+ self,
1781
+ sample: torch.FloatTensor,
1782
+ timestep: Union[torch.Tensor, float, int],
1783
+ encoder_hidden_states: torch.Tensor,
1784
+ class_labels: Optional[torch.Tensor] = None,
1785
+ return_dict: bool = True,
1786
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1787
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1788
+ ) -> Union[Dict, Tuple]:
1789
+ r"""
1790
+ current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
1791
+ """
1792
+
1793
+ r"""
1794
+ Args:
1795
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1796
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1797
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1798
+ return_dict (`bool`, *optional*, defaults to `True`):
1799
+ Whether or not to return a dict instead of a plain tuple.
1800
+
1801
+ Returns:
1802
+ `SampleOutput` or `tuple`:
1803
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1804
+ """
1805
+
1806
+ _self = self.delegate
1807
+
1808
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1809
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1810
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1811
+ # on the fly if necessary.
1812
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1813
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1814
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1815
+ default_overall_up_factor = 2**_self.num_upsamplers
1816
+
1817
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1818
+ # 64で割り切れないときはupsamplerにサイズを伝える
1819
+ forward_upsample_size = False
1820
+ upsample_size = None
1821
+
1822
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1823
+ # logger.info("Forward upsample size to force interpolation output size.")
1824
+ forward_upsample_size = True
1825
+
1826
+ # 1. time
1827
+ timesteps = timestep
1828
+ timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1829
+
1830
+ t_emb = _self.time_proj(timesteps)
1831
+
1832
+ # timesteps does not contain any weights and will always return f32 tensors
1833
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1834
+ # there might be better ways to encapsulate this.
1835
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1836
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1837
+ # time_projでキャストしておけばいいんじゃね?
1838
+ t_emb = t_emb.to(dtype=_self.dtype)
1839
+ emb = _self.time_embedding(t_emb)
1840
+
1841
+ # 2. pre-process
1842
+ sample = _self.conv_in(sample)
1843
+
1844
+ down_block_res_samples = (sample,)
1845
+ for depth, downsample_block in enumerate(_self.down_blocks):
1846
+ # Deep Shrink
1847
+ if self.ds_depth_1 is not None:
1848
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1849
+ self.ds_depth_2 is not None
1850
+ and depth == self.ds_depth_2
1851
+ and timesteps[0] < self.ds_timesteps_1
1852
+ and timesteps[0] >= self.ds_timesteps_2
1853
+ ):
1854
+ org_dtype = sample.dtype
1855
+ if org_dtype == torch.bfloat16:
1856
+ sample = sample.to(torch.float32)
1857
+ sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1858
+
1859
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1860
+ # まあこちらのほうがわかりやすいかもしれない
1861
+ if downsample_block.has_cross_attention:
1862
+ sample, res_samples = downsample_block(
1863
+ hidden_states=sample,
1864
+ temb=emb,
1865
+ encoder_hidden_states=encoder_hidden_states,
1866
+ )
1867
+ else:
1868
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1869
+
1870
+ down_block_res_samples += res_samples
1871
+
1872
+ # skip connectionにControlNetの出力を追加する
1873
+ if down_block_additional_residuals is not None:
1874
+ down_block_res_samples = list(down_block_res_samples)
1875
+ for i in range(len(down_block_res_samples)):
1876
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1877
+ down_block_res_samples = tuple(down_block_res_samples)
1878
+
1879
+ # 4. mid
1880
+ sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1881
+
1882
+ # ControlNetの出力を追加する
1883
+ if mid_block_additional_residual is not None:
1884
+ sample += mid_block_additional_residual
1885
+
1886
+ # 5. up
1887
+ for i, upsample_block in enumerate(_self.up_blocks):
1888
+ is_final_block = i == len(_self.up_blocks) - 1
1889
+
1890
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1891
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1892
+
1893
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1894
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1895
+ if not is_final_block and forward_upsample_size:
1896
+ upsample_size = down_block_res_samples[-1].shape[2:]
1897
+
1898
+ if upsample_block.has_cross_attention:
1899
+ sample = upsample_block(
1900
+ hidden_states=sample,
1901
+ temb=emb,
1902
+ res_hidden_states_tuple=res_samples,
1903
+ encoder_hidden_states=encoder_hidden_states,
1904
+ upsample_size=upsample_size,
1905
+ )
1906
+ else:
1907
+ sample = upsample_block(
1908
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1909
+ )
1910
+
1911
+ # 6. post-process
1912
+ sample = _self.conv_norm_out(sample)
1913
+ sample = _self.conv_act(sample)
1914
+ sample = _self.conv_out(sample)
1915
+
1916
+ if not return_dict:
1917
+ return (sample,)
1918
+
1919
+ return SampleOutput(sample=sample)
library/sai_model_spec.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/Stability-AI/ModelSpec
2
+ import datetime
3
+ import hashlib
4
+ from io import BytesIO
5
+ import os
6
+ from typing import List, Optional, Tuple, Union
7
+ import safetensors
8
+ from library.utils import setup_logging
9
+ setup_logging()
10
+ import logging
11
+ logger = logging.getLogger(__name__)
12
+
13
+ r"""
14
+ # Metadata Example
15
+ metadata = {
16
+ # === Must ===
17
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
18
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
19
+ "modelspec.implementation": "sgm",
20
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
21
+ # === Should ===
22
+ "modelspec.author": "Example Corp", # Your name or company name
23
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
24
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
25
+ # === Can ===
26
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
27
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
28
+ }
29
+ """
30
+
31
+ BASE_METADATA = {
32
+ # === Must ===
33
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
34
+ "modelspec.architecture": None,
35
+ "modelspec.implementation": None,
36
+ "modelspec.title": None,
37
+ "modelspec.resolution": None,
38
+ # === Should ===
39
+ "modelspec.description": None,
40
+ "modelspec.author": None,
41
+ "modelspec.date": None,
42
+ # === Can ===
43
+ "modelspec.license": None,
44
+ "modelspec.tags": None,
45
+ "modelspec.merged_from": None,
46
+ "modelspec.prediction_type": None,
47
+ "modelspec.timestep_range": None,
48
+ "modelspec.encoder_layer": None,
49
+ }
50
+
51
+ # 別に使うやつだけ定義
52
+ MODELSPEC_TITLE = "modelspec.title"
53
+
54
+ ARCH_SD_V1 = "stable-diffusion-v1"
55
+ ARCH_SD_V2_512 = "stable-diffusion-v2-512"
56
+ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
57
+ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
58
+
59
+ ADAPTER_LORA = "lora"
60
+ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
61
+
62
+ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
63
+ IMPL_DIFFUSERS = "diffusers"
64
+
65
+ PRED_TYPE_EPSILON = "epsilon"
66
+ PRED_TYPE_V = "v"
67
+
68
+
69
+ def load_bytes_in_safetensors(tensors):
70
+ bytes = safetensors.torch.save(tensors)
71
+ b = BytesIO(bytes)
72
+
73
+ b.seek(0)
74
+ header = b.read(8)
75
+ n = int.from_bytes(header, "little")
76
+
77
+ offset = n + 8
78
+ b.seek(offset)
79
+
80
+ return b.read()
81
+
82
+
83
+ def precalculate_safetensors_hashes(state_dict):
84
+ # calculate each tensor one by one to reduce memory usage
85
+ hash_sha256 = hashlib.sha256()
86
+ for tensor in state_dict.values():
87
+ single_tensor_sd = {"tensor": tensor}
88
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
89
+ hash_sha256.update(bytes_for_tensor)
90
+
91
+ return f"0x{hash_sha256.hexdigest()}"
92
+
93
+
94
+ def update_hash_sha256(metadata: dict, state_dict: dict):
95
+ raise NotImplementedError
96
+
97
+
98
+ def build_metadata(
99
+ state_dict: Optional[dict],
100
+ v2: bool,
101
+ v_parameterization: bool,
102
+ sdxl: bool,
103
+ lora: bool,
104
+ textual_inversion: bool,
105
+ timestamp: float,
106
+ title: Optional[str] = None,
107
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
108
+ is_stable_diffusion_ckpt: Optional[bool] = None,
109
+ author: Optional[str] = None,
110
+ description: Optional[str] = None,
111
+ license: Optional[str] = None,
112
+ tags: Optional[str] = None,
113
+ merged_from: Optional[str] = None,
114
+ timesteps: Optional[Tuple[int, int]] = None,
115
+ clip_skip: Optional[int] = None,
116
+ ):
117
+ # if state_dict is None, hash is not calculated
118
+
119
+ metadata = {}
120
+ metadata.update(BASE_METADATA)
121
+
122
+ # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
123
+ # if state_dict is not None:
124
+ # hash = precalculate_safetensors_hashes(state_dict)
125
+ # metadata["modelspec.hash_sha256"] = hash
126
+
127
+ if sdxl:
128
+ arch = ARCH_SD_XL_V1_BASE
129
+ elif v2:
130
+ if v_parameterization:
131
+ arch = ARCH_SD_V2_768_V
132
+ else:
133
+ arch = ARCH_SD_V2_512
134
+ else:
135
+ arch = ARCH_SD_V1
136
+
137
+ if lora:
138
+ arch += f"/{ADAPTER_LORA}"
139
+ elif textual_inversion:
140
+ arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
141
+
142
+ metadata["modelspec.architecture"] = arch
143
+
144
+ if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
145
+ is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
146
+
147
+ if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
148
+ # Stable Diffusion ckpt, TI, SDXL LoRA
149
+ impl = IMPL_STABILITY_AI
150
+ else:
151
+ # v1/v2 LoRA or Diffusers
152
+ impl = IMPL_DIFFUSERS
153
+ metadata["modelspec.implementation"] = impl
154
+
155
+ if title is None:
156
+ if lora:
157
+ title = "LoRA"
158
+ elif textual_inversion:
159
+ title = "TextualInversion"
160
+ else:
161
+ title = "Checkpoint"
162
+ title += f"@{timestamp}"
163
+ metadata[MODELSPEC_TITLE] = title
164
+
165
+ if author is not None:
166
+ metadata["modelspec.author"] = author
167
+ else:
168
+ del metadata["modelspec.author"]
169
+
170
+ if description is not None:
171
+ metadata["modelspec.description"] = description
172
+ else:
173
+ del metadata["modelspec.description"]
174
+
175
+ if merged_from is not None:
176
+ metadata["modelspec.merged_from"] = merged_from
177
+ else:
178
+ del metadata["modelspec.merged_from"]
179
+
180
+ if license is not None:
181
+ metadata["modelspec.license"] = license
182
+ else:
183
+ del metadata["modelspec.license"]
184
+
185
+ if tags is not None:
186
+ metadata["modelspec.tags"] = tags
187
+ else:
188
+ del metadata["modelspec.tags"]
189
+
190
+ # remove microsecond from time
191
+ int_ts = int(timestamp)
192
+
193
+ # time to iso-8601 compliant date
194
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
195
+ metadata["modelspec.date"] = date
196
+
197
+ if reso is not None:
198
+ # comma separated to tuple
199
+ if isinstance(reso, str):
200
+ reso = tuple(map(int, reso.split(",")))
201
+ if len(reso) == 1:
202
+ reso = (reso[0], reso[0])
203
+ else:
204
+ # resolution is defined in dataset, so use default
205
+ if sdxl:
206
+ reso = 1024
207
+ elif v2 and v_parameterization:
208
+ reso = 768
209
+ else:
210
+ reso = 512
211
+ if isinstance(reso, int):
212
+ reso = (reso, reso)
213
+
214
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
215
+
216
+ if v_parameterization:
217
+ metadata["modelspec.prediction_type"] = PRED_TYPE_V
218
+ else:
219
+ metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
220
+
221
+ if timesteps is not None:
222
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
223
+ timesteps = (timesteps, timesteps)
224
+ if len(timesteps) == 1:
225
+ timesteps = (timesteps[0], timesteps[0])
226
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
227
+ else:
228
+ del metadata["modelspec.timestep_range"]
229
+
230
+ if clip_skip is not None:
231
+ metadata["modelspec.encoder_layer"] = f"{clip_skip}"
232
+ else:
233
+ del metadata["modelspec.encoder_layer"]
234
+
235
+ # # assert all values are filled
236
+ # assert all([v is not None for v in metadata.values()]), metadata
237
+ if not all([v is not None for v in metadata.values()]):
238
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
239
+
240
+ return metadata
241
+
242
+
243
+ # region utils
244
+
245
+
246
+ def get_title(metadata: dict) -> Optional[str]:
247
+ return metadata.get(MODELSPEC_TITLE, None)
248
+
249
+
250
+ def load_metadata_from_safetensors(model: str) -> dict:
251
+ if not model.endswith(".safetensors"):
252
+ return {}
253
+
254
+ with safetensors.safe_open(model, framework="pt") as f:
255
+ metadata = f.metadata()
256
+ if metadata is None:
257
+ metadata = {}
258
+ return metadata
259
+
260
+
261
+ def build_merged_from(models: List[str]) -> str:
262
+ def get_title(model: str):
263
+ metadata = load_metadata_from_safetensors(model)
264
+ title = metadata.get(MODELSPEC_TITLE, None)
265
+ if title is None:
266
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
267
+ return title
268
+
269
+ titles = [get_title(model) for model in models]
270
+ return ", ".join(titles)
271
+
272
+
273
+ # endregion
274
+
275
+
276
+ r"""
277
+ if __name__ == "__main__":
278
+ import argparse
279
+ import torch
280
+ from safetensors.torch import load_file
281
+ from library import train_util
282
+
283
+ parser = argparse.ArgumentParser()
284
+ parser.add_argument("--ckpt", type=str, required=True)
285
+ args = parser.parse_args()
286
+
287
+ print(f"Loading {args.ckpt}")
288
+ state_dict = load_file(args.ckpt)
289
+
290
+ print(f"Calculating metadata")
291
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
292
+ print(metadata)
293
+ del state_dict
294
+
295
+ # by reference implementation
296
+ with open(args.ckpt, mode="rb") as file_data:
297
+ file_hash = hashlib.sha256()
298
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
299
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
300
+ content = (
301
+ file_data.read()
302
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
303
+ file_hash.update(content)
304
+ # ===== Update the hash for modelspec =====
305
+ by_ref = f"0x{file_hash.hexdigest()}"
306
+ print(by_ref)
307
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
308
+
309
+ """
library/sdxl_lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from tqdm import tqdm
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+ from PIL import Image
20
+
21
+ from library import sdxl_model_util, sdxl_train_util, train_util
22
+
23
+
24
+ try:
25
+ from diffusers.utils import PIL_INTERPOLATION
26
+ except ImportError:
27
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
28
+ PIL_INTERPOLATION = {
29
+ "linear": PIL.Image.Resampling.BILINEAR,
30
+ "bilinear": PIL.Image.Resampling.BILINEAR,
31
+ "bicubic": PIL.Image.Resampling.BICUBIC,
32
+ "lanczos": PIL.Image.Resampling.LANCZOS,
33
+ "nearest": PIL.Image.Resampling.NEAREST,
34
+ }
35
+ else:
36
+ PIL_INTERPOLATION = {
37
+ "linear": PIL.Image.LINEAR,
38
+ "bilinear": PIL.Image.BILINEAR,
39
+ "bicubic": PIL.Image.BICUBIC,
40
+ "lanczos": PIL.Image.LANCZOS,
41
+ "nearest": PIL.Image.NEAREST,
42
+ }
43
+ # ------------------------------------------------------------------------------
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ re_attention = re.compile(
48
+ r"""
49
+ \\\(|
50
+ \\\)|
51
+ \\\[|
52
+ \\]|
53
+ \\\\|
54
+ \\|
55
+ \(|
56
+ \[|
57
+ :([+-]?[.\d]+)\)|
58
+ \)|
59
+ ]|
60
+ [^\\()\[\]:]+|
61
+ :
62
+ """,
63
+ re.X,
64
+ )
65
+
66
+
67
+ def parse_prompt_attention(text):
68
+ """
69
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
70
+ Accepted tokens are:
71
+ (abc) - increases attention to abc by a multiplier of 1.1
72
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
73
+ [abc] - decreases attention to abc by a multiplier of 1.1
74
+ \( - literal character '('
75
+ \[ - literal character '['
76
+ \) - literal character ')'
77
+ \] - literal character ']'
78
+ \\ - literal character '\'
79
+ anything else - just text
80
+ >>> parse_prompt_attention('normal text')
81
+ [['normal text', 1.0]]
82
+ >>> parse_prompt_attention('an (important) word')
83
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
84
+ >>> parse_prompt_attention('(unbalanced')
85
+ [['unbalanced', 1.1]]
86
+ >>> parse_prompt_attention('\(literal\]')
87
+ [['(literal]', 1.0]]
88
+ >>> parse_prompt_attention('(unnecessary)(parens)')
89
+ [['unnecessaryparens', 1.1]]
90
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
91
+ [['a ', 1.0],
92
+ ['house', 1.5730000000000004],
93
+ [' ', 1.1],
94
+ ['on', 1.0],
95
+ [' a ', 1.1],
96
+ ['hill', 0.55],
97
+ [', sun, ', 1.1],
98
+ ['sky', 1.4641000000000006],
99
+ ['.', 1.1]]
100
+ """
101
+
102
+ res = []
103
+ round_brackets = []
104
+ square_brackets = []
105
+
106
+ round_bracket_multiplier = 1.1
107
+ square_bracket_multiplier = 1 / 1.1
108
+
109
+ def multiply_range(start_position, multiplier):
110
+ for p in range(start_position, len(res)):
111
+ res[p][1] *= multiplier
112
+
113
+ for m in re_attention.finditer(text):
114
+ text = m.group(0)
115
+ weight = m.group(1)
116
+
117
+ if text.startswith("\\"):
118
+ res.append([text[1:], 1.0])
119
+ elif text == "(":
120
+ round_brackets.append(len(res))
121
+ elif text == "[":
122
+ square_brackets.append(len(res))
123
+ elif weight is not None and len(round_brackets) > 0:
124
+ multiply_range(round_brackets.pop(), float(weight))
125
+ elif text == ")" and len(round_brackets) > 0:
126
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
127
+ elif text == "]" and len(square_brackets) > 0:
128
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
129
+ else:
130
+ res.append([text, 1.0])
131
+
132
+ for pos in round_brackets:
133
+ multiply_range(pos, round_bracket_multiplier)
134
+
135
+ for pos in square_brackets:
136
+ multiply_range(pos, square_bracket_multiplier)
137
+
138
+ if len(res) == 0:
139
+ res = [["", 1.0]]
140
+
141
+ # merge runs of identical weights
142
+ i = 0
143
+ while i + 1 < len(res):
144
+ if res[i][1] == res[i + 1][1]:
145
+ res[i][0] += res[i + 1][0]
146
+ res.pop(i + 1)
147
+ else:
148
+ i += 1
149
+
150
+ return res
151
+
152
+
153
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
154
+ r"""
155
+ Tokenize a list of prompts and return its tokens with weights of each token.
156
+
157
+ No padding, starting or ending token is included.
158
+ """
159
+ tokens = []
160
+ weights = []
161
+ truncated = False
162
+ for text in prompt:
163
+ texts_and_weights = parse_prompt_attention(text)
164
+ text_token = []
165
+ text_weight = []
166
+ for word, weight in texts_and_weights:
167
+ # tokenize and discard the starting and the ending token
168
+ token = pipe.tokenizer(word).input_ids[1:-1]
169
+ text_token += token
170
+ # copy the weight by length of token
171
+ text_weight += [weight] * len(token)
172
+ # stop if the text is too long (longer than truncation limit)
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ break
176
+ # truncate
177
+ if len(text_token) > max_length:
178
+ truncated = True
179
+ text_token = text_token[:max_length]
180
+ text_weight = text_weight[:max_length]
181
+ tokens.append(text_token)
182
+ weights.append(text_weight)
183
+ if truncated:
184
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
185
+ return tokens, weights
186
+
187
+
188
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
189
+ r"""
190
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
191
+ """
192
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
193
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
194
+ for i in range(len(tokens)):
195
+ tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
196
+ if no_boseos_middle:
197
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
198
+ else:
199
+ w = []
200
+ if len(weights[i]) == 0:
201
+ w = [1.0] * weights_length
202
+ else:
203
+ for j in range(max_embeddings_multiples):
204
+ w.append(1.0) # weight for starting token in this chunk
205
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
206
+ w.append(1.0) # weight for ending token in this chunk
207
+ w += [1.0] * (weights_length - len(w))
208
+ weights[i] = w[:]
209
+
210
+ return tokens, weights
211
+
212
+
213
+ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
214
+ if not is_sdxl_text_encoder2:
215
+ # text_encoder1: same as SD1/2
216
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
217
+ hidden_states = enc_out["hidden_states"][11]
218
+ pool = None
219
+ else:
220
+ # text_encoder2
221
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
222
+ hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
223
+ # pool = enc_out["text_embeds"]
224
+ pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
225
+ hidden_states = hidden_states.to(device)
226
+ if pool is not None:
227
+ pool = pool.to(device)
228
+ return hidden_states, pool
229
+
230
+
231
+ def get_unweighted_text_embeddings(
232
+ pipe: StableDiffusionPipeline,
233
+ text_input: torch.Tensor,
234
+ chunk_length: int,
235
+ clip_skip: int,
236
+ eos: int,
237
+ pad: int,
238
+ is_sdxl_text_encoder2: bool,
239
+ no_boseos_middle: Optional[bool] = True,
240
+ ):
241
+ """
242
+ When the length of tokens is a multiple of the capacity of the text encoder,
243
+ it should be split into chunks and sent to the text encoder individually.
244
+ """
245
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
246
+ text_pool = None
247
+ if max_embeddings_multiples > 1:
248
+ text_embeddings = []
249
+ for i in range(max_embeddings_multiples):
250
+ # extract the i-th chunk
251
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
252
+
253
+ # cover the head and the tail by the starting and the ending tokens
254
+ text_input_chunk[:, 0] = text_input[0, 0]
255
+ if pad == eos: # v1
256
+ text_input_chunk[:, -1] = text_input[0, -1]
257
+ else: # v2
258
+ for j in range(len(text_input_chunk)):
259
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
260
+ text_input_chunk[j, -1] = eos
261
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
262
+ text_input_chunk[j, 1] = eos
263
+
264
+ text_embedding, current_text_pool = get_hidden_states(
265
+ pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
266
+ )
267
+ if text_pool is None:
268
+ text_pool = current_text_pool
269
+
270
+ if no_boseos_middle:
271
+ if i == 0:
272
+ # discard the ending token
273
+ text_embedding = text_embedding[:, :-1]
274
+ elif i == max_embeddings_multiples - 1:
275
+ # discard the starting token
276
+ text_embedding = text_embedding[:, 1:]
277
+ else:
278
+ # discard both starting and ending tokens
279
+ text_embedding = text_embedding[:, 1:-1]
280
+
281
+ text_embeddings.append(text_embedding)
282
+ text_embeddings = torch.concat(text_embeddings, axis=1)
283
+ else:
284
+ text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
285
+ return text_embeddings, text_pool
286
+
287
+
288
+ def get_weighted_text_embeddings(
289
+ pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
290
+ prompt: Union[str, List[str]],
291
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
292
+ max_embeddings_multiples: Optional[int] = 3,
293
+ no_boseos_middle: Optional[bool] = False,
294
+ skip_parsing: Optional[bool] = False,
295
+ skip_weighting: Optional[bool] = False,
296
+ clip_skip=None,
297
+ is_sdxl_text_encoder2=False,
298
+ ):
299
+ r"""
300
+ Prompts can be assigned with local weights using brackets. For example,
301
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
302
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
303
+
304
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
305
+
306
+ Args:
307
+ pipe (`StableDiffusionPipeline`):
308
+ Pipe to provide access to the tokenizer and the text encoder.
309
+ prompt (`str` or `List[str]`):
310
+ The prompt or prompts to guide the image generation.
311
+ uncond_prompt (`str` or `List[str]`):
312
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
313
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
314
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
315
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
316
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
317
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
318
+ ending token in each of the chunk in the middle.
319
+ skip_parsing (`bool`, *optional*, defaults to `False`):
320
+ Skip the parsing of brackets.
321
+ skip_weighting (`bool`, *optional*, defaults to `False`):
322
+ Skip the weighting. When the parsing is skipped, it is forced True.
323
+ """
324
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
325
+ if isinstance(prompt, str):
326
+ prompt = [prompt]
327
+
328
+ if not skip_parsing:
329
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
330
+ if uncond_prompt is not None:
331
+ if isinstance(uncond_prompt, str):
332
+ uncond_prompt = [uncond_prompt]
333
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
334
+ else:
335
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
336
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
337
+ if uncond_prompt is not None:
338
+ if isinstance(uncond_prompt, str):
339
+ uncond_prompt = [uncond_prompt]
340
+ uncond_tokens = [
341
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
342
+ ]
343
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
344
+
345
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
346
+ max_length = max([len(token) for token in prompt_tokens])
347
+ if uncond_prompt is not None:
348
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
349
+
350
+ max_embeddings_multiples = min(
351
+ max_embeddings_multiples,
352
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
353
+ )
354
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
355
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
356
+
357
+ # pad the length of tokens and weights
358
+ bos = pipe.tokenizer.bos_token_id
359
+ eos = pipe.tokenizer.eos_token_id
360
+ pad = pipe.tokenizer.pad_token_id
361
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
362
+ prompt_tokens,
363
+ prompt_weights,
364
+ max_length,
365
+ bos,
366
+ eos,
367
+ pad,
368
+ no_boseos_middle=no_boseos_middle,
369
+ chunk_length=pipe.tokenizer.model_max_length,
370
+ )
371
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
372
+ if uncond_prompt is not None:
373
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
374
+ uncond_tokens,
375
+ uncond_weights,
376
+ max_length,
377
+ bos,
378
+ eos,
379
+ pad,
380
+ no_boseos_middle=no_boseos_middle,
381
+ chunk_length=pipe.tokenizer.model_max_length,
382
+ )
383
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
384
+
385
+ # get the embeddings
386
+ text_embeddings, text_pool = get_unweighted_text_embeddings(
387
+ pipe,
388
+ prompt_tokens,
389
+ pipe.tokenizer.model_max_length,
390
+ clip_skip,
391
+ eos,
392
+ pad,
393
+ is_sdxl_text_encoder2,
394
+ no_boseos_middle=no_boseos_middle,
395
+ )
396
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
397
+
398
+ if uncond_prompt is not None:
399
+ uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
400
+ pipe,
401
+ uncond_tokens,
402
+ pipe.tokenizer.model_max_length,
403
+ clip_skip,
404
+ eos,
405
+ pad,
406
+ is_sdxl_text_encoder2,
407
+ no_boseos_middle=no_boseos_middle,
408
+ )
409
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
410
+
411
+ # assign weights to the prompts and normalize in the sense of mean
412
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
413
+ if (not skip_parsing) and (not skip_weighting):
414
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
415
+ text_embeddings *= prompt_weights.unsqueeze(-1)
416
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
417
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
418
+ if uncond_prompt is not None:
419
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
420
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
421
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
422
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
423
+
424
+ if uncond_prompt is not None:
425
+ return text_embeddings, text_pool, uncond_embeddings, uncond_pool
426
+ return text_embeddings, text_pool, None, None
427
+
428
+
429
+ def preprocess_image(image):
430
+ w, h = image.size
431
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
432
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
433
+ image = np.array(image).astype(np.float32) / 255.0
434
+ image = image[None].transpose(0, 3, 1, 2)
435
+ image = torch.from_numpy(image)
436
+ return 2.0 * image - 1.0
437
+
438
+
439
+ def preprocess_mask(mask, scale_factor=8):
440
+ mask = mask.convert("L")
441
+ w, h = mask.size
442
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
443
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
444
+ mask = np.array(mask).astype(np.float32) / 255.0
445
+ mask = np.tile(mask, (4, 1, 1))
446
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
447
+ mask = 1 - mask # repaint white, keep black
448
+ mask = torch.from_numpy(mask)
449
+ return mask
450
+
451
+
452
+ def prepare_controlnet_image(
453
+ image: PIL.Image.Image,
454
+ width: int,
455
+ height: int,
456
+ batch_size: int,
457
+ num_images_per_prompt: int,
458
+ device: torch.device,
459
+ dtype: torch.dtype,
460
+ do_classifier_free_guidance: bool = False,
461
+ guess_mode: bool = False,
462
+ ):
463
+ if not isinstance(image, torch.Tensor):
464
+ if isinstance(image, PIL.Image.Image):
465
+ image = [image]
466
+
467
+ if isinstance(image[0], PIL.Image.Image):
468
+ images = []
469
+
470
+ for image_ in image:
471
+ image_ = image_.convert("RGB")
472
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
473
+ image_ = np.array(image_)
474
+ image_ = image_[None, :]
475
+ images.append(image_)
476
+
477
+ image = images
478
+
479
+ image = np.concatenate(image, axis=0)
480
+ image = np.array(image).astype(np.float32) / 255.0
481
+ image = image.transpose(0, 3, 1, 2)
482
+ image = torch.from_numpy(image)
483
+ elif isinstance(image[0], torch.Tensor):
484
+ image = torch.cat(image, dim=0)
485
+
486
+ image_batch_size = image.shape[0]
487
+
488
+ if image_batch_size == 1:
489
+ repeat_by = batch_size
490
+ else:
491
+ # image batch size is the same as prompt batch size
492
+ repeat_by = num_images_per_prompt
493
+
494
+ image = image.repeat_interleave(repeat_by, dim=0)
495
+
496
+ image = image.to(device=device, dtype=dtype)
497
+
498
+ if do_classifier_free_guidance and not guess_mode:
499
+ image = torch.cat([image] * 2)
500
+
501
+ return image
502
+
503
+
504
+ class SdxlStableDiffusionLongPromptWeightingPipeline:
505
+ r"""
506
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
507
+ weighting in prompt.
508
+
509
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
510
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
511
+
512
+ Args:
513
+ vae ([`AutoencoderKL`]):
514
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
515
+ text_encoder ([`CLIPTextModel`]):
516
+ Frozen text-encoder. Stable Diffusion uses the text portion of
517
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
518
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
519
+ tokenizer (`CLIPTokenizer`):
520
+ Tokenizer of class
521
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
522
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
523
+ scheduler ([`SchedulerMixin`]):
524
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
525
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
526
+ safety_checker ([`StableDiffusionSafetyChecker`]):
527
+ Classification module that estimates whether generated images could be considered offensive or harmful.
528
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
529
+ feature_extractor ([`CLIPFeatureExtractor`]):
530
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
531
+ """
532
+
533
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
534
+
535
+ def __init__(
536
+ self,
537
+ vae: AutoencoderKL,
538
+ text_encoder: List[CLIPTextModel],
539
+ tokenizer: List[CLIPTokenizer],
540
+ unet: UNet2DConditionModel,
541
+ scheduler: SchedulerMixin,
542
+ # clip_skip: int,
543
+ safety_checker: StableDiffusionSafetyChecker,
544
+ feature_extractor: CLIPFeatureExtractor,
545
+ requires_safety_checker: bool = True,
546
+ clip_skip: int = 1,
547
+ ):
548
+ # clip skip is ignored currently
549
+ self.tokenizer = tokenizer[0]
550
+ self.text_encoder = text_encoder[0]
551
+ self.unet = unet
552
+ self.scheduler = scheduler
553
+ self.safety_checker = safety_checker
554
+ self.feature_extractor = feature_extractor
555
+ self.requires_safety_checker = requires_safety_checker
556
+ self.vae = vae
557
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
558
+ self.progress_bar = lambda x: tqdm(x, leave=False)
559
+
560
+ self.clip_skip = clip_skip
561
+ self.tokenizers = tokenizer
562
+ self.text_encoders = text_encoder
563
+
564
+ # self.__init__additional__()
565
+
566
+ # def __init__additional__(self):
567
+ # if not hasattr(self, "vae_scale_factor"):
568
+ # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
569
+
570
+ def to(self, device=None, dtype=None):
571
+ if device is not None:
572
+ self.device = device
573
+ # self.vae.to(device=self.device)
574
+ if dtype is not None:
575
+ self.dtype = dtype
576
+
577
+ # do not move Text Encoders to device, because Text Encoder should be on CPU
578
+
579
+ @property
580
+ def _execution_device(self):
581
+ r"""
582
+ Returns the device on which the pipeline's models will be executed. After calling
583
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
584
+ hooks.
585
+ """
586
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
587
+ return self.device
588
+ for module in self.unet.modules():
589
+ if (
590
+ hasattr(module, "_hf_hook")
591
+ and hasattr(module._hf_hook, "execution_device")
592
+ and module._hf_hook.execution_device is not None
593
+ ):
594
+ return torch.device(module._hf_hook.execution_device)
595
+ return self.device
596
+
597
+ def _encode_prompt(
598
+ self,
599
+ prompt,
600
+ device,
601
+ num_images_per_prompt,
602
+ do_classifier_free_guidance,
603
+ negative_prompt,
604
+ max_embeddings_multiples,
605
+ is_sdxl_text_encoder2,
606
+ ):
607
+ r"""
608
+ Encodes the prompt into text encoder hidden states.
609
+
610
+ Args:
611
+ prompt (`str` or `list(int)`):
612
+ prompt to be encoded
613
+ device: (`torch.device`):
614
+ torch device
615
+ num_images_per_prompt (`int`):
616
+ number of images that should be generated per prompt
617
+ do_classifier_free_guidance (`bool`):
618
+ whether to use classifier free guidance or not
619
+ negative_prompt (`str` or `List[str]`):
620
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
621
+ if `guidance_scale` is less than `1`).
622
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
623
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
624
+ """
625
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
626
+
627
+ if negative_prompt is None:
628
+ negative_prompt = [""] * batch_size
629
+ elif isinstance(negative_prompt, str):
630
+ negative_prompt = [negative_prompt] * batch_size
631
+ if batch_size != len(negative_prompt):
632
+ raise ValueError(
633
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
634
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
635
+ " the batch size of `prompt`."
636
+ )
637
+
638
+ text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
639
+ pipe=self,
640
+ prompt=prompt,
641
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
642
+ max_embeddings_multiples=max_embeddings_multiples,
643
+ clip_skip=self.clip_skip,
644
+ is_sdxl_text_encoder2=is_sdxl_text_encoder2,
645
+ )
646
+ bs_embed, seq_len, _ = text_embeddings.shape
647
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
648
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
649
+ if text_pool is not None:
650
+ text_pool = text_pool.repeat(1, num_images_per_prompt)
651
+ text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
652
+
653
+ if do_classifier_free_guidance:
654
+ bs_embed, seq_len, _ = uncond_embeddings.shape
655
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
656
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
657
+ if uncond_pool is not None:
658
+ uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
659
+ uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
660
+
661
+ return text_embeddings, text_pool, uncond_embeddings, uncond_pool
662
+
663
+ return text_embeddings, text_pool, None, None
664
+
665
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
666
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
667
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
668
+
669
+ if strength < 0 or strength > 1:
670
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
671
+
672
+ if height % 8 != 0 or width % 8 != 0:
673
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
674
+
675
+ if (callback_steps is None) or (
676
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
677
+ ):
678
+ raise ValueError(
679
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
680
+ )
681
+
682
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
683
+ if is_text2img:
684
+ return self.scheduler.timesteps.to(device), num_inference_steps
685
+ else:
686
+ # get the original timestep using init_timestep
687
+ offset = self.scheduler.config.get("steps_offset", 0)
688
+ init_timestep = int(num_inference_steps * strength) + offset
689
+ init_timestep = min(init_timestep, num_inference_steps)
690
+
691
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
692
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
693
+ return timesteps, num_inference_steps - t_start
694
+
695
+ def run_safety_checker(self, image, device, dtype):
696
+ if self.safety_checker is not None:
697
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
698
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
699
+ else:
700
+ has_nsfw_concept = None
701
+ return image, has_nsfw_concept
702
+
703
+ def decode_latents(self, latents):
704
+ with torch.no_grad():
705
+ latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
706
+
707
+ # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
708
+ # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
709
+ # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
710
+ # self.vae.to("cpu")
711
+ # self.vae.set_use_memory_efficient_attention_xformers(False)
712
+ # image = self.vae.decode(latents.to("cpu")).sample
713
+
714
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
715
+ image = (image / 2 + 0.5).clamp(0, 1)
716
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
717
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
718
+ return image
719
+
720
+ def prepare_extra_step_kwargs(self, generator, eta):
721
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
722
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
723
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
724
+ # and should be between [0, 1]
725
+
726
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
727
+ extra_step_kwargs = {}
728
+ if accepts_eta:
729
+ extra_step_kwargs["eta"] = eta
730
+
731
+ # check if the scheduler accepts generator
732
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
733
+ if accepts_generator:
734
+ extra_step_kwargs["generator"] = generator
735
+ return extra_step_kwargs
736
+
737
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
738
+ if image is None:
739
+ shape = (
740
+ batch_size,
741
+ self.unet.in_channels,
742
+ height // self.vae_scale_factor,
743
+ width // self.vae_scale_factor,
744
+ )
745
+
746
+ if latents is None:
747
+ if device.type == "mps":
748
+ # randn does not work reproducibly on mps
749
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
750
+ else:
751
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
752
+ else:
753
+ if latents.shape != shape:
754
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
755
+ latents = latents.to(device)
756
+
757
+ # scale the initial noise by the standard deviation required by the scheduler
758
+ latents = latents * self.scheduler.init_noise_sigma
759
+ return latents, None, None
760
+ else:
761
+ init_latent_dist = self.vae.encode(image).latent_dist
762
+ init_latents = init_latent_dist.sample(generator=generator)
763
+ init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
764
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
765
+ init_latents_orig = init_latents
766
+ shape = init_latents.shape
767
+
768
+ # add noise to latents using the timesteps
769
+ if device.type == "mps":
770
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
771
+ else:
772
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
773
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
774
+ return latents, init_latents_orig, noise
775
+
776
+ @torch.no_grad()
777
+ def __call__(
778
+ self,
779
+ prompt: Union[str, List[str]],
780
+ negative_prompt: Optional[Union[str, List[str]]] = None,
781
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
782
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
783
+ height: int = 512,
784
+ width: int = 512,
785
+ num_inference_steps: int = 50,
786
+ guidance_scale: float = 7.5,
787
+ strength: float = 0.8,
788
+ num_images_per_prompt: Optional[int] = 1,
789
+ eta: float = 0.0,
790
+ generator: Optional[torch.Generator] = None,
791
+ latents: Optional[torch.FloatTensor] = None,
792
+ max_embeddings_multiples: Optional[int] = 3,
793
+ output_type: Optional[str] = "pil",
794
+ return_dict: bool = True,
795
+ controlnet=None,
796
+ controlnet_image=None,
797
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
798
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
799
+ callback_steps: int = 1,
800
+ ):
801
+ r"""
802
+ Function invoked when calling the pipeline for generation.
803
+
804
+ Args:
805
+ prompt (`str` or `List[str]`):
806
+ The prompt or prompts to guide the image generation.
807
+ negative_prompt (`str` or `List[str]`, *optional*):
808
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
809
+ if `guidance_scale` is less than `1`).
810
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
811
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
812
+ process.
813
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
814
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
815
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
816
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
817
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
818
+ height (`int`, *optional*, defaults to 512):
819
+ The height in pixels of the generated image.
820
+ width (`int`, *optional*, defaults to 512):
821
+ The width in pixels of the generated image.
822
+ num_inference_steps (`int`, *optional*, defaults to 50):
823
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
824
+ expense of slower inference.
825
+ guidance_scale (`float`, *optional*, defaults to 7.5):
826
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
827
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
828
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
829
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
830
+ usually at the expense of lower image quality.
831
+ strength (`float`, *optional*, defaults to 0.8):
832
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
833
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
834
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
835
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
836
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
837
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
838
+ The number of images to generate per prompt.
839
+ eta (`float`, *optional*, defaults to 0.0):
840
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
841
+ [`schedulers.DDIMScheduler`], will be ignored for others.
842
+ generator (`torch.Generator`, *optional*):
843
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
844
+ deterministic.
845
+ latents (`torch.FloatTensor`, *optional*):
846
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
847
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
848
+ tensor will ge generated by sampling using the supplied random `generator`.
849
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
850
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
851
+ output_type (`str`, *optional*, defaults to `"pil"`):
852
+ The output format of the generate image. Choose between
853
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
854
+ return_dict (`bool`, *optional*, defaults to `True`):
855
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
856
+ plain tuple.
857
+ controlnet (`diffusers.ControlNetModel`, *optional*):
858
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
859
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
860
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
861
+ inference.
862
+ callback (`Callable`, *optional*):
863
+ A function that will be called every `callback_steps` steps during inference. The function will be
864
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
865
+ is_cancelled_callback (`Callable`, *optional*):
866
+ A function that will be called every `callback_steps` steps during inference. If the function returns
867
+ `True`, the inference will be cancelled.
868
+ callback_steps (`int`, *optional*, defaults to 1):
869
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
870
+ called at every step.
871
+
872
+ Returns:
873
+ `None` if cancelled by `is_cancelled_callback`,
874
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
875
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
876
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
877
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
878
+ (nsfw) content, according to the `safety_checker`.
879
+ """
880
+ if controlnet is not None and controlnet_image is None:
881
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
882
+
883
+ # 0. Default height and width to unet
884
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
885
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
886
+
887
+ # 1. Check inputs. Raise error if not correct
888
+ self.check_inputs(prompt, height, width, strength, callback_steps)
889
+
890
+ # 2. Define call parameters
891
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
892
+ device = self._execution_device
893
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
894
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
895
+ # corresponds to doing no classifier free guidance.
896
+ do_classifier_free_guidance = guidance_scale > 1.0
897
+
898
+ # 3. Encode input prompt
899
+ # 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
900
+ # To simplify the implementation, switch the tokenzer/text encoder and call it twice
901
+ text_embeddings_list = []
902
+ text_pool = None
903
+ uncond_embeddings_list = []
904
+ uncond_pool = None
905
+ for i in range(len(self.tokenizers)):
906
+ self.tokenizer = self.tokenizers[i]
907
+ self.text_encoder = self.text_encoders[i]
908
+
909
+ text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
910
+ prompt,
911
+ device,
912
+ num_images_per_prompt,
913
+ do_classifier_free_guidance,
914
+ negative_prompt,
915
+ max_embeddings_multiples,
916
+ is_sdxl_text_encoder2=i == 1,
917
+ )
918
+ text_embeddings_list.append(text_embeddings)
919
+ uncond_embeddings_list.append(uncond_embeddings)
920
+
921
+ if tp1 is not None:
922
+ text_pool = tp1
923
+ if up1 is not None:
924
+ uncond_pool = up1
925
+
926
+ unet_dtype = self.unet.dtype
927
+ dtype = unet_dtype
928
+ if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
929
+ dtype = torch.float16
930
+ self.unet.to(dtype)
931
+
932
+ # 4. Preprocess image and mask
933
+ if isinstance(image, PIL.Image.Image):
934
+ image = preprocess_image(image)
935
+ if image is not None:
936
+ image = image.to(device=self.device, dtype=dtype)
937
+ if isinstance(mask_image, PIL.Image.Image):
938
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
939
+ if mask_image is not None:
940
+ mask = mask_image.to(device=self.device, dtype=dtype)
941
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
942
+ else:
943
+ mask = None
944
+
945
+ # ControlNet is not working yet in SDXL, but keep the code here for future use
946
+ if controlnet_image is not None:
947
+ controlnet_image = prepare_controlnet_image(
948
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
949
+ )
950
+
951
+ # 5. set timesteps
952
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
953
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
954
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
955
+
956
+ # 6. Prepare latent variables
957
+ latents, init_latents_orig, noise = self.prepare_latents(
958
+ image,
959
+ latent_timestep,
960
+ batch_size * num_images_per_prompt,
961
+ height,
962
+ width,
963
+ dtype,
964
+ device,
965
+ generator,
966
+ latents,
967
+ )
968
+
969
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
970
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
971
+
972
+ # create size embs and concat embeddings for SDXL
973
+ orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
974
+ crop_size = torch.zeros_like(orig_size)
975
+ target_size = orig_size
976
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
977
+
978
+ # make conditionings
979
+ if do_classifier_free_guidance:
980
+ text_embeddings = torch.cat(text_embeddings_list, dim=2)
981
+ uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
982
+ text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
983
+
984
+ cond_vector = torch.cat([text_pool, embs], dim=1)
985
+ uncond_vector = torch.cat([uncond_pool, embs], dim=1)
986
+ vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
987
+ else:
988
+ text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
989
+ vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
990
+
991
+ # 8. Denoising loop
992
+ for i, t in enumerate(self.progress_bar(timesteps)):
993
+ # expand the latents if we are doing classifier free guidance
994
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
995
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
996
+
997
+ unet_additional_args = {}
998
+ if controlnet is not None:
999
+ down_block_res_samples, mid_block_res_sample = controlnet(
1000
+ latent_model_input,
1001
+ t,
1002
+ encoder_hidden_states=text_embeddings,
1003
+ controlnet_cond=controlnet_image,
1004
+ conditioning_scale=1.0,
1005
+ guess_mode=False,
1006
+ return_dict=False,
1007
+ )
1008
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
1009
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
1010
+
1011
+ # predict the noise residual
1012
+ noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
1013
+ noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
1014
+
1015
+ # perform guidance
1016
+ if do_classifier_free_guidance:
1017
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1018
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1019
+
1020
+ # compute the previous noisy sample x_t -> x_t-1
1021
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1022
+
1023
+ if mask is not None:
1024
+ # masking
1025
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
1026
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
1027
+
1028
+ # call the callback, if provided
1029
+ if i % callback_steps == 0:
1030
+ if callback is not None:
1031
+ callback(i, t, latents)
1032
+ if is_cancelled_callback is not None and is_cancelled_callback():
1033
+ return None
1034
+
1035
+ self.unet.to(unet_dtype)
1036
+ return latents
1037
+
1038
+ def latents_to_image(self, latents):
1039
+ # 9. Post-processing
1040
+ image = self.decode_latents(latents.to(self.vae.dtype))
1041
+ image = self.numpy_to_pil(image)
1042
+ return image
1043
+
1044
+ # copy from pil_utils.py
1045
+ def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
1046
+ """
1047
+ Convert a numpy image or a batch of images to a PIL image.
1048
+ """
1049
+ if images.ndim == 3:
1050
+ images = images[None, ...]
1051
+ images = (images * 255).round().astype("uint8")
1052
+ if images.shape[-1] == 1:
1053
+ # special case for grayscale (single channel) images
1054
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1055
+ else:
1056
+ pil_images = [Image.fromarray(image) for image in images]
1057
+
1058
+ return pil_images
1059
+
1060
+ def text2img(
1061
+ self,
1062
+ prompt: Union[str, List[str]],
1063
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1064
+ height: int = 512,
1065
+ width: int = 512,
1066
+ num_inference_steps: int = 50,
1067
+ guidance_scale: float = 7.5,
1068
+ num_images_per_prompt: Optional[int] = 1,
1069
+ eta: float = 0.0,
1070
+ generator: Optional[torch.Generator] = None,
1071
+ latents: Optional[torch.FloatTensor] = None,
1072
+ max_embeddings_multiples: Optional[int] = 3,
1073
+ output_type: Optional[str] = "pil",
1074
+ return_dict: bool = True,
1075
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1076
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1077
+ callback_steps: int = 1,
1078
+ ):
1079
+ r"""
1080
+ Function for text-to-image generation.
1081
+ Args:
1082
+ prompt (`str` or `List[str]`):
1083
+ The prompt or prompts to guide the image generation.
1084
+ negative_prompt (`str` or `List[str]`, *optional*):
1085
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1086
+ if `guidance_scale` is less than `1`).
1087
+ height (`int`, *optional*, defaults to 512):
1088
+ The height in pixels of the generated image.
1089
+ width (`int`, *optional*, defaults to 512):
1090
+ The width in pixels of the generated image.
1091
+ num_inference_steps (`int`, *optional*, defaults to 50):
1092
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1093
+ expense of slower inference.
1094
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1095
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1096
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1097
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1098
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1099
+ usually at the expense of lower image quality.
1100
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1101
+ The number of images to generate per prompt.
1102
+ eta (`float`, *optional*, defaults to 0.0):
1103
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1104
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1105
+ generator (`torch.Generator`, *optional*):
1106
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1107
+ deterministic.
1108
+ latents (`torch.FloatTensor`, *optional*):
1109
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1110
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1111
+ tensor will ge generated by sampling using the supplied random `generator`.
1112
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1113
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1114
+ output_type (`str`, *optional*, defaults to `"pil"`):
1115
+ The output format of the generate image. Choose between
1116
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1117
+ return_dict (`bool`, *optional*, defaults to `True`):
1118
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1119
+ plain tuple.
1120
+ callback (`Callable`, *optional*):
1121
+ A function that will be called every `callback_steps` steps during inference. The function will be
1122
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1123
+ is_cancelled_callback (`Callable`, *optional*):
1124
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1125
+ `True`, the inference will be cancelled.
1126
+ callback_steps (`int`, *optional*, defaults to 1):
1127
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1128
+ called at every step.
1129
+ Returns:
1130
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1131
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1132
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1133
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1134
+ (nsfw) content, according to the `safety_checker`.
1135
+ """
1136
+ return self.__call__(
1137
+ prompt=prompt,
1138
+ negative_prompt=negative_prompt,
1139
+ height=height,
1140
+ width=width,
1141
+ num_inference_steps=num_inference_steps,
1142
+ guidance_scale=guidance_scale,
1143
+ num_images_per_prompt=num_images_per_prompt,
1144
+ eta=eta,
1145
+ generator=generator,
1146
+ latents=latents,
1147
+ max_embeddings_multiples=max_embeddings_multiples,
1148
+ output_type=output_type,
1149
+ return_dict=return_dict,
1150
+ callback=callback,
1151
+ is_cancelled_callback=is_cancelled_callback,
1152
+ callback_steps=callback_steps,
1153
+ )
1154
+
1155
+ def img2img(
1156
+ self,
1157
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1158
+ prompt: Union[str, List[str]],
1159
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1160
+ strength: float = 0.8,
1161
+ num_inference_steps: Optional[int] = 50,
1162
+ guidance_scale: Optional[float] = 7.5,
1163
+ num_images_per_prompt: Optional[int] = 1,
1164
+ eta: Optional[float] = 0.0,
1165
+ generator: Optional[torch.Generator] = None,
1166
+ max_embeddings_multiples: Optional[int] = 3,
1167
+ output_type: Optional[str] = "pil",
1168
+ return_dict: bool = True,
1169
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1170
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1171
+ callback_steps: int = 1,
1172
+ ):
1173
+ r"""
1174
+ Function for image-to-image generation.
1175
+ Args:
1176
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1177
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1178
+ process.
1179
+ prompt (`str` or `List[str]`):
1180
+ The prompt or prompts to guide the image generation.
1181
+ negative_prompt (`str` or `List[str]`, *optional*):
1182
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1183
+ if `guidance_scale` is less than `1`).
1184
+ strength (`float`, *optional*, defaults to 0.8):
1185
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1186
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1187
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1188
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1189
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1190
+ num_inference_steps (`int`, *optional*, defaults to 50):
1191
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1192
+ expense of slower inference. This parameter will be modulated by `strength`.
1193
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1194
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1195
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1196
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1197
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1198
+ usually at the expense of lower image quality.
1199
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1200
+ The number of images to generate per prompt.
1201
+ eta (`float`, *optional*, defaults to 0.0):
1202
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1203
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1204
+ generator (`torch.Generator`, *optional*):
1205
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1206
+ deterministic.
1207
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1208
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1209
+ output_type (`str`, *optional*, defaults to `"pil"`):
1210
+ The output format of the generate image. Choose between
1211
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1212
+ return_dict (`bool`, *optional*, defaults to `True`):
1213
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1214
+ plain tuple.
1215
+ callback (`Callable`, *optional*):
1216
+ A function that will be called every `callback_steps` steps during inference. The function will be
1217
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1218
+ is_cancelled_callback (`Callable`, *optional*):
1219
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1220
+ `True`, the inference will be cancelled.
1221
+ callback_steps (`int`, *optional*, defaults to 1):
1222
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1223
+ called at every step.
1224
+ Returns:
1225
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1226
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1227
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1228
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1229
+ (nsfw) content, according to the `safety_checker`.
1230
+ """
1231
+ return self.__call__(
1232
+ prompt=prompt,
1233
+ negative_prompt=negative_prompt,
1234
+ image=image,
1235
+ num_inference_steps=num_inference_steps,
1236
+ guidance_scale=guidance_scale,
1237
+ strength=strength,
1238
+ num_images_per_prompt=num_images_per_prompt,
1239
+ eta=eta,
1240
+ generator=generator,
1241
+ max_embeddings_multiples=max_embeddings_multiples,
1242
+ output_type=output_type,
1243
+ return_dict=return_dict,
1244
+ callback=callback,
1245
+ is_cancelled_callback=is_cancelled_callback,
1246
+ callback_steps=callback_steps,
1247
+ )
1248
+
1249
+ def inpaint(
1250
+ self,
1251
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1252
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1253
+ prompt: Union[str, List[str]],
1254
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1255
+ strength: float = 0.8,
1256
+ num_inference_steps: Optional[int] = 50,
1257
+ guidance_scale: Optional[float] = 7.5,
1258
+ num_images_per_prompt: Optional[int] = 1,
1259
+ eta: Optional[float] = 0.0,
1260
+ generator: Optional[torch.Generator] = None,
1261
+ max_embeddings_multiples: Optional[int] = 3,
1262
+ output_type: Optional[str] = "pil",
1263
+ return_dict: bool = True,
1264
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1265
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1266
+ callback_steps: int = 1,
1267
+ ):
1268
+ r"""
1269
+ Function for inpaint.
1270
+ Args:
1271
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1272
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1273
+ process. This is the image whose masked region will be inpainted.
1274
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1275
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1276
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1277
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1278
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1279
+ prompt (`str` or `List[str]`):
1280
+ The prompt or prompts to guide the image generation.
1281
+ negative_prompt (`str` or `List[str]`, *optional*):
1282
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1283
+ if `guidance_scale` is less than `1`).
1284
+ strength (`float`, *optional*, defaults to 0.8):
1285
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1286
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1287
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1288
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1289
+ num_inference_steps (`int`, *optional*, defaults to 50):
1290
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1291
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1292
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1293
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1294
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1295
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1296
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1297
+ usually at the expense of lower image quality.
1298
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1299
+ The number of images to generate per prompt.
1300
+ eta (`float`, *optional*, defaults to 0.0):
1301
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1302
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1303
+ generator (`torch.Generator`, *optional*):
1304
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1305
+ deterministic.
1306
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1307
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1308
+ output_type (`str`, *optional*, defaults to `"pil"`):
1309
+ The output format of the generate image. Choose between
1310
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1311
+ return_dict (`bool`, *optional*, defaults to `True`):
1312
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1313
+ plain tuple.
1314
+ callback (`Callable`, *optional*):
1315
+ A function that will be called every `callback_steps` steps during inference. The function will be
1316
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1317
+ is_cancelled_callback (`Callable`, *optional*):
1318
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1319
+ `True`, the inference will be cancelled.
1320
+ callback_steps (`int`, *optional*, defaults to 1):
1321
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1322
+ called at every step.
1323
+ Returns:
1324
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1325
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1326
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1327
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1328
+ (nsfw) content, according to the `safety_checker`.
1329
+ """
1330
+ return self.__call__(
1331
+ prompt=prompt,
1332
+ negative_prompt=negative_prompt,
1333
+ image=image,
1334
+ mask_image=mask_image,
1335
+ num_inference_steps=num_inference_steps,
1336
+ guidance_scale=guidance_scale,
1337
+ strength=strength,
1338
+ num_images_per_prompt=num_images_per_prompt,
1339
+ eta=eta,
1340
+ generator=generator,
1341
+ max_embeddings_multiples=max_embeddings_multiples,
1342
+ output_type=output_type,
1343
+ return_dict=return_dict,
1344
+ callback=callback,
1345
+ is_cancelled_callback=is_cancelled_callback,
1346
+ callback_steps=callback_steps,
1347
+ )
library/sdxl_model_util.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import safetensors
3
+ from accelerate import init_empty_weights
4
+ from accelerate.utils.modeling import set_module_tensor_to_device
5
+ from safetensors.torch import load_file, save_file
6
+ from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
7
+ from typing import List
8
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
9
+ from library import model_util
10
+ from library import sdxl_original_unet
11
+ from .utils import setup_logging
12
+
13
+ setup_logging()
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ VAE_SCALE_FACTOR = 0.13025
19
+ MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
20
+
21
+ # Diffusersの設定を読み込むための参照モデル
22
+ DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
23
+
24
+ DIFFUSERS_SDXL_UNET_CONFIG = {
25
+ "act_fn": "silu",
26
+ "addition_embed_type": "text_time",
27
+ "addition_embed_type_num_heads": 64,
28
+ "addition_time_embed_dim": 256,
29
+ "attention_head_dim": [5, 10, 20],
30
+ "block_out_channels": [320, 640, 1280],
31
+ "center_input_sample": False,
32
+ "class_embed_type": None,
33
+ "class_embeddings_concat": False,
34
+ "conv_in_kernel": 3,
35
+ "conv_out_kernel": 3,
36
+ "cross_attention_dim": 2048,
37
+ "cross_attention_norm": None,
38
+ "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
39
+ "downsample_padding": 1,
40
+ "dual_cross_attention": False,
41
+ "encoder_hid_dim": None,
42
+ "encoder_hid_dim_type": None,
43
+ "flip_sin_to_cos": True,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_only_cross_attention": None,
48
+ "mid_block_scale_factor": 1,
49
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
50
+ "norm_eps": 1e-05,
51
+ "norm_num_groups": 32,
52
+ "num_attention_heads": None,
53
+ "num_class_embeds": None,
54
+ "only_cross_attention": False,
55
+ "out_channels": 4,
56
+ "projection_class_embeddings_input_dim": 2816,
57
+ "resnet_out_scale_factor": 1.0,
58
+ "resnet_skip_time_act": False,
59
+ "resnet_time_scale_shift": "default",
60
+ "sample_size": 128,
61
+ "time_cond_proj_dim": None,
62
+ "time_embedding_act_fn": None,
63
+ "time_embedding_dim": None,
64
+ "time_embedding_type": "positional",
65
+ "timestep_post_act": None,
66
+ "transformer_layers_per_block": [1, 2, 10],
67
+ "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
68
+ "upcast_attention": False,
69
+ "use_linear_projection": True,
70
+ }
71
+
72
+
73
+ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
74
+ SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
75
+
76
+ # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
77
+ # logit_scaleはcheckpointの保存時に使用する
78
+ def convert_key(key):
79
+ # common conversion
80
+ key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
81
+ key = key.replace(SDXL_KEY_PREFIX, "text_model.")
82
+
83
+ if "resblocks" in key:
84
+ # resblocks conversion
85
+ key = key.replace(".resblocks.", ".layers.")
86
+ if ".ln_" in key:
87
+ key = key.replace(".ln_", ".layer_norm")
88
+ elif ".mlp." in key:
89
+ key = key.replace(".c_fc.", ".fc1.")
90
+ key = key.replace(".c_proj.", ".fc2.")
91
+ elif ".attn.out_proj" in key:
92
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
93
+ elif ".attn.in_proj" in key:
94
+ key = None # 特殊なので後で処理する
95
+ else:
96
+ raise ValueError(f"unexpected key in SD: {key}")
97
+ elif ".positional_embedding" in key:
98
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
99
+ elif ".text_projection" in key:
100
+ key = key.replace("text_model.text_projection", "text_projection.weight")
101
+ elif ".logit_scale" in key:
102
+ key = None # 後で処理する
103
+ elif ".token_embedding" in key:
104
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
105
+ elif ".ln_final" in key:
106
+ key = key.replace(".ln_final", ".final_layer_norm")
107
+ # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
108
+ elif ".embeddings.position_ids" in key:
109
+ key = None # remove this key: position_ids is not used in newer transformers
110
+ return key
111
+
112
+ keys = list(checkpoint.keys())
113
+ new_sd = {}
114
+ for key in keys:
115
+ new_key = convert_key(key)
116
+ if new_key is None:
117
+ continue
118
+ new_sd[new_key] = checkpoint[key]
119
+
120
+ # attnの変換
121
+ for key in keys:
122
+ if ".resblocks" in key and ".attn.in_proj_" in key:
123
+ # 三つに分割
124
+ values = torch.chunk(checkpoint[key], 3)
125
+
126
+ key_suffix = ".weight" if "weight" in key else ".bias"
127
+ key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
128
+ key_pfx = key_pfx.replace("_weight", "")
129
+ key_pfx = key_pfx.replace("_bias", "")
130
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
131
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
132
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
133
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
134
+
135
+ # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
136
+ logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
137
+
138
+ # temporary workaround for text_projection.weight.weight for Playground-v2
139
+ if "text_projection.weight.weight" in new_sd:
140
+ logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
141
+ new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
142
+ del new_sd["text_projection.weight.weight"]
143
+
144
+ return new_sd, logit_scale
145
+
146
+
147
+ # load state_dict without allocating new tensors
148
+ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
149
+ # dtype will use fp32 as default
150
+ missing_keys = list(model.state_dict().keys() - state_dict.keys())
151
+ unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
152
+
153
+ # similar to model.load_state_dict()
154
+ if not missing_keys and not unexpected_keys:
155
+ for k in list(state_dict.keys()):
156
+ set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
157
+ return "<All keys matched successfully>"
158
+
159
+ # error_msgs
160
+ error_msgs: List[str] = []
161
+ if missing_keys:
162
+ error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
163
+ if unexpected_keys:
164
+ error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
165
+
166
+ raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
167
+
168
+
169
+ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
170
+ # model_version is reserved for future use
171
+ # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
172
+
173
+ # Load the state dict
174
+ if model_util.is_safetensors(ckpt_path):
175
+ checkpoint = None
176
+ if disable_mmap:
177
+ state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
178
+ else:
179
+ try:
180
+ state_dict = load_file(ckpt_path, device=map_location)
181
+ except:
182
+ state_dict = load_file(ckpt_path) # prevent device invalid Error
183
+ epoch = None
184
+ global_step = None
185
+ else:
186
+ checkpoint = torch.load(ckpt_path, map_location=map_location)
187
+ if "state_dict" in checkpoint:
188
+ state_dict = checkpoint["state_dict"]
189
+ epoch = checkpoint.get("epoch", 0)
190
+ global_step = checkpoint.get("global_step", 0)
191
+ else:
192
+ state_dict = checkpoint
193
+ epoch = 0
194
+ global_step = 0
195
+ checkpoint = None
196
+
197
+ # U-Net
198
+ logger.info("building U-Net")
199
+ with init_empty_weights():
200
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
201
+
202
+ logger.info("loading U-Net from checkpoint")
203
+ unet_sd = {}
204
+ for k in list(state_dict.keys()):
205
+ if k.startswith("model.diffusion_model."):
206
+ unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
207
+ info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
208
+ logger.info(f"U-Net: {info}")
209
+
210
+ # Text Encoders
211
+ logger.info("building text encoders")
212
+
213
+ # Text Encoder 1 is same to Stability AI's SDXL
214
+ text_model1_cfg = CLIPTextConfig(
215
+ vocab_size=49408,
216
+ hidden_size=768,
217
+ intermediate_size=3072,
218
+ num_hidden_layers=12,
219
+ num_attention_heads=12,
220
+ max_position_embeddings=77,
221
+ hidden_act="quick_gelu",
222
+ layer_norm_eps=1e-05,
223
+ dropout=0.0,
224
+ attention_dropout=0.0,
225
+ initializer_range=0.02,
226
+ initializer_factor=1.0,
227
+ pad_token_id=1,
228
+ bos_token_id=0,
229
+ eos_token_id=2,
230
+ model_type="clip_text_model",
231
+ projection_dim=768,
232
+ # torch_dtype="float32",
233
+ # transformers_version="4.25.0.dev0",
234
+ )
235
+ with init_empty_weights():
236
+ text_model1 = CLIPTextModel._from_config(text_model1_cfg)
237
+
238
+ # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
239
+ # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
240
+ text_model2_cfg = CLIPTextConfig(
241
+ vocab_size=49408,
242
+ hidden_size=1280,
243
+ intermediate_size=5120,
244
+ num_hidden_layers=32,
245
+ num_attention_heads=20,
246
+ max_position_embeddings=77,
247
+ hidden_act="gelu",
248
+ layer_norm_eps=1e-05,
249
+ dropout=0.0,
250
+ attention_dropout=0.0,
251
+ initializer_range=0.02,
252
+ initializer_factor=1.0,
253
+ pad_token_id=1,
254
+ bos_token_id=0,
255
+ eos_token_id=2,
256
+ model_type="clip_text_model",
257
+ projection_dim=1280,
258
+ # torch_dtype="float32",
259
+ # transformers_version="4.25.0.dev0",
260
+ )
261
+ with init_empty_weights():
262
+ text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
263
+
264
+ logger.info("loading text encoders from checkpoint")
265
+ te1_sd = {}
266
+ te2_sd = {}
267
+ for k in list(state_dict.keys()):
268
+ if k.startswith("conditioner.embedders.0.transformer."):
269
+ te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
270
+ elif k.startswith("conditioner.embedders.1.model."):
271
+ te2_sd[k] = state_dict.pop(k)
272
+
273
+ # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
274
+ if "text_model.embeddings.position_ids" in te1_sd:
275
+ te1_sd.pop("text_model.embeddings.position_ids")
276
+
277
+ info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
278
+ logger.info(f"text encoder 1: {info1}")
279
+
280
+ converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
281
+ info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
282
+ logger.info(f"text encoder 2: {info2}")
283
+
284
+ # prepare vae
285
+ logger.info("building VAE")
286
+ vae_config = model_util.create_vae_diffusers_config()
287
+ with init_empty_weights():
288
+ vae = AutoencoderKL(**vae_config)
289
+
290
+ logger.info("loading VAE from checkpoint")
291
+ converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
292
+ info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
293
+ logger.info(f"VAE: {info}")
294
+
295
+ ckpt_info = (epoch, global_step) if epoch is not None else None
296
+ return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
297
+
298
+
299
+ def make_unet_conversion_map():
300
+ unet_conversion_map_layer = []
301
+
302
+ for i in range(3): # num_blocks is 3 in sdxl
303
+ # loop over downblocks/upblocks
304
+ for j in range(2):
305
+ # loop over resnets/attentions for downblocks
306
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
307
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
308
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
309
+
310
+ if i < 3:
311
+ # no attention layers in down_blocks.3
312
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
313
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
314
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
315
+
316
+ for j in range(3):
317
+ # loop over resnets/attentions for upblocks
318
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
319
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
320
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
321
+
322
+ # if i > 0: commentout for sdxl
323
+ # no attention layers in up_blocks.0
324
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
325
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
326
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
327
+
328
+ if i < 3:
329
+ # no downsample in down_blocks.3
330
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
331
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
332
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
333
+
334
+ # no upsample in up_blocks.3
335
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
336
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
337
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
338
+
339
+ hf_mid_atn_prefix = "mid_block.attentions.0."
340
+ sd_mid_atn_prefix = "middle_block.1."
341
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
342
+
343
+ for j in range(2):
344
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
345
+ sd_mid_res_prefix = f"middle_block.{2*j}."
346
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
347
+
348
+ unet_conversion_map_resnet = [
349
+ # (stable-diffusion, HF Diffusers)
350
+ ("in_layers.0.", "norm1."),
351
+ ("in_layers.2.", "conv1."),
352
+ ("out_layers.0.", "norm2."),
353
+ ("out_layers.3.", "conv2."),
354
+ ("emb_layers.1.", "time_emb_proj."),
355
+ ("skip_connection.", "conv_shortcut."),
356
+ ]
357
+
358
+ unet_conversion_map = []
359
+ for sd, hf in unet_conversion_map_layer:
360
+ if "resnets" in hf:
361
+ for sd_res, hf_res in unet_conversion_map_resnet:
362
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
363
+ else:
364
+ unet_conversion_map.append((sd, hf))
365
+
366
+ for j in range(2):
367
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
368
+ sd_time_embed_prefix = f"time_embed.{j*2}."
369
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
370
+
371
+ for j in range(2):
372
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
373
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
374
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
375
+
376
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
377
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
378
+ unet_conversion_map.append(("out.2.", "conv_out."))
379
+
380
+ return unet_conversion_map
381
+
382
+
383
+ def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
384
+ unet_conversion_map = make_unet_conversion_map()
385
+
386
+ conversion_map = {hf: sd for sd, hf in unet_conversion_map}
387
+ return convert_unet_state_dict(du_sd, conversion_map)
388
+
389
+
390
+ def convert_unet_state_dict(src_sd, conversion_map):
391
+ converted_sd = {}
392
+ for src_key, value in src_sd.items():
393
+ # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
394
+ src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
395
+ while len(src_key_fragments) > 0:
396
+ src_key_prefix = ".".join(src_key_fragments) + "."
397
+ if src_key_prefix in conversion_map:
398
+ converted_prefix = conversion_map[src_key_prefix]
399
+ converted_key = converted_prefix + src_key[len(src_key_prefix) :]
400
+ converted_sd[converted_key] = value
401
+ break
402
+ src_key_fragments.pop(-1)
403
+ assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
404
+
405
+ return converted_sd
406
+
407
+
408
+ def convert_sdxl_unet_state_dict_to_diffusers(sd):
409
+ unet_conversion_map = make_unet_conversion_map()
410
+
411
+ conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
412
+ return convert_unet_state_dict(sd, conversion_dict)
413
+
414
+
415
+ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
416
+ def convert_key(key):
417
+ # position_idsの除去
418
+ if ".position_ids" in key:
419
+ return None
420
+
421
+ # common
422
+ key = key.replace("text_model.encoder.", "transformer.")
423
+ key = key.replace("text_model.", "")
424
+ if "layers" in key:
425
+ # resblocks conversion
426
+ key = key.replace(".layers.", ".resblocks.")
427
+ if ".layer_norm" in key:
428
+ key = key.replace(".layer_norm", ".ln_")
429
+ elif ".mlp." in key:
430
+ key = key.replace(".fc1.", ".c_fc.")
431
+ key = key.replace(".fc2.", ".c_proj.")
432
+ elif ".self_attn.out_proj" in key:
433
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
434
+ elif ".self_attn." in key:
435
+ key = None # 特殊なので後で処理する
436
+ else:
437
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
438
+ elif ".position_embedding" in key:
439
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
440
+ elif ".token_embedding" in key:
441
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
442
+ elif "text_projection" in key: # no dot in key
443
+ key = key.replace("text_projection.weight", "text_projection")
444
+ elif "final_layer_norm" in key:
445
+ key = key.replace("final_layer_norm", "ln_final")
446
+ return key
447
+
448
+ keys = list(checkpoint.keys())
449
+ new_sd = {}
450
+ for key in keys:
451
+ new_key = convert_key(key)
452
+ if new_key is None:
453
+ continue
454
+ new_sd[new_key] = checkpoint[key]
455
+
456
+ # attnの変換
457
+ for key in keys:
458
+ if "layers" in key and "q_proj" in key:
459
+ # 三つを結合
460
+ key_q = key
461
+ key_k = key.replace("q_proj", "k_proj")
462
+ key_v = key.replace("q_proj", "v_proj")
463
+
464
+ value_q = checkpoint[key_q]
465
+ value_k = checkpoint[key_k]
466
+ value_v = checkpoint[key_v]
467
+ value = torch.cat([value_q, value_k, value_v])
468
+
469
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
470
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
471
+ new_sd[new_key] = value
472
+
473
+ if logit_scale is not None:
474
+ new_sd["logit_scale"] = logit_scale
475
+
476
+ return new_sd
477
+
478
+
479
+ def save_stable_diffusion_checkpoint(
480
+ output_file,
481
+ text_encoder1,
482
+ text_encoder2,
483
+ unet,
484
+ epochs,
485
+ steps,
486
+ ckpt_info,
487
+ vae,
488
+ logit_scale,
489
+ metadata,
490
+ save_dtype=None,
491
+ ):
492
+ state_dict = {}
493
+
494
+ def update_sd(prefix, sd):
495
+ for k, v in sd.items():
496
+ key = prefix + k
497
+ if save_dtype is not None:
498
+ v = v.detach().clone().to("cpu").to(save_dtype)
499
+ state_dict[key] = v
500
+
501
+ # Convert the UNet model
502
+ update_sd("model.diffusion_model.", unet.state_dict())
503
+
504
+ # Convert the text encoders
505
+ update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
506
+
507
+ text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
508
+ update_sd("conditioner.embedders.1.model.", text_enc2_dict)
509
+
510
+ # Convert the VAE
511
+ vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
512
+ update_sd("first_stage_model.", vae_dict)
513
+
514
+ # Put together new checkpoint
515
+ key_count = len(state_dict.keys())
516
+ new_ckpt = {"state_dict": state_dict}
517
+
518
+ # epoch and global_step are sometimes not int
519
+ if ckpt_info is not None:
520
+ epochs += ckpt_info[0]
521
+ steps += ckpt_info[1]
522
+
523
+ new_ckpt["epoch"] = epochs
524
+ new_ckpt["global_step"] = steps
525
+
526
+ if model_util.is_safetensors(output_file):
527
+ save_file(state_dict, output_file, metadata)
528
+ else:
529
+ torch.save(new_ckpt, output_file)
530
+
531
+ return key_count
532
+
533
+
534
+ def save_diffusers_checkpoint(
535
+ output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
536
+ ):
537
+ from diffusers import StableDiffusionXLPipeline
538
+
539
+ # convert U-Net
540
+ unet_sd = unet.state_dict()
541
+ du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
542
+
543
+ diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
544
+ if save_dtype is not None:
545
+ diffusers_unet.to(save_dtype)
546
+ diffusers_unet.load_state_dict(du_unet_sd)
547
+
548
+ # create pipeline to save
549
+ if pretrained_model_name_or_path is None:
550
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
551
+
552
+ scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
553
+ tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
554
+ tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
555
+ if vae is None:
556
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
557
+
558
+ # prevent local path from being saved
559
+ def remove_name_or_path(model):
560
+ if hasattr(model, "config"):
561
+ model.config._name_or_path = None
562
+ model.config._name_or_path = None
563
+
564
+ remove_name_or_path(diffusers_unet)
565
+ remove_name_or_path(text_encoder1)
566
+ remove_name_or_path(text_encoder2)
567
+ remove_name_or_path(scheduler)
568
+ remove_name_or_path(tokenizer1)
569
+ remove_name_or_path(tokenizer2)
570
+ remove_name_or_path(vae)
571
+
572
+ pipeline = StableDiffusionXLPipeline(
573
+ unet=diffusers_unet,
574
+ text_encoder=text_encoder1,
575
+ text_encoder_2=text_encoder2,
576
+ vae=vae,
577
+ scheduler=scheduler,
578
+ tokenizer=tokenizer1,
579
+ tokenizer_2=tokenizer2,
580
+ )
581
+ if save_dtype is not None:
582
+ pipeline.to(None, save_dtype)
583
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
library/sdxl_original_unet.py ADDED
@@ -0,0 +1,1286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusersのコードをベースとした sd_xl_baseのU-Net
2
+ # state dictの形式をSDXLに合わせてある
3
+
4
+ """
5
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
6
+ params:
7
+ adm_in_channels: 2816
8
+ num_classes: sequential
9
+ use_checkpoint: True
10
+ in_channels: 4
11
+ out_channels: 4
12
+ model_channels: 320
13
+ attention_resolutions: [4, 2]
14
+ num_res_blocks: 2
15
+ channel_mult: [1, 2, 4]
16
+ num_head_channels: 64
17
+ use_spatial_transformer: True
18
+ use_linear_in_transformer: True
19
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
20
+ context_dim: 2048
21
+ spatial_transformer_attn_type: softmax-xformers
22
+ legacy: False
23
+ """
24
+
25
+ import math
26
+ from types import SimpleNamespace
27
+ from typing import Any, Optional
28
+ import torch
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import functional as F
32
+ from einops import rearrange
33
+ from .utils import setup_logging
34
+
35
+ setup_logging()
36
+ import logging
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ IN_CHANNELS: int = 4
41
+ OUT_CHANNELS: int = 4
42
+ ADM_IN_CHANNELS: int = 2816
43
+ CONTEXT_DIM: int = 2048
44
+ MODEL_CHANNELS: int = 320
45
+ TIME_EMBED_DIM = 320 * 4
46
+
47
+ USE_REENTRANT = True
48
+
49
+ # region memory efficient attention
50
+
51
+ # FlashAttentionを使うCrossAttention
52
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
53
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
54
+
55
+ # constants
56
+
57
+ EPSILON = 1e-6
58
+
59
+ # helper functions
60
+
61
+
62
+ def exists(val):
63
+ return val is not None
64
+
65
+
66
+ def default(val, d):
67
+ return val if exists(val) else d
68
+
69
+
70
+ # flash attention forwards and backwards
71
+
72
+ # https://arxiv.org/abs/2205.14135
73
+
74
+
75
+ class FlashAttentionFunction(torch.autograd.Function):
76
+ @staticmethod
77
+ @torch.no_grad()
78
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
79
+ """Algorithm 2 in the paper"""
80
+
81
+ device = q.device
82
+ dtype = q.dtype
83
+ max_neg_value = -torch.finfo(q.dtype).max
84
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
85
+
86
+ o = torch.zeros_like(q)
87
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
88
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
89
+
90
+ scale = q.shape[-1] ** -0.5
91
+
92
+ if not exists(mask):
93
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
94
+ else:
95
+ mask = rearrange(mask, "b n -> b 1 1 n")
96
+ mask = mask.split(q_bucket_size, dim=-1)
97
+
98
+ row_splits = zip(
99
+ q.split(q_bucket_size, dim=-2),
100
+ o.split(q_bucket_size, dim=-2),
101
+ mask,
102
+ all_row_sums.split(q_bucket_size, dim=-2),
103
+ all_row_maxes.split(q_bucket_size, dim=-2),
104
+ )
105
+
106
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
107
+ q_start_index = ind * q_bucket_size - qk_len_diff
108
+
109
+ col_splits = zip(
110
+ k.split(k_bucket_size, dim=-2),
111
+ v.split(k_bucket_size, dim=-2),
112
+ )
113
+
114
+ for k_ind, (kc, vc) in enumerate(col_splits):
115
+ k_start_index = k_ind * k_bucket_size
116
+
117
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
118
+
119
+ if exists(row_mask):
120
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
121
+
122
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
123
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
124
+ q_start_index - k_start_index + 1
125
+ )
126
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
127
+
128
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
129
+ attn_weights -= block_row_maxes
130
+ exp_weights = torch.exp(attn_weights)
131
+
132
+ if exists(row_mask):
133
+ exp_weights.masked_fill_(~row_mask, 0.0)
134
+
135
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
136
+
137
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
138
+
139
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
140
+
141
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
142
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
143
+
144
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
145
+
146
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
147
+
148
+ row_maxes.copy_(new_row_maxes)
149
+ row_sums.copy_(new_row_sums)
150
+
151
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
152
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
153
+
154
+ return o
155
+
156
+ @staticmethod
157
+ @torch.no_grad()
158
+ def backward(ctx, do):
159
+ """Algorithm 4 in the paper"""
160
+
161
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
162
+ q, k, v, o, l, m = ctx.saved_tensors
163
+
164
+ device = q.device
165
+
166
+ max_neg_value = -torch.finfo(q.dtype).max
167
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
168
+
169
+ dq = torch.zeros_like(q)
170
+ dk = torch.zeros_like(k)
171
+ dv = torch.zeros_like(v)
172
+
173
+ row_splits = zip(
174
+ q.split(q_bucket_size, dim=-2),
175
+ o.split(q_bucket_size, dim=-2),
176
+ do.split(q_bucket_size, dim=-2),
177
+ mask,
178
+ l.split(q_bucket_size, dim=-2),
179
+ m.split(q_bucket_size, dim=-2),
180
+ dq.split(q_bucket_size, dim=-2),
181
+ )
182
+
183
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
184
+ q_start_index = ind * q_bucket_size - qk_len_diff
185
+
186
+ col_splits = zip(
187
+ k.split(k_bucket_size, dim=-2),
188
+ v.split(k_bucket_size, dim=-2),
189
+ dk.split(k_bucket_size, dim=-2),
190
+ dv.split(k_bucket_size, dim=-2),
191
+ )
192
+
193
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
194
+ k_start_index = k_ind * k_bucket_size
195
+
196
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
197
+
198
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
199
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
200
+ q_start_index - k_start_index + 1
201
+ )
202
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
203
+
204
+ exp_attn_weights = torch.exp(attn_weights - mc)
205
+
206
+ if exists(row_mask):
207
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
208
+
209
+ p = exp_attn_weights / lc
210
+
211
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
212
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
213
+
214
+ D = (doc * oc).sum(dim=-1, keepdims=True)
215
+ ds = p * scale * (dp - D)
216
+
217
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
218
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
219
+
220
+ dqc.add_(dq_chunk)
221
+ dkc.add_(dk_chunk)
222
+ dvc.add_(dv_chunk)
223
+
224
+ return dq, dk, dv, None, None, None, None
225
+
226
+
227
+ # endregion
228
+
229
+
230
+ def get_parameter_dtype(parameter: torch.nn.Module):
231
+ return next(parameter.parameters()).dtype
232
+
233
+
234
+ def get_parameter_device(parameter: torch.nn.Module):
235
+ return next(parameter.parameters()).device
236
+
237
+
238
+ def get_timestep_embedding(
239
+ timesteps: torch.Tensor,
240
+ embedding_dim: int,
241
+ downscale_freq_shift: float = 1,
242
+ scale: float = 1,
243
+ max_period: int = 10000,
244
+ ):
245
+ """
246
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
247
+
248
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
249
+ These may be fractional.
250
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
251
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
252
+ """
253
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
254
+
255
+ half_dim = embedding_dim // 2
256
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
257
+ exponent = exponent / (half_dim - downscale_freq_shift)
258
+
259
+ emb = torch.exp(exponent)
260
+ emb = timesteps[:, None].float() * emb[None, :]
261
+
262
+ # scale embeddings
263
+ emb = scale * emb
264
+
265
+ # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
266
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
267
+
268
+ # zero pad
269
+ if embedding_dim % 2 == 1:
270
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
271
+ return emb
272
+
273
+
274
+ # Deep Shrink: We do not common this function, because minimize dependencies.
275
+ def resize_like(x, target, mode="bicubic", align_corners=False):
276
+ org_dtype = x.dtype
277
+ if org_dtype == torch.bfloat16:
278
+ x = x.to(torch.float32)
279
+
280
+ if x.shape[-2:] != target.shape[-2:]:
281
+ if mode == "nearest":
282
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
283
+ else:
284
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
285
+
286
+ if org_dtype == torch.bfloat16:
287
+ x = x.to(org_dtype)
288
+ return x
289
+
290
+
291
+ class GroupNorm32(nn.GroupNorm):
292
+ def forward(self, x):
293
+ if self.weight.dtype != torch.float32:
294
+ return super().forward(x)
295
+ return super().forward(x.float()).type(x.dtype)
296
+
297
+
298
+ class ResnetBlock2D(nn.Module):
299
+ def __init__(
300
+ self,
301
+ in_channels,
302
+ out_channels,
303
+ ):
304
+ super().__init__()
305
+ self.in_channels = in_channels
306
+ self.out_channels = out_channels
307
+
308
+ self.in_layers = nn.Sequential(
309
+ GroupNorm32(32, in_channels),
310
+ nn.SiLU(),
311
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
312
+ )
313
+
314
+ self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
315
+
316
+ self.out_layers = nn.Sequential(
317
+ GroupNorm32(32, out_channels),
318
+ nn.SiLU(),
319
+ nn.Identity(), # to make state_dict compatible with original model
320
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
321
+ )
322
+
323
+ if in_channels != out_channels:
324
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
325
+ else:
326
+ self.skip_connection = nn.Identity()
327
+
328
+ self.gradient_checkpointing = False
329
+
330
+ def forward_body(self, x, emb):
331
+ h = self.in_layers(x)
332
+ emb_out = self.emb_layers(emb).type(h.dtype)
333
+ h = h + emb_out[:, :, None, None]
334
+ h = self.out_layers(h)
335
+ x = self.skip_connection(x)
336
+ return x + h
337
+
338
+ def forward(self, x, emb):
339
+ if self.training and self.gradient_checkpointing:
340
+ # logger.info("ResnetBlock2D: gradient_checkpointing")
341
+
342
+ def create_custom_forward(func):
343
+ def custom_forward(*inputs):
344
+ return func(*inputs)
345
+
346
+ return custom_forward
347
+
348
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
349
+ else:
350
+ x = self.forward_body(x, emb)
351
+
352
+ return x
353
+
354
+
355
+ class Downsample2D(nn.Module):
356
+ def __init__(self, channels, out_channels):
357
+ super().__init__()
358
+
359
+ self.channels = channels
360
+ self.out_channels = out_channels
361
+
362
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
363
+
364
+ self.gradient_checkpointing = False
365
+
366
+ def forward_body(self, hidden_states):
367
+ assert hidden_states.shape[1] == self.channels
368
+ hidden_states = self.op(hidden_states)
369
+
370
+ return hidden_states
371
+
372
+ def forward(self, hidden_states):
373
+ if self.training and self.gradient_checkpointing:
374
+ # logger.info("Downsample2D: gradient_checkpointing")
375
+
376
+ def create_custom_forward(func):
377
+ def custom_forward(*inputs):
378
+ return func(*inputs)
379
+
380
+ return custom_forward
381
+
382
+ hidden_states = torch.utils.checkpoint.checkpoint(
383
+ create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
384
+ )
385
+ else:
386
+ hidden_states = self.forward_body(hidden_states)
387
+
388
+ return hidden_states
389
+
390
+
391
+ class CrossAttention(nn.Module):
392
+ def __init__(
393
+ self,
394
+ query_dim: int,
395
+ cross_attention_dim: Optional[int] = None,
396
+ heads: int = 8,
397
+ dim_head: int = 64,
398
+ upcast_attention: bool = False,
399
+ ):
400
+ super().__init__()
401
+ inner_dim = dim_head * heads
402
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
403
+ self.upcast_attention = upcast_attention
404
+
405
+ self.scale = dim_head**-0.5
406
+ self.heads = heads
407
+
408
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
409
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
410
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
411
+
412
+ self.to_out = nn.ModuleList([])
413
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
414
+ # no dropout here
415
+
416
+ self.use_memory_efficient_attention_xformers = False
417
+ self.use_memory_efficient_attention_mem_eff = False
418
+ self.use_sdpa = False
419
+
420
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
421
+ self.use_memory_efficient_attention_xformers = xformers
422
+ self.use_memory_efficient_attention_mem_eff = mem_eff
423
+
424
+ def set_use_sdpa(self, sdpa):
425
+ self.use_sdpa = sdpa
426
+
427
+ def reshape_heads_to_batch_dim(self, tensor):
428
+ batch_size, seq_len, dim = tensor.shape
429
+ head_size = self.heads
430
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
431
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
432
+ return tensor
433
+
434
+ def reshape_batch_dim_to_heads(self, tensor):
435
+ batch_size, seq_len, dim = tensor.shape
436
+ head_size = self.heads
437
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
438
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
439
+ return tensor
440
+
441
+ def forward(self, hidden_states, context=None, mask=None):
442
+ if self.use_memory_efficient_attention_xformers:
443
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
444
+ if self.use_memory_efficient_attention_mem_eff:
445
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
446
+ if self.use_sdpa:
447
+ return self.forward_sdpa(hidden_states, context, mask)
448
+
449
+ query = self.to_q(hidden_states)
450
+ context = context if context is not None else hidden_states
451
+ key = self.to_k(context)
452
+ value = self.to_v(context)
453
+
454
+ query = self.reshape_heads_to_batch_dim(query)
455
+ key = self.reshape_heads_to_batch_dim(key)
456
+ value = self.reshape_heads_to_batch_dim(value)
457
+
458
+ hidden_states = self._attention(query, key, value)
459
+
460
+ # linear proj
461
+ hidden_states = self.to_out[0](hidden_states)
462
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
463
+ return hidden_states
464
+
465
+ def _attention(self, query, key, value):
466
+ if self.upcast_attention:
467
+ query = query.float()
468
+ key = key.float()
469
+
470
+ attention_scores = torch.baddbmm(
471
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
472
+ query,
473
+ key.transpose(-1, -2),
474
+ beta=0,
475
+ alpha=self.scale,
476
+ )
477
+ attention_probs = attention_scores.softmax(dim=-1)
478
+
479
+ # cast back to the original dtype
480
+ attention_probs = attention_probs.to(value.dtype)
481
+
482
+ # compute attention output
483
+ hidden_states = torch.bmm(attention_probs, value)
484
+
485
+ # reshape hidden_states
486
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
487
+ return hidden_states
488
+
489
+ # TODO support Hypernetworks
490
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
491
+ import xformers.ops
492
+
493
+ h = self.heads
494
+ q_in = self.to_q(x)
495
+ context = context if context is not None else x
496
+ context = context.to(x.dtype)
497
+ k_in = self.to_k(context)
498
+ v_in = self.to_v(context)
499
+
500
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
501
+ del q_in, k_in, v_in
502
+
503
+ q = q.contiguous()
504
+ k = k.contiguous()
505
+ v = v.contiguous()
506
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
507
+ del q, k, v
508
+
509
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
510
+
511
+ out = self.to_out[0](out)
512
+ return out
513
+
514
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
515
+ flash_func = FlashAttentionFunction
516
+
517
+ q_bucket_size = 512
518
+ k_bucket_size = 1024
519
+
520
+ h = self.heads
521
+ q = self.to_q(x)
522
+ context = context if context is not None else x
523
+ context = context.to(x.dtype)
524
+ k = self.to_k(context)
525
+ v = self.to_v(context)
526
+ del context, x
527
+
528
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
529
+
530
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
531
+
532
+ out = rearrange(out, "b h n d -> b n (h d)")
533
+
534
+ out = self.to_out[0](out)
535
+ return out
536
+
537
+ def forward_sdpa(self, x, context=None, mask=None):
538
+ h = self.heads
539
+ q_in = self.to_q(x)
540
+ context = context if context is not None else x
541
+ context = context.to(x.dtype)
542
+ k_in = self.to_k(context)
543
+ v_in = self.to_v(context)
544
+
545
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
546
+ del q_in, k_in, v_in
547
+
548
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
549
+
550
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
551
+
552
+ out = self.to_out[0](out)
553
+ return out
554
+
555
+
556
+ # feedforward
557
+ class GEGLU(nn.Module):
558
+ r"""
559
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
560
+
561
+ Parameters:
562
+ dim_in (`int`): The number of channels in the input.
563
+ dim_out (`int`): The number of channels in the output.
564
+ """
565
+
566
+ def __init__(self, dim_in: int, dim_out: int):
567
+ super().__init__()
568
+ self.proj = nn.Linear(dim_in, dim_out * 2)
569
+
570
+ def gelu(self, gate):
571
+ if gate.device.type != "mps":
572
+ return F.gelu(gate)
573
+ # mps: gelu is not implemented for float16
574
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
575
+
576
+ def forward(self, hidden_states):
577
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
578
+ return hidden_states * self.gelu(gate)
579
+
580
+
581
+ class FeedForward(nn.Module):
582
+ def __init__(
583
+ self,
584
+ dim: int,
585
+ ):
586
+ super().__init__()
587
+ inner_dim = int(dim * 4) # mult is always 4
588
+
589
+ self.net = nn.ModuleList([])
590
+ # project in
591
+ self.net.append(GEGLU(dim, inner_dim))
592
+ # project dropout
593
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
594
+ # project out
595
+ self.net.append(nn.Linear(inner_dim, dim))
596
+
597
+ def forward(self, hidden_states):
598
+ for module in self.net:
599
+ hidden_states = module(hidden_states)
600
+ return hidden_states
601
+
602
+
603
+ class BasicTransformerBlock(nn.Module):
604
+ def __init__(
605
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
606
+ ):
607
+ super().__init__()
608
+
609
+ self.gradient_checkpointing = False
610
+
611
+ # 1. Self-Attn
612
+ self.attn1 = CrossAttention(
613
+ query_dim=dim,
614
+ cross_attention_dim=None,
615
+ heads=num_attention_heads,
616
+ dim_head=attention_head_dim,
617
+ upcast_attention=upcast_attention,
618
+ )
619
+ self.ff = FeedForward(dim)
620
+
621
+ # 2. Cross-Attn
622
+ self.attn2 = CrossAttention(
623
+ query_dim=dim,
624
+ cross_attention_dim=cross_attention_dim,
625
+ heads=num_attention_heads,
626
+ dim_head=attention_head_dim,
627
+ upcast_attention=upcast_attention,
628
+ )
629
+
630
+ self.norm1 = nn.LayerNorm(dim)
631
+ self.norm2 = nn.LayerNorm(dim)
632
+
633
+ # 3. Feed-forward
634
+ self.norm3 = nn.LayerNorm(dim)
635
+
636
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
637
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
638
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
639
+
640
+ def set_use_sdpa(self, sdpa: bool):
641
+ self.attn1.set_use_sdpa(sdpa)
642
+ self.attn2.set_use_sdpa(sdpa)
643
+
644
+ def forward_body(self, hidden_states, context=None, timestep=None):
645
+ # 1. Self-Attention
646
+ norm_hidden_states = self.norm1(hidden_states)
647
+
648
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
649
+
650
+ # 2. Cross-Attention
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
653
+
654
+ # 3. Feed-forward
655
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
656
+
657
+ return hidden_states
658
+
659
+ def forward(self, hidden_states, context=None, timestep=None):
660
+ if self.training and self.gradient_checkpointing:
661
+ # logger.info("BasicTransformerBlock: checkpointing")
662
+
663
+ def create_custom_forward(func):
664
+ def custom_forward(*inputs):
665
+ return func(*inputs)
666
+
667
+ return custom_forward
668
+
669
+ output = torch.utils.checkpoint.checkpoint(
670
+ create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
671
+ )
672
+ else:
673
+ output = self.forward_body(hidden_states, context, timestep)
674
+
675
+ return output
676
+
677
+
678
+ class Transformer2DModel(nn.Module):
679
+ def __init__(
680
+ self,
681
+ num_attention_heads: int = 16,
682
+ attention_head_dim: int = 88,
683
+ in_channels: Optional[int] = None,
684
+ cross_attention_dim: Optional[int] = None,
685
+ use_linear_projection: bool = False,
686
+ upcast_attention: bool = False,
687
+ num_transformer_layers: int = 1,
688
+ ):
689
+ super().__init__()
690
+ self.in_channels = in_channels
691
+ self.num_attention_heads = num_attention_heads
692
+ self.attention_head_dim = attention_head_dim
693
+ inner_dim = num_attention_heads * attention_head_dim
694
+ self.use_linear_projection = use_linear_projection
695
+
696
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
697
+ # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
698
+
699
+ if use_linear_projection:
700
+ self.proj_in = nn.Linear(in_channels, inner_dim)
701
+ else:
702
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
703
+
704
+ blocks = []
705
+ for _ in range(num_transformer_layers):
706
+ blocks.append(
707
+ BasicTransformerBlock(
708
+ inner_dim,
709
+ num_attention_heads,
710
+ attention_head_dim,
711
+ cross_attention_dim=cross_attention_dim,
712
+ upcast_attention=upcast_attention,
713
+ )
714
+ )
715
+
716
+ self.transformer_blocks = nn.ModuleList(blocks)
717
+
718
+ if use_linear_projection:
719
+ self.proj_out = nn.Linear(in_channels, inner_dim)
720
+ else:
721
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
722
+
723
+ self.gradient_checkpointing = False
724
+
725
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
726
+ for transformer in self.transformer_blocks:
727
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
728
+
729
+ def set_use_sdpa(self, sdpa):
730
+ for transformer in self.transformer_blocks:
731
+ transformer.set_use_sdpa(sdpa)
732
+
733
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
734
+ # 1. Input
735
+ batch, _, height, weight = hidden_states.shape
736
+ residual = hidden_states
737
+
738
+ hidden_states = self.norm(hidden_states)
739
+ if not self.use_linear_projection:
740
+ hidden_states = self.proj_in(hidden_states)
741
+ inner_dim = hidden_states.shape[1]
742
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
743
+ else:
744
+ inner_dim = hidden_states.shape[1]
745
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
746
+ hidden_states = self.proj_in(hidden_states)
747
+
748
+ # 2. Blocks
749
+ for block in self.transformer_blocks:
750
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
751
+
752
+ # 3. Output
753
+ if not self.use_linear_projection:
754
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
755
+ hidden_states = self.proj_out(hidden_states)
756
+ else:
757
+ hidden_states = self.proj_out(hidden_states)
758
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
759
+
760
+ output = hidden_states + residual
761
+
762
+ return output
763
+
764
+
765
+ class Upsample2D(nn.Module):
766
+ def __init__(self, channels, out_channels):
767
+ super().__init__()
768
+ self.channels = channels
769
+ self.out_channels = out_channels
770
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
771
+
772
+ self.gradient_checkpointing = False
773
+
774
+ def forward_body(self, hidden_states, output_size=None):
775
+ assert hidden_states.shape[1] == self.channels
776
+
777
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
778
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
779
+ # https://github.com/pytorch/pytorch/issues/86679
780
+ dtype = hidden_states.dtype
781
+ if dtype == torch.bfloat16:
782
+ hidden_states = hidden_states.to(torch.float32)
783
+
784
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
785
+ if hidden_states.shape[0] >= 64:
786
+ hidden_states = hidden_states.contiguous()
787
+
788
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
789
+ if output_size is None:
790
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
791
+ else:
792
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
793
+
794
+ # If the input is bfloat16, we cast back to bfloat16
795
+ if dtype == torch.bfloat16:
796
+ hidden_states = hidden_states.to(dtype)
797
+
798
+ hidden_states = self.conv(hidden_states)
799
+
800
+ return hidden_states
801
+
802
+ def forward(self, hidden_states, output_size=None):
803
+ if self.training and self.gradient_checkpointing:
804
+ # logger.info("Upsample2D: gradient_checkpointing")
805
+
806
+ def create_custom_forward(func):
807
+ def custom_forward(*inputs):
808
+ return func(*inputs)
809
+
810
+ return custom_forward
811
+
812
+ hidden_states = torch.utils.checkpoint.checkpoint(
813
+ create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
814
+ )
815
+ else:
816
+ hidden_states = self.forward_body(hidden_states, output_size)
817
+
818
+ return hidden_states
819
+
820
+
821
+ class SdxlUNet2DConditionModel(nn.Module):
822
+ _supports_gradient_checkpointing = True
823
+
824
+ def __init__(
825
+ self,
826
+ **kwargs,
827
+ ):
828
+ super().__init__()
829
+
830
+ self.in_channels = IN_CHANNELS
831
+ self.out_channels = OUT_CHANNELS
832
+ self.model_channels = MODEL_CHANNELS
833
+ self.time_embed_dim = TIME_EMBED_DIM
834
+ self.adm_in_channels = ADM_IN_CHANNELS
835
+
836
+ self.gradient_checkpointing = False
837
+ # self.sample_size = sample_size
838
+
839
+ # time embedding
840
+ self.time_embed = nn.Sequential(
841
+ nn.Linear(self.model_channels, self.time_embed_dim),
842
+ nn.SiLU(),
843
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
844
+ )
845
+
846
+ # label embedding
847
+ self.label_emb = nn.Sequential(
848
+ nn.Sequential(
849
+ nn.Linear(self.adm_in_channels, self.time_embed_dim),
850
+ nn.SiLU(),
851
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
852
+ )
853
+ )
854
+
855
+ # input
856
+ self.input_blocks = nn.ModuleList(
857
+ [
858
+ nn.Sequential(
859
+ nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
860
+ )
861
+ ]
862
+ )
863
+
864
+ # level 0
865
+ for i in range(2):
866
+ layers = [
867
+ ResnetBlock2D(
868
+ in_channels=1 * self.model_channels,
869
+ out_channels=1 * self.model_channels,
870
+ ),
871
+ ]
872
+ self.input_blocks.append(nn.ModuleList(layers))
873
+
874
+ self.input_blocks.append(
875
+ nn.Sequential(
876
+ Downsample2D(
877
+ channels=1 * self.model_channels,
878
+ out_channels=1 * self.model_channels,
879
+ ),
880
+ )
881
+ )
882
+
883
+ # level 1
884
+ for i in range(2):
885
+ layers = [
886
+ ResnetBlock2D(
887
+ in_channels=(1 if i == 0 else 2) * self.model_channels,
888
+ out_channels=2 * self.model_channels,
889
+ ),
890
+ Transformer2DModel(
891
+ num_attention_heads=2 * self.model_channels // 64,
892
+ attention_head_dim=64,
893
+ in_channels=2 * self.model_channels,
894
+ num_transformer_layers=2,
895
+ use_linear_projection=True,
896
+ cross_attention_dim=2048,
897
+ ),
898
+ ]
899
+ self.input_blocks.append(nn.ModuleList(layers))
900
+
901
+ self.input_blocks.append(
902
+ nn.Sequential(
903
+ Downsample2D(
904
+ channels=2 * self.model_channels,
905
+ out_channels=2 * self.model_channels,
906
+ ),
907
+ )
908
+ )
909
+
910
+ # level 2
911
+ for i in range(2):
912
+ layers = [
913
+ ResnetBlock2D(
914
+ in_channels=(2 if i == 0 else 4) * self.model_channels,
915
+ out_channels=4 * self.model_channels,
916
+ ),
917
+ Transformer2DModel(
918
+ num_attention_heads=4 * self.model_channels // 64,
919
+ attention_head_dim=64,
920
+ in_channels=4 * self.model_channels,
921
+ num_transformer_layers=10,
922
+ use_linear_projection=True,
923
+ cross_attention_dim=2048,
924
+ ),
925
+ ]
926
+ self.input_blocks.append(nn.ModuleList(layers))
927
+
928
+ # mid
929
+ self.middle_block = nn.ModuleList(
930
+ [
931
+ ResnetBlock2D(
932
+ in_channels=4 * self.model_channels,
933
+ out_channels=4 * self.model_channels,
934
+ ),
935
+ Transformer2DModel(
936
+ num_attention_heads=4 * self.model_channels // 64,
937
+ attention_head_dim=64,
938
+ in_channels=4 * self.model_channels,
939
+ num_transformer_layers=10,
940
+ use_linear_projection=True,
941
+ cross_attention_dim=2048,
942
+ ),
943
+ ResnetBlock2D(
944
+ in_channels=4 * self.model_channels,
945
+ out_channels=4 * self.model_channels,
946
+ ),
947
+ ]
948
+ )
949
+
950
+ # output
951
+ self.output_blocks = nn.ModuleList([])
952
+
953
+ # level 2
954
+ for i in range(3):
955
+ layers = [
956
+ ResnetBlock2D(
957
+ in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
958
+ out_channels=4 * self.model_channels,
959
+ ),
960
+ Transformer2DModel(
961
+ num_attention_heads=4 * self.model_channels // 64,
962
+ attention_head_dim=64,
963
+ in_channels=4 * self.model_channels,
964
+ num_transformer_layers=10,
965
+ use_linear_projection=True,
966
+ cross_attention_dim=2048,
967
+ ),
968
+ ]
969
+ if i == 2:
970
+ layers.append(
971
+ Upsample2D(
972
+ channels=4 * self.model_channels,
973
+ out_channels=4 * self.model_channels,
974
+ )
975
+ )
976
+
977
+ self.output_blocks.append(nn.ModuleList(layers))
978
+
979
+ # level 1
980
+ for i in range(3):
981
+ layers = [
982
+ ResnetBlock2D(
983
+ in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
984
+ out_channels=2 * self.model_channels,
985
+ ),
986
+ Transformer2DModel(
987
+ num_attention_heads=2 * self.model_channels // 64,
988
+ attention_head_dim=64,
989
+ in_channels=2 * self.model_channels,
990
+ num_transformer_layers=2,
991
+ use_linear_projection=True,
992
+ cross_attention_dim=2048,
993
+ ),
994
+ ]
995
+ if i == 2:
996
+ layers.append(
997
+ Upsample2D(
998
+ channels=2 * self.model_channels,
999
+ out_channels=2 * self.model_channels,
1000
+ )
1001
+ )
1002
+
1003
+ self.output_blocks.append(nn.ModuleList(layers))
1004
+
1005
+ # level 0
1006
+ for i in range(3):
1007
+ layers = [
1008
+ ResnetBlock2D(
1009
+ in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
1010
+ out_channels=1 * self.model_channels,
1011
+ ),
1012
+ ]
1013
+
1014
+ self.output_blocks.append(nn.ModuleList(layers))
1015
+
1016
+ # output
1017
+ self.out = nn.ModuleList(
1018
+ [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
1019
+ )
1020
+
1021
+ # region diffusers compatibility
1022
+ def prepare_config(self):
1023
+ self.config = SimpleNamespace()
1024
+
1025
+ @property
1026
+ def dtype(self) -> torch.dtype:
1027
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1028
+ return get_parameter_dtype(self)
1029
+
1030
+ @property
1031
+ def device(self) -> torch.device:
1032
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1033
+ return get_parameter_device(self)
1034
+
1035
+ def set_attention_slice(self, slice_size):
1036
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1037
+
1038
+ def is_gradient_checkpointing(self) -> bool:
1039
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1040
+
1041
+ def enable_gradient_checkpointing(self):
1042
+ self.gradient_checkpointing = True
1043
+ self.set_gradient_checkpointing(value=True)
1044
+
1045
+ def disable_gradient_checkpointing(self):
1046
+ self.gradient_checkpointing = False
1047
+ self.set_gradient_checkpointing(value=False)
1048
+
1049
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1050
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1051
+ for block in blocks:
1052
+ for module in block:
1053
+ if hasattr(module, "set_use_memory_efficient_attention"):
1054
+ # logger.info(module.__class__.__name__)
1055
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1056
+
1057
+ def set_use_sdpa(self, sdpa: bool) -> None:
1058
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1059
+ for block in blocks:
1060
+ for module in block:
1061
+ if hasattr(module, "set_use_sdpa"):
1062
+ module.set_use_sdpa(sdpa)
1063
+
1064
+ def set_gradient_checkpointing(self, value=False):
1065
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1066
+ for block in blocks:
1067
+ for module in block.modules():
1068
+ if hasattr(module, "gradient_checkpointing"):
1069
+ # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
1070
+ module.gradient_checkpointing = value
1071
+
1072
+ # endregion
1073
+
1074
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1075
+ # broadcast timesteps to batch dimension
1076
+ timesteps = timesteps.expand(x.shape[0])
1077
+
1078
+ hs = []
1079
+ t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
1080
+ t_emb = t_emb.to(x.dtype)
1081
+ emb = self.time_embed(t_emb)
1082
+
1083
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1084
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1085
+ # assert x.dtype == self.dtype
1086
+ emb = emb + self.label_emb(y)
1087
+
1088
+ def call_module(module, h, emb, context):
1089
+ x = h
1090
+ for layer in module:
1091
+ # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1092
+ if isinstance(layer, ResnetBlock2D):
1093
+ x = layer(x, emb)
1094
+ elif isinstance(layer, Transformer2DModel):
1095
+ x = layer(x, context)
1096
+ else:
1097
+ x = layer(x)
1098
+ return x
1099
+
1100
+ # h = x.type(self.dtype)
1101
+ h = x
1102
+
1103
+ for module in self.input_blocks:
1104
+ h = call_module(module, h, emb, context)
1105
+ hs.append(h)
1106
+
1107
+ h = call_module(self.middle_block, h, emb, context)
1108
+
1109
+ for module in self.output_blocks:
1110
+ h = torch.cat([h, hs.pop()], dim=1)
1111
+ h = call_module(module, h, emb, context)
1112
+
1113
+ h = h.type(x.dtype)
1114
+ h = call_module(self.out, h, emb, context)
1115
+
1116
+ return h
1117
+
1118
+
1119
+ class InferSdxlUNet2DConditionModel:
1120
+ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
1121
+ self.delegate = original_unet
1122
+
1123
+ # override original model's forward method: because forward is not called by `__call__`
1124
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1125
+ self.delegate.forward = self.forward
1126
+
1127
+ # Deep Shrink
1128
+ self.ds_depth_1 = None
1129
+ self.ds_depth_2 = None
1130
+ self.ds_timesteps_1 = None
1131
+ self.ds_timesteps_2 = None
1132
+ self.ds_ratio = None
1133
+
1134
+ # call original model's methods
1135
+ def __getattr__(self, name):
1136
+ return getattr(self.delegate, name)
1137
+
1138
+ def __call__(self, *args, **kwargs):
1139
+ return self.delegate(*args, **kwargs)
1140
+
1141
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1142
+ if ds_depth_1 is None:
1143
+ logger.info("Deep Shrink is disabled.")
1144
+ self.ds_depth_1 = None
1145
+ self.ds_timesteps_1 = None
1146
+ self.ds_depth_2 = None
1147
+ self.ds_timesteps_2 = None
1148
+ self.ds_ratio = None
1149
+ else:
1150
+ logger.info(
1151
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1152
+ )
1153
+ self.ds_depth_1 = ds_depth_1
1154
+ self.ds_timesteps_1 = ds_timesteps_1
1155
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1156
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1157
+ self.ds_ratio = ds_ratio
1158
+
1159
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1160
+ r"""
1161
+ current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
1162
+ """
1163
+ _self = self.delegate
1164
+
1165
+ # broadcast timesteps to batch dimension
1166
+ timesteps = timesteps.expand(x.shape[0])
1167
+
1168
+ hs = []
1169
+ t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
1170
+ t_emb = t_emb.to(x.dtype)
1171
+ emb = _self.time_embed(t_emb)
1172
+
1173
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1174
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1175
+ # assert x.dtype == _self.dtype
1176
+ emb = emb + _self.label_emb(y)
1177
+
1178
+ def call_module(module, h, emb, context):
1179
+ x = h
1180
+ for layer in module:
1181
+ # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1182
+ if isinstance(layer, ResnetBlock2D):
1183
+ x = layer(x, emb)
1184
+ elif isinstance(layer, Transformer2DModel):
1185
+ x = layer(x, context)
1186
+ else:
1187
+ x = layer(x)
1188
+ return x
1189
+
1190
+ # h = x.type(self.dtype)
1191
+ h = x
1192
+
1193
+ for depth, module in enumerate(_self.input_blocks):
1194
+ # Deep Shrink
1195
+ if self.ds_depth_1 is not None:
1196
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1197
+ self.ds_depth_2 is not None
1198
+ and depth == self.ds_depth_2
1199
+ and timesteps[0] < self.ds_timesteps_1
1200
+ and timesteps[0] >= self.ds_timesteps_2
1201
+ ):
1202
+ # print("downsample", h.shape, self.ds_ratio)
1203
+ org_dtype = h.dtype
1204
+ if org_dtype == torch.bfloat16:
1205
+ h = h.to(torch.float32)
1206
+ h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1207
+
1208
+ h = call_module(module, h, emb, context)
1209
+ hs.append(h)
1210
+
1211
+ h = call_module(_self.middle_block, h, emb, context)
1212
+
1213
+ for module in _self.output_blocks:
1214
+ # Deep Shrink
1215
+ if self.ds_depth_1 is not None:
1216
+ if hs[-1].shape[-2:] != h.shape[-2:]:
1217
+ # print("upsample", h.shape, hs[-1].shape)
1218
+ h = resize_like(h, hs[-1])
1219
+
1220
+ h = torch.cat([h, hs.pop()], dim=1)
1221
+ h = call_module(module, h, emb, context)
1222
+
1223
+ # Deep Shrink: in case of depth 0
1224
+ if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
1225
+ # print("upsample", h.shape, x.shape)
1226
+ h = resize_like(h, x)
1227
+
1228
+ h = h.type(x.dtype)
1229
+ h = call_module(_self.out, h, emb, context)
1230
+
1231
+ return h
1232
+
1233
+
1234
+ if __name__ == "__main__":
1235
+ import time
1236
+
1237
+ logger.info("create unet")
1238
+ unet = SdxlUNet2DConditionModel()
1239
+
1240
+ unet.to("cuda")
1241
+ unet.set_use_memory_efficient_attention(True, False)
1242
+ unet.set_gradient_checkpointing(True)
1243
+ unet.train()
1244
+
1245
+ # 使用メモリ量確認用の疑似学習ループ
1246
+ logger.info("preparing optimizer")
1247
+
1248
+ # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
1249
+
1250
+ # import bitsandbytes
1251
+ # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
1252
+ # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1253
+ # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1254
+
1255
+ import transformers
1256
+
1257
+ optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
1258
+
1259
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
1260
+
1261
+ logger.info("start training")
1262
+ steps = 10
1263
+ batch_size = 1
1264
+
1265
+ for step in range(steps):
1266
+ logger.info(f"step {step}")
1267
+ if step == 1:
1268
+ time_start = time.perf_counter()
1269
+
1270
+ x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
1271
+ t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
1272
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
1273
+ y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
1274
+
1275
+ with torch.cuda.amp.autocast(enabled=True):
1276
+ output = unet(x, t, ctx, y)
1277
+ target = torch.randn_like(output)
1278
+ loss = torch.nn.functional.mse_loss(output, target)
1279
+
1280
+ scaler.scale(loss).backward()
1281
+ scaler.step(optimizer)
1282
+ scaler.update()
1283
+ optimizer.zero_grad(set_to_none=True)
1284
+
1285
+ time_end = time.perf_counter()
1286
+ logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
library/sdxl_train_util.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from library.device_utils import init_ipex, clean_memory_on_device
8
+
9
+ init_ipex()
10
+
11
+ from accelerate import init_empty_weights
12
+ from tqdm import tqdm
13
+ from transformers import CLIPTokenizer
14
+ from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
15
+ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
16
+ from .utils import setup_logging
17
+
18
+ setup_logging()
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
24
+ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
25
+
26
+ # DEFAULT_NOISE_OFFSET = 0.0357
27
+
28
+
29
+ def load_target_model(args, accelerator, model_version: str, weight_dtype):
30
+ model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
31
+ for pi in range(accelerator.state.num_processes):
32
+ if pi == accelerator.state.local_process_index:
33
+ logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
34
+
35
+ (
36
+ load_stable_diffusion_format,
37
+ text_encoder1,
38
+ text_encoder2,
39
+ vae,
40
+ unet,
41
+ logit_scale,
42
+ ckpt_info,
43
+ ) = _load_target_model(
44
+ args.pretrained_model_name_or_path,
45
+ args.vae,
46
+ model_version,
47
+ weight_dtype,
48
+ accelerator.device if args.lowram else "cpu",
49
+ model_dtype,
50
+ args.disable_mmap_load_safetensors,
51
+ )
52
+
53
+ # work on low-ram device
54
+ if args.lowram:
55
+ text_encoder1.to(accelerator.device)
56
+ text_encoder2.to(accelerator.device)
57
+ unet.to(accelerator.device)
58
+ vae.to(accelerator.device)
59
+
60
+ clean_memory_on_device(accelerator.device)
61
+ accelerator.wait_for_everyone()
62
+
63
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
64
+
65
+
66
+ def _load_target_model(
67
+ name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
68
+ ):
69
+ # model_dtype only work with full fp16/bf16
70
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
71
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
72
+
73
+ if load_stable_diffusion_format:
74
+ logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
75
+ (
76
+ text_encoder1,
77
+ text_encoder2,
78
+ vae,
79
+ unet,
80
+ logit_scale,
81
+ ckpt_info,
82
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
83
+ else:
84
+ # Diffusers model is loaded to CPU
85
+ from diffusers import StableDiffusionXLPipeline
86
+
87
+ variant = "fp16" if weight_dtype == torch.float16 else None
88
+ logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
89
+ try:
90
+ try:
91
+ pipe = StableDiffusionXLPipeline.from_pretrained(
92
+ name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
93
+ )
94
+ except EnvironmentError as ex:
95
+ if variant is not None:
96
+ logger.info("try to load fp32 model")
97
+ pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
98
+ else:
99
+ raise ex
100
+ except EnvironmentError as ex:
101
+ logger.error(
102
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
103
+ )
104
+ raise ex
105
+
106
+ text_encoder1 = pipe.text_encoder
107
+ text_encoder2 = pipe.text_encoder_2
108
+
109
+ # convert to fp32 for cache text_encoders outputs
110
+ if text_encoder1.dtype != torch.float32:
111
+ text_encoder1 = text_encoder1.to(dtype=torch.float32)
112
+ if text_encoder2.dtype != torch.float32:
113
+ text_encoder2 = text_encoder2.to(dtype=torch.float32)
114
+
115
+ vae = pipe.vae
116
+ unet = pipe.unet
117
+ del pipe
118
+
119
+ # Diffusers U-Net to original U-Net
120
+ state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
121
+ with init_empty_weights():
122
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
123
+ sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
124
+ logger.info("U-Net converted to original U-Net")
125
+
126
+ logit_scale = None
127
+ ckpt_info = None
128
+
129
+ # VAEを読み込む
130
+ if vae_path is not None:
131
+ vae = model_util.load_vae(vae_path, weight_dtype)
132
+ logger.info("additional VAE loaded")
133
+
134
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
135
+
136
+
137
+ def load_tokenizers(args: argparse.Namespace):
138
+ logger.info("prepare tokenizers")
139
+
140
+ original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
141
+ tokeniers = []
142
+ for i, original_path in enumerate(original_paths):
143
+ tokenizer: CLIPTokenizer = None
144
+ if args.tokenizer_cache_dir:
145
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
146
+ if os.path.exists(local_tokenizer_path):
147
+ logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
148
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
149
+
150
+ if tokenizer is None:
151
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
152
+
153
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
154
+ logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
155
+ tokenizer.save_pretrained(local_tokenizer_path)
156
+
157
+ if i == 1:
158
+ tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
159
+
160
+ tokeniers.append(tokenizer)
161
+
162
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
163
+ logger.info(f"update token length: {args.max_token_length}")
164
+
165
+ return tokeniers
166
+
167
+
168
+ def match_mixed_precision(args, weight_dtype):
169
+ if args.full_fp16:
170
+ assert (
171
+ weight_dtype == torch.float16
172
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
173
+ return weight_dtype
174
+ elif args.full_bf16:
175
+ assert (
176
+ weight_dtype == torch.bfloat16
177
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
178
+ return weight_dtype
179
+ else:
180
+ return None
181
+
182
+
183
+ def timestep_embedding(timesteps, dim, max_period=10000):
184
+ """
185
+ Create sinusoidal timestep embeddings.
186
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
187
+ These may be fractional.
188
+ :param dim: the dimension of the output.
189
+ :param max_period: controls the minimum frequency of the embeddings.
190
+ :return: an [N x dim] Tensor of positional embeddings.
191
+ """
192
+ half = dim // 2
193
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
194
+ device=timesteps.device
195
+ )
196
+ args = timesteps[:, None].float() * freqs[None]
197
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
198
+ if dim % 2:
199
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
200
+ return embedding
201
+
202
+
203
+ def get_timestep_embedding(x, outdim):
204
+ assert len(x.shape) == 2
205
+ b, dims = x.shape[0], x.shape[1]
206
+ x = torch.flatten(x)
207
+ emb = timestep_embedding(x, outdim)
208
+ emb = torch.reshape(emb, (b, dims * outdim))
209
+ return emb
210
+
211
+
212
+ def get_size_embeddings(orig_size, crop_size, target_size, device):
213
+ emb1 = get_timestep_embedding(orig_size, 256)
214
+ emb2 = get_timestep_embedding(crop_size, 256)
215
+ emb3 = get_timestep_embedding(target_size, 256)
216
+ vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
217
+ return vector
218
+
219
+
220
+ def save_sd_model_on_train_end(
221
+ args: argparse.Namespace,
222
+ src_path: str,
223
+ save_stable_diffusion_format: bool,
224
+ use_safetensors: bool,
225
+ save_dtype: torch.dtype,
226
+ epoch: int,
227
+ global_step: int,
228
+ text_encoder1,
229
+ text_encoder2,
230
+ unet,
231
+ vae,
232
+ logit_scale,
233
+ ckpt_info,
234
+ ):
235
+ def sd_saver(ckpt_file, epoch_no, global_step):
236
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
237
+ sdxl_model_util.save_stable_diffusion_checkpoint(
238
+ ckpt_file,
239
+ text_encoder1,
240
+ text_encoder2,
241
+ unet,
242
+ epoch_no,
243
+ global_step,
244
+ ckpt_info,
245
+ vae,
246
+ logit_scale,
247
+ sai_metadata,
248
+ save_dtype,
249
+ )
250
+
251
+ def diffusers_saver(out_dir):
252
+ sdxl_model_util.save_diffusers_checkpoint(
253
+ out_dir,
254
+ text_encoder1,
255
+ text_encoder2,
256
+ unet,
257
+ src_path,
258
+ vae,
259
+ use_safetensors=use_safetensors,
260
+ save_dtype=save_dtype,
261
+ )
262
+
263
+ train_util.save_sd_model_on_train_end_common(
264
+ args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
265
+ )
266
+
267
+
268
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
269
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
270
+ def save_sd_model_on_epoch_end_or_stepwise(
271
+ args: argparse.Namespace,
272
+ on_epoch_end: bool,
273
+ accelerator,
274
+ src_path,
275
+ save_stable_diffusion_format: bool,
276
+ use_safetensors: bool,
277
+ save_dtype: torch.dtype,
278
+ epoch: int,
279
+ num_train_epochs: int,
280
+ global_step: int,
281
+ text_encoder1,
282
+ text_encoder2,
283
+ unet,
284
+ vae,
285
+ logit_scale,
286
+ ckpt_info,
287
+ ):
288
+ def sd_saver(ckpt_file, epoch_no, global_step):
289
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
290
+ sdxl_model_util.save_stable_diffusion_checkpoint(
291
+ ckpt_file,
292
+ text_encoder1,
293
+ text_encoder2,
294
+ unet,
295
+ epoch_no,
296
+ global_step,
297
+ ckpt_info,
298
+ vae,
299
+ logit_scale,
300
+ sai_metadata,
301
+ save_dtype,
302
+ )
303
+
304
+ def diffusers_saver(out_dir):
305
+ sdxl_model_util.save_diffusers_checkpoint(
306
+ out_dir,
307
+ text_encoder1,
308
+ text_encoder2,
309
+ unet,
310
+ src_path,
311
+ vae,
312
+ use_safetensors=use_safetensors,
313
+ save_dtype=save_dtype,
314
+ )
315
+
316
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
317
+ args,
318
+ on_epoch_end,
319
+ accelerator,
320
+ save_stable_diffusion_format,
321
+ use_safetensors,
322
+ epoch,
323
+ num_train_epochs,
324
+ global_step,
325
+ sd_saver,
326
+ diffusers_saver,
327
+ )
328
+
329
+
330
+ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
331
+ parser.add_argument(
332
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
333
+ )
334
+ parser.add_argument(
335
+ "--cache_text_encoder_outputs_to_disk",
336
+ action="store_true",
337
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
338
+ )
339
+ parser.add_argument(
340
+ "--disable_mmap_load_safetensors",
341
+ action="store_true",
342
+ help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
343
+ )
344
+
345
+
346
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
347
+ assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
348
+ if args.v_parameterization:
349
+ logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
350
+
351
+ if args.clip_skip is not None:
352
+ logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
353
+
354
+ # if args.multires_noise_iterations:
355
+ # logger.info(
356
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
357
+ # )
358
+ # else:
359
+ # if args.noise_offset is None:
360
+ # args.noise_offset = DEFAULT_NOISE_OFFSET
361
+ # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
362
+ # logger.info(
363
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
364
+ # )
365
+ # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
366
+
367
+ assert (
368
+ not hasattr(args, "weighted_captions") or not args.weighted_captions
369
+ ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
370
+
371
+ if supportTextEncoderCaching:
372
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
373
+ args.cache_text_encoder_outputs = True
374
+ logger.warning(
375
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
376
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
377
+ )
378
+
379
+
380
+ def sample_images(*args, **kwargs):
381
+ return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
library/slicing_vae.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Diffusers to reduce VRAM usage
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
27
+ from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
28
+ from diffusers.models.autoencoder_kl import AutoencoderKLOutput
29
+ from .utils import setup_logging
30
+ setup_logging()
31
+ import logging
32
+ logger = logging.getLogger(__name__)
33
+
34
+ def slice_h(x, num_slices):
35
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
36
+ # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
37
+ # NCHWでもNHWCでもどちらでも動く
38
+ size = (x.shape[2] + num_slices - 1) // num_slices
39
+ sliced = []
40
+ for i in range(num_slices):
41
+ if i == 0:
42
+ sliced.append(x[:, :, : size + 1, :])
43
+ else:
44
+ end = size * (i + 1) + 1
45
+ if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
46
+ end = x.shape[2]
47
+ sliced.append(x[:, :, size * i - 1 : end, :])
48
+ if end >= x.shape[2]:
49
+ break
50
+ return sliced
51
+
52
+
53
+ def cat_h(sliced):
54
+ # padding分を除いて結合する
55
+ cat = []
56
+ for i, x in enumerate(sliced):
57
+ if i == 0:
58
+ cat.append(x[:, :, :-1, :])
59
+ elif i == len(sliced) - 1:
60
+ cat.append(x[:, :, 1:, :])
61
+ else:
62
+ cat.append(x[:, :, 1:-1, :])
63
+ del x
64
+ x = torch.cat(cat, dim=2)
65
+ return x
66
+
67
+
68
+ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
69
+ assert _self.upsample is None and _self.downsample is None
70
+ assert _self.norm1.num_groups == _self.norm2.num_groups
71
+ assert temb is None
72
+
73
+ # make sure norms are on cpu
74
+ org_device = input_tensor.device
75
+ cpu_device = torch.device("cpu")
76
+ _self.norm1.to(cpu_device)
77
+ _self.norm2.to(cpu_device)
78
+
79
+ # GroupNormがCPUでfp16で動かない対策
80
+ org_dtype = input_tensor.dtype
81
+ if org_dtype == torch.float16:
82
+ _self.norm1.to(torch.float32)
83
+ _self.norm2.to(torch.float32)
84
+
85
+ # すべてのテンソルをCPUに移動する
86
+ input_tensor = input_tensor.to(cpu_device)
87
+ hidden_states = input_tensor
88
+
89
+ # どうもこれは結果が異なるようだ……
90
+ # def sliced_norm1(norm, x):
91
+ # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
92
+ # sliced_tensor = torch.chunk(x, num_div, dim=1)
93
+ # sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
94
+ # sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
95
+ # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
96
+ # normed_tensor = []
97
+ # for i in range(num_div):
98
+ # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
99
+ # normed_tensor.append(n)
100
+ # del n
101
+ # x = torch.cat(normed_tensor, dim=1)
102
+ # return num_div, x
103
+
104
+ # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
105
+ if org_dtype == torch.float16:
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ hidden_states = _self.norm1(hidden_states) # run on cpu
108
+ if org_dtype == torch.float16:
109
+ hidden_states = hidden_states.to(torch.float16)
110
+
111
+ sliced = slice_h(hidden_states, num_slices)
112
+ del hidden_states
113
+
114
+ for i in range(len(sliced)):
115
+ x = sliced[i]
116
+ sliced[i] = None
117
+
118
+ # 計算する部分だけGPUに移動する、以下同様
119
+ x = x.to(org_device)
120
+ x = _self.nonlinearity(x)
121
+ x = _self.conv1(x)
122
+ x = x.to(cpu_device)
123
+ sliced[i] = x
124
+ del x
125
+
126
+ hidden_states = cat_h(sliced)
127
+ del sliced
128
+
129
+ if org_dtype == torch.float16:
130
+ hidden_states = hidden_states.to(torch.float32)
131
+ hidden_states = _self.norm2(hidden_states) # run on cpu
132
+ if org_dtype == torch.float16:
133
+ hidden_states = hidden_states.to(torch.float16)
134
+
135
+ sliced = slice_h(hidden_states, num_slices)
136
+ del hidden_states
137
+
138
+ for i in range(len(sliced)):
139
+ x = sliced[i]
140
+ sliced[i] = None
141
+
142
+ x = x.to(org_device)
143
+ x = _self.nonlinearity(x)
144
+ x = _self.dropout(x)
145
+ x = _self.conv2(x)
146
+ x = x.to(cpu_device)
147
+ sliced[i] = x
148
+ del x
149
+
150
+ hidden_states = cat_h(sliced)
151
+ del sliced
152
+
153
+ # make shortcut
154
+ if _self.conv_shortcut is not None:
155
+ sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
156
+ del input_tensor
157
+
158
+ for i in range(len(sliced)):
159
+ x = sliced[i]
160
+ sliced[i] = None
161
+
162
+ x = x.to(org_device)
163
+ x = _self.conv_shortcut(x)
164
+ x = x.to(cpu_device)
165
+ sliced[i] = x
166
+ del x
167
+
168
+ input_tensor = torch.cat(sliced, dim=2)
169
+ del sliced
170
+
171
+ output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
172
+
173
+ output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
174
+ return output_tensor
175
+
176
+
177
+ class SlicingEncoder(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channels=3,
181
+ out_channels=3,
182
+ down_block_types=("DownEncoderBlock2D",),
183
+ block_out_channels=(64,),
184
+ layers_per_block=2,
185
+ norm_num_groups=32,
186
+ act_fn="silu",
187
+ double_z=True,
188
+ num_slices=2,
189
+ ):
190
+ super().__init__()
191
+ self.layers_per_block = layers_per_block
192
+
193
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
194
+
195
+ self.mid_block = None
196
+ self.down_blocks = nn.ModuleList([])
197
+
198
+ # down
199
+ output_channel = block_out_channels[0]
200
+ for i, down_block_type in enumerate(down_block_types):
201
+ input_channel = output_channel
202
+ output_channel = block_out_channels[i]
203
+ is_final_block = i == len(block_out_channels) - 1
204
+
205
+ down_block = get_down_block(
206
+ down_block_type,
207
+ num_layers=self.layers_per_block,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ add_downsample=not is_final_block,
211
+ resnet_eps=1e-6,
212
+ downsample_padding=0,
213
+ resnet_act_fn=act_fn,
214
+ resnet_groups=norm_num_groups,
215
+ attention_head_dim=output_channel,
216
+ temb_channels=None,
217
+ )
218
+ self.down_blocks.append(down_block)
219
+
220
+ # mid
221
+ self.mid_block = UNetMidBlock2D(
222
+ in_channels=block_out_channels[-1],
223
+ resnet_eps=1e-6,
224
+ resnet_act_fn=act_fn,
225
+ output_scale_factor=1,
226
+ resnet_time_scale_shift="default",
227
+ attention_head_dim=block_out_channels[-1],
228
+ resnet_groups=norm_num_groups,
229
+ temb_channels=None,
230
+ )
231
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
232
+
233
+ # out
234
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
235
+ self.conv_act = nn.SiLU()
236
+
237
+ conv_out_channels = 2 * out_channels if double_z else out_channels
238
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
239
+
240
+ # replace forward of ResBlocks
241
+ def wrapper(func, module, num_slices):
242
+ def forward(*args, **kwargs):
243
+ return func(module, num_slices, *args, **kwargs)
244
+
245
+ return forward
246
+
247
+ self.num_slices = num_slices
248
+ div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
249
+ # logger.info(f"initial divisor: {div}")
250
+ if div >= 2:
251
+ div = int(div)
252
+ for resnet in self.mid_block.resnets:
253
+ resnet.forward = wrapper(resblock_forward, resnet, div)
254
+ # midblock doesn't have downsample
255
+
256
+ for i, down_block in enumerate(self.down_blocks[::-1]):
257
+ if div >= 2:
258
+ div = int(div)
259
+ # logger.info(f"down block: {i} divisor: {div}")
260
+ for resnet in down_block.resnets:
261
+ resnet.forward = wrapper(resblock_forward, resnet, div)
262
+ if down_block.downsamplers is not None:
263
+ # logger.info("has downsample")
264
+ for downsample in down_block.downsamplers:
265
+ downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
266
+ div *= 2
267
+
268
+ def forward(self, x):
269
+ sample = x
270
+ del x
271
+
272
+ org_device = sample.device
273
+ cpu_device = torch.device("cpu")
274
+
275
+ # sample = self.conv_in(sample)
276
+ sample = sample.to(cpu_device)
277
+ sliced = slice_h(sample, self.num_slices)
278
+ del sample
279
+
280
+ for i in range(len(sliced)):
281
+ x = sliced[i]
282
+ sliced[i] = None
283
+
284
+ x = x.to(org_device)
285
+ x = self.conv_in(x)
286
+ x = x.to(cpu_device)
287
+ sliced[i] = x
288
+ del x
289
+
290
+ sample = cat_h(sliced)
291
+ del sliced
292
+
293
+ sample = sample.to(org_device)
294
+
295
+ # down
296
+ for down_block in self.down_blocks:
297
+ sample = down_block(sample)
298
+
299
+ # middle
300
+ sample = self.mid_block(sample)
301
+
302
+ # post-process
303
+ # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
304
+ sample = self.conv_norm_out(sample)
305
+ sample = self.conv_act(sample)
306
+ sample = self.conv_out(sample)
307
+
308
+ return sample
309
+
310
+ def downsample_forward(self, _self, num_slices, hidden_states):
311
+ assert hidden_states.shape[1] == _self.channels
312
+ assert _self.use_conv and _self.padding == 0
313
+ logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
314
+
315
+ org_device = hidden_states.device
316
+ cpu_device = torch.device("cpu")
317
+
318
+ hidden_states = hidden_states.to(cpu_device)
319
+ pad = (0, 1, 0, 1)
320
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
321
+
322
+ # slice with even number because of stride 2
323
+ # strideが2なので偶数でスライスする
324
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
325
+ size = (hidden_states.shape[2] + num_slices - 1) // num_slices
326
+ size = size + 1 if size % 2 == 1 else size
327
+
328
+ sliced = []
329
+ for i in range(num_slices):
330
+ if i == 0:
331
+ sliced.append(hidden_states[:, :, : size + 1, :])
332
+ else:
333
+ end = size * (i + 1) + 1
334
+ if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
335
+ end = hidden_states.shape[2]
336
+ sliced.append(hidden_states[:, :, size * i - 1 : end, :])
337
+ if end >= hidden_states.shape[2]:
338
+ break
339
+ del hidden_states
340
+
341
+ for i in range(len(sliced)):
342
+ x = sliced[i]
343
+ sliced[i] = None
344
+
345
+ x = x.to(org_device)
346
+ x = _self.conv(x)
347
+ x = x.to(cpu_device)
348
+
349
+ # ここだけ雰囲気が違うのはCopilotのせい
350
+ if i == 0:
351
+ hidden_states = x
352
+ else:
353
+ hidden_states = torch.cat([hidden_states, x], dim=2)
354
+
355
+ hidden_states = hidden_states.to(org_device)
356
+ # logger.info(f"downsample forward done {hidden_states.shape}")
357
+ return hidden_states
358
+
359
+
360
+ class SlicingDecoder(nn.Module):
361
+ def __init__(
362
+ self,
363
+ in_channels=3,
364
+ out_channels=3,
365
+ up_block_types=("UpDecoderBlock2D",),
366
+ block_out_channels=(64,),
367
+ layers_per_block=2,
368
+ norm_num_groups=32,
369
+ act_fn="silu",
370
+ num_slices=2,
371
+ ):
372
+ super().__init__()
373
+ self.layers_per_block = layers_per_block
374
+
375
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
376
+
377
+ self.mid_block = None
378
+ self.up_blocks = nn.ModuleList([])
379
+
380
+ # mid
381
+ self.mid_block = UNetMidBlock2D(
382
+ in_channels=block_out_channels[-1],
383
+ resnet_eps=1e-6,
384
+ resnet_act_fn=act_fn,
385
+ output_scale_factor=1,
386
+ resnet_time_scale_shift="default",
387
+ attention_head_dim=block_out_channels[-1],
388
+ resnet_groups=norm_num_groups,
389
+ temb_channels=None,
390
+ )
391
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
392
+
393
+ # up
394
+ reversed_block_out_channels = list(reversed(block_out_channels))
395
+ output_channel = reversed_block_out_channels[0]
396
+ for i, up_block_type in enumerate(up_block_types):
397
+ prev_output_channel = output_channel
398
+ output_channel = reversed_block_out_channels[i]
399
+
400
+ is_final_block = i == len(block_out_channels) - 1
401
+
402
+ up_block = get_up_block(
403
+ up_block_type,
404
+ num_layers=self.layers_per_block + 1,
405
+ in_channels=prev_output_channel,
406
+ out_channels=output_channel,
407
+ prev_output_channel=None,
408
+ add_upsample=not is_final_block,
409
+ resnet_eps=1e-6,
410
+ resnet_act_fn=act_fn,
411
+ resnet_groups=norm_num_groups,
412
+ attention_head_dim=output_channel,
413
+ temb_channels=None,
414
+ )
415
+ self.up_blocks.append(up_block)
416
+ prev_output_channel = output_channel
417
+
418
+ # out
419
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
420
+ self.conv_act = nn.SiLU()
421
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
422
+
423
+ # replace forward of ResBlocks
424
+ def wrapper(func, module, num_slices):
425
+ def forward(*args, **kwargs):
426
+ return func(module, num_slices, *args, **kwargs)
427
+
428
+ return forward
429
+
430
+ self.num_slices = num_slices
431
+ div = num_slices / (2 ** (len(self.up_blocks) - 1))
432
+ logger.info(f"initial divisor: {div}")
433
+ if div >= 2:
434
+ div = int(div)
435
+ for resnet in self.mid_block.resnets:
436
+ resnet.forward = wrapper(resblock_forward, resnet, div)
437
+ # midblock doesn't have upsample
438
+
439
+ for i, up_block in enumerate(self.up_blocks):
440
+ if div >= 2:
441
+ div = int(div)
442
+ # logger.info(f"up block: {i} divisor: {div}")
443
+ for resnet in up_block.resnets:
444
+ resnet.forward = wrapper(resblock_forward, resnet, div)
445
+ if up_block.upsamplers is not None:
446
+ # logger.info("has upsample")
447
+ for upsample in up_block.upsamplers:
448
+ upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
449
+ div *= 2
450
+
451
+ def forward(self, z):
452
+ sample = z
453
+ del z
454
+ sample = self.conv_in(sample)
455
+
456
+ # middle
457
+ sample = self.mid_block(sample)
458
+
459
+ # up
460
+ for i, up_block in enumerate(self.up_blocks):
461
+ sample = up_block(sample)
462
+
463
+ # post-process
464
+ sample = self.conv_norm_out(sample)
465
+ sample = self.conv_act(sample)
466
+
467
+ # conv_out with slicing because of VRAM usage
468
+ # conv_outはとてもVRAM使うのでスライスして対応
469
+ org_device = sample.device
470
+ cpu_device = torch.device("cpu")
471
+ sample = sample.to(cpu_device)
472
+
473
+ sliced = slice_h(sample, self.num_slices)
474
+ del sample
475
+ for i in range(len(sliced)):
476
+ x = sliced[i]
477
+ sliced[i] = None
478
+
479
+ x = x.to(org_device)
480
+ x = self.conv_out(x)
481
+ x = x.to(cpu_device)
482
+ sliced[i] = x
483
+ sample = cat_h(sliced)
484
+ del sliced
485
+
486
+ sample = sample.to(org_device)
487
+ return sample
488
+
489
+ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
490
+ assert hidden_states.shape[1] == _self.channels
491
+ assert _self.use_conv_transpose == False and _self.use_conv
492
+
493
+ org_dtype = hidden_states.dtype
494
+ org_device = hidden_states.device
495
+ cpu_device = torch.device("cpu")
496
+
497
+ hidden_states = hidden_states.to(cpu_device)
498
+ sliced = slice_h(hidden_states, num_slices)
499
+ del hidden_states
500
+
501
+ for i in range(len(sliced)):
502
+ x = sliced[i]
503
+ sliced[i] = None
504
+
505
+ x = x.to(org_device)
506
+
507
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
508
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
509
+ # https://github.com/pytorch/pytorch/issues/86679
510
+ # PyTorch 2で直らないかね……
511
+ if org_dtype == torch.bfloat16:
512
+ x = x.to(torch.float32)
513
+
514
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
515
+
516
+ if org_dtype == torch.bfloat16:
517
+ x = x.to(org_dtype)
518
+
519
+ x = _self.conv(x)
520
+
521
+ # upsampleされてるのでpadは2になる
522
+ if i == 0:
523
+ x = x[:, :, :-2, :]
524
+ elif i == num_slices - 1:
525
+ x = x[:, :, 2:, :]
526
+ else:
527
+ x = x[:, :, 2:-2, :]
528
+
529
+ x = x.to(cpu_device)
530
+ sliced[i] = x
531
+ del x
532
+
533
+ hidden_states = torch.cat(sliced, dim=2)
534
+ # logger.info(f"us hidden_states {hidden_states.shape}")
535
+ del sliced
536
+
537
+ hidden_states = hidden_states.to(org_device)
538
+ return hidden_states
539
+
540
+
541
+ class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
542
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
543
+ and Max Welling.
544
+
545
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
546
+ implements for all the model (such as downloading or saving, etc.)
547
+
548
+ Parameters:
549
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
550
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
551
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
552
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
553
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
554
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
555
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
556
+ obj:`(64,)`): Tuple of block output channels.
557
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
558
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
559
+ sample_size (`int`, *optional*, defaults to `32`): TODO
560
+ """
561
+
562
+ @register_to_config
563
+ def __init__(
564
+ self,
565
+ in_channels: int = 3,
566
+ out_channels: int = 3,
567
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
568
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
569
+ block_out_channels: Tuple[int] = (64,),
570
+ layers_per_block: int = 1,
571
+ act_fn: str = "silu",
572
+ latent_channels: int = 4,
573
+ norm_num_groups: int = 32,
574
+ sample_size: int = 32,
575
+ num_slices: int = 16,
576
+ ):
577
+ super().__init__()
578
+
579
+ # pass init params to Encoder
580
+ self.encoder = SlicingEncoder(
581
+ in_channels=in_channels,
582
+ out_channels=latent_channels,
583
+ down_block_types=down_block_types,
584
+ block_out_channels=block_out_channels,
585
+ layers_per_block=layers_per_block,
586
+ act_fn=act_fn,
587
+ norm_num_groups=norm_num_groups,
588
+ double_z=True,
589
+ num_slices=num_slices,
590
+ )
591
+
592
+ # pass init params to Decoder
593
+ self.decoder = SlicingDecoder(
594
+ in_channels=latent_channels,
595
+ out_channels=out_channels,
596
+ up_block_types=up_block_types,
597
+ block_out_channels=block_out_channels,
598
+ layers_per_block=layers_per_block,
599
+ norm_num_groups=norm_num_groups,
600
+ act_fn=act_fn,
601
+ num_slices=num_slices,
602
+ )
603
+
604
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
605
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
606
+ self.use_slicing = False
607
+
608
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
609
+ h = self.encoder(x)
610
+ moments = self.quant_conv(h)
611
+ posterior = DiagonalGaussianDistribution(moments)
612
+
613
+ if not return_dict:
614
+ return (posterior,)
615
+
616
+ return AutoencoderKLOutput(latent_dist=posterior)
617
+
618
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
619
+ z = self.post_quant_conv(z)
620
+ dec = self.decoder(z)
621
+
622
+ if not return_dict:
623
+ return (dec,)
624
+
625
+ return DecoderOutput(sample=dec)
626
+
627
+ # これはバッチ方向のスライシング 紛らわしい
628
+ def enable_slicing(self):
629
+ r"""
630
+ Enable sliced VAE decoding.
631
+
632
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
633
+ steps. This is useful to save some memory and allow larger batch sizes.
634
+ """
635
+ self.use_slicing = True
636
+
637
+ def disable_slicing(self):
638
+ r"""
639
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
640
+ decoding in one step.
641
+ """
642
+ self.use_slicing = False
643
+
644
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
645
+ if self.use_slicing and z.shape[0] > 1:
646
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
647
+ decoded = torch.cat(decoded_slices)
648
+ else:
649
+ decoded = self._decode(z).sample
650
+
651
+ if not return_dict:
652
+ return (decoded,)
653
+
654
+ return DecoderOutput(sample=decoded)
655
+
656
+ def forward(
657
+ self,
658
+ sample: torch.FloatTensor,
659
+ sample_posterior: bool = False,
660
+ return_dict: bool = True,
661
+ generator: Optional[torch.Generator] = None,
662
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
663
+ r"""
664
+ Args:
665
+ sample (`torch.FloatTensor`): Input sample.
666
+ sample_posterior (`bool`, *optional*, defaults to `False`):
667
+ Whether to sample from the posterior.
668
+ return_dict (`bool`, *optional*, defaults to `True`):
669
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
670
+ """
671
+ x = sample
672
+ posterior = self.encode(x).latent_dist
673
+ if sample_posterior:
674
+ z = posterior.sample(generator=generator)
675
+ else:
676
+ z = posterior.mode()
677
+ dec = self.decode(z).sample
678
+
679
+ if not return_dict:
680
+ return (dec,)
681
+
682
+ return DecoderOutput(sample=dec)
library/train_util.py ADDED
The diff for this file is too large to render. See raw diff
 
library/utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import threading
4
+ import torch
5
+ from torchvision import transforms
6
+ from typing import *
7
+ from diffusers import EulerAncestralDiscreteScheduler
8
+ import diffusers.schedulers.scheduling_euler_ancestral_discrete
9
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
10
+ import cv2
11
+ from PIL import Image
12
+ import numpy as np
13
+
14
+
15
+ def fire_in_thread(f, *args, **kwargs):
16
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
17
+
18
+
19
+ def add_logging_arguments(parser):
20
+ parser.add_argument(
21
+ "--console_log_level",
22
+ type=str,
23
+ default=None,
24
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
25
+ help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
26
+ )
27
+ parser.add_argument(
28
+ "--console_log_file",
29
+ type=str,
30
+ default=None,
31
+ help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
32
+ )
33
+ parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
34
+
35
+
36
+ def setup_logging(args=None, log_level=None, reset=False):
37
+ if logging.root.handlers:
38
+ if reset:
39
+ # remove all handlers
40
+ for handler in logging.root.handlers[:]:
41
+ logging.root.removeHandler(handler)
42
+ else:
43
+ return
44
+
45
+ # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
46
+ if log_level is None and args is not None:
47
+ log_level = args.console_log_level
48
+ if log_level is None:
49
+ log_level = "INFO"
50
+ log_level = getattr(logging, log_level)
51
+
52
+ msg_init = None
53
+ if args is not None and args.console_log_file:
54
+ handler = logging.FileHandler(args.console_log_file, mode="w")
55
+ else:
56
+ handler = None
57
+ if not args or not args.console_log_simple:
58
+ try:
59
+ from rich.logging import RichHandler
60
+ from rich.console import Console
61
+ from rich.logging import RichHandler
62
+
63
+ handler = RichHandler(console=Console(stderr=True))
64
+ except ImportError:
65
+ # print("rich is not installed, using basic logging")
66
+ msg_init = "rich is not installed, using basic logging"
67
+
68
+ if handler is None:
69
+ handler = logging.StreamHandler(sys.stdout) # same as print
70
+ handler.propagate = False
71
+
72
+ formatter = logging.Formatter(
73
+ fmt="%(message)s",
74
+ datefmt="%Y-%m-%d %H:%M:%S",
75
+ )
76
+ handler.setFormatter(formatter)
77
+ logging.root.setLevel(log_level)
78
+ logging.root.addHandler(handler)
79
+
80
+ if msg_init is not None:
81
+ logger = logging.getLogger(__name__)
82
+ logger.info(msg_init)
83
+
84
+
85
+ def pil_resize(image, size, interpolation=Image.LANCZOS):
86
+ has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
87
+
88
+ if has_alpha:
89
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
90
+ else:
91
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
92
+
93
+ resized_pil = pil_image.resize(size, interpolation)
94
+
95
+ # Convert back to cv2 format
96
+ if has_alpha:
97
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
98
+ else:
99
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
100
+
101
+ return resized_cv2
102
+
103
+
104
+ # TODO make inf_utils.py
105
+
106
+
107
+ # region Gradual Latent hires fix
108
+
109
+
110
+ class GradualLatent:
111
+ def __init__(
112
+ self,
113
+ ratio,
114
+ start_timesteps,
115
+ every_n_steps,
116
+ ratio_step,
117
+ s_noise=1.0,
118
+ gaussian_blur_ksize=None,
119
+ gaussian_blur_sigma=0.5,
120
+ gaussian_blur_strength=0.5,
121
+ unsharp_target_x=True,
122
+ ):
123
+ self.ratio = ratio
124
+ self.start_timesteps = start_timesteps
125
+ self.every_n_steps = every_n_steps
126
+ self.ratio_step = ratio_step
127
+ self.s_noise = s_noise
128
+ self.gaussian_blur_ksize = gaussian_blur_ksize
129
+ self.gaussian_blur_sigma = gaussian_blur_sigma
130
+ self.gaussian_blur_strength = gaussian_blur_strength
131
+ self.unsharp_target_x = unsharp_target_x
132
+
133
+ def __str__(self) -> str:
134
+ return (
135
+ f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
136
+ + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
137
+ + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
138
+ + f"unsharp_target_x={self.unsharp_target_x})"
139
+ )
140
+
141
+ def apply_unshark_mask(self, x: torch.Tensor):
142
+ if self.gaussian_blur_ksize is None:
143
+ return x
144
+ blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
145
+ # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
146
+ mask = (x - blurred) * self.gaussian_blur_strength
147
+ sharpened = x + mask
148
+ return sharpened
149
+
150
+ def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
151
+ org_dtype = x.dtype
152
+ if org_dtype == torch.bfloat16:
153
+ x = x.float()
154
+
155
+ x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
156
+
157
+ # apply unsharp mask / アンシャープマスクを適用する
158
+ if unsharp and self.gaussian_blur_ksize:
159
+ x = self.apply_unshark_mask(x)
160
+
161
+ return x
162
+
163
+
164
+ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
165
+ def __init__(self, *args, **kwargs):
166
+ super().__init__(*args, **kwargs)
167
+ self.resized_size = None
168
+ self.gradual_latent = None
169
+
170
+ def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
171
+ self.resized_size = size
172
+ self.gradual_latent = gradual_latent
173
+
174
+ def step(
175
+ self,
176
+ model_output: torch.FloatTensor,
177
+ timestep: Union[float, torch.FloatTensor],
178
+ sample: torch.FloatTensor,
179
+ generator: Optional[torch.Generator] = None,
180
+ return_dict: bool = True,
181
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
182
+ """
183
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
184
+ process from the learned model outputs (most often the predicted noise).
185
+
186
+ Args:
187
+ model_output (`torch.FloatTensor`):
188
+ The direct output from learned diffusion model.
189
+ timestep (`float`):
190
+ The current discrete timestep in the diffusion chain.
191
+ sample (`torch.FloatTensor`):
192
+ A current instance of a sample created by the diffusion process.
193
+ generator (`torch.Generator`, *optional*):
194
+ A random number generator.
195
+ return_dict (`bool`):
196
+ Whether or not to return a
197
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
198
+
199
+ Returns:
200
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
201
+ If return_dict is `True`,
202
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
203
+ otherwise a tuple is returned where the first element is the sample tensor.
204
+
205
+ """
206
+
207
+ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
208
+ raise ValueError(
209
+ (
210
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
211
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
212
+ " one of the `scheduler.timesteps` as a timestep."
213
+ ),
214
+ )
215
+
216
+ if not self.is_scale_input_called:
217
+ # logger.warning(
218
+ print(
219
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
220
+ "See `StableDiffusionPipeline` for a usage example."
221
+ )
222
+
223
+ if self.step_index is None:
224
+ self._init_step_index(timestep)
225
+
226
+ sigma = self.sigmas[self.step_index]
227
+
228
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
229
+ if self.config.prediction_type == "epsilon":
230
+ pred_original_sample = sample - sigma * model_output
231
+ elif self.config.prediction_type == "v_prediction":
232
+ # * c_out + input * c_skip
233
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
234
+ elif self.config.prediction_type == "sample":
235
+ raise NotImplementedError("prediction_type not implemented yet: sample")
236
+ else:
237
+ raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
238
+
239
+ sigma_from = self.sigmas[self.step_index]
240
+ sigma_to = self.sigmas[self.step_index + 1]
241
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
242
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
243
+
244
+ # 2. Convert to an ODE derivative
245
+ derivative = (sample - pred_original_sample) / sigma
246
+
247
+ dt = sigma_down - sigma
248
+
249
+ device = model_output.device
250
+ if self.resized_size is None:
251
+ prev_sample = sample + derivative * dt
252
+
253
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
254
+ model_output.shape, dtype=model_output.dtype, device=device, generator=generator
255
+ )
256
+ s_noise = 1.0
257
+ else:
258
+ print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
259
+ s_noise = self.gradual_latent.s_noise
260
+
261
+ if self.gradual_latent.unsharp_target_x:
262
+ prev_sample = sample + derivative * dt
263
+ prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
264
+ else:
265
+ sample = self.gradual_latent.interpolate(sample, self.resized_size)
266
+ derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
267
+ prev_sample = sample + derivative * dt
268
+
269
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
270
+ (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
271
+ dtype=model_output.dtype,
272
+ device=device,
273
+ generator=generator,
274
+ )
275
+
276
+ prev_sample = prev_sample + noise * sigma_up * s_noise
277
+
278
+ # upon completion increase step index by one
279
+ self._step_index += 1
280
+
281
+ if not return_dict:
282
+ return (prev_sample,)
283
+
284
+ return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
285
+
286
+
287
+ # endregion