Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,945 Bytes
8f6d6cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from diffusers.models.attention_processor import Attention
class LoRALinearLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
network_alpha: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
cond_width=512,
cond_height=512,
number=0,
n_loras=1
):
super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
self.out_features = out_features
self.in_features = in_features
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
self.cond_height = cond_height
self.cond_width = cond_width
self.number = number
self.n_loras = n_loras
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
####
batch_size = hidden_states.shape[0]
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
block_size = hidden_states.shape[1] - cond_size * self.n_loras
shape = (batch_size, hidden_states.shape[1], 3072)
mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
mask[:, :block_size+self.number*cond_size, :] = 0
mask[:, block_size+(self.number+1)*cond_size:, :] = 0
hidden_states = mask * hidden_states
####
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class MultiSingleStreamBlockLoraProcessor(nn.Module):
def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.cond_width = cond_width
self.cond_height = cond_height
self.q_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.lora_weights = lora_weights
self.bank_attn = None
self.bank_kv = []
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
scaled_seq_len = hidden_states.shape[1]
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
block_size = scaled_seq_len - cond_size * self.n_loras
scaled_cond_size = cond_size
scaled_block_size = block_size
if len(self.bank_kv)== 0:
cache = True
else:
cache = False
if cache:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
self.bank_kv.append(key[:, :, scaled_block_size:, :])
self.bank_kv.append(value[:, :, scaled_block_size:, :])
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
num_cond_blocks = self.n_loras
mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
mask[ :scaled_block_size, :] = 0 # First block_size row
for i in range(num_cond_blocks):
start = i * scaled_cond_size + scaled_block_size
end = (i + 1) * scaled_cond_size + scaled_block_size
mask[start:end, start:end] = 0 # Diagonal blocks
mask = mask * -1e20
mask = mask.to(query.dtype)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
else:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = query.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = torch.concat([key[:, :, :scaled_block_size, :], self.bank_kv[0]], dim=-2)
value = torch.concat([value[:, :, :scaled_block_size, :], self.bank_kv[1]], dim=-2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query[:, :, :scaled_block_size, :]
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
cond_hidden_states = hidden_states[:, block_size:,:]
hidden_states = hidden_states[:, : block_size,:]
return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.cond_width = cond_width
self.cond_height = cond_height
self.q_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
for i in range(n_loras)
])
self.lora_weights = lora_weights
self.bank_attn = None
self.bank_kv = []
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
block_size = hidden_states.shape[1] - cond_size * self.n_loras
scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
scaled_cond_size = cond_size
scaled_block_size = scaled_seq_len - scaled_cond_size * self.n_loras
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
if len(self.bank_kv)== 0:
cache = True
else:
cache = False
if cache:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
self.bank_kv.append(key[:, :, block_size:, :])
self.bank_kv.append(value[:, :, block_size:, :])
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
num_cond_blocks = self.n_loras
mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
mask[ :scaled_block_size, :] = 0 # First block_size row
for i in range(num_cond_blocks):
start = i * scaled_cond_size + scaled_block_size
end = (i + 1) * scaled_cond_size + scaled_block_size
mask[start:end, start:end] = 0 # Diagonal blocks
mask = mask * -1e20
mask = mask.to(query.dtype)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
self.bank_attn = hidden_states[:, :, scaled_block_size:, :]
else:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = query.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2)
value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query[:, :, :scaled_block_size, :]
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
hidden_states = torch.concat([hidden_states, self.bank_attn], dim=-2)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
cond_hidden_states = hidden_states[:, block_size:,:]
hidden_states = hidden_states[:, :block_size,:]
return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states) |