primerz commited on
Commit
125014f
·
verified ·
1 Parent(s): a0e6fd7

Update lora.py

Browse files
Files changed (1) hide show
  1. lora.py +63 -1189
lora.py CHANGED
@@ -1,1222 +1,96 @@
1
- # LoRA network module taken from https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py
2
- # reference:
3
- # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
- # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
-
6
  import math
7
- import os
8
- from typing import Dict, List, Optional, Tuple, Type, Union
9
- from diffusers import AutoencoderKL
10
- from transformers import CLIPTextModel
11
- import numpy as np
12
  import torch
13
- import re
14
-
15
-
16
- RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
17
-
18
- RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
19
 
20
-
21
- class LoRAModule(torch.nn.Module):
22
  """
23
- replaces forward method of the original Linear, instead of replacing the original Linear module.
24
  """
25
 
26
  def __init__(
27
  self,
28
- lora_name,
29
- org_module: torch.nn.Module,
30
- multiplier=1.0,
31
- lora_dim=4,
32
- alpha=1,
33
- dropout=None,
34
- rank_dropout=None,
35
- module_dropout=None,
36
- ):
37
- """if alpha == 0 or None, alpha is rank (no scaling)."""
38
- super().__init__()
39
- self.lora_name = lora_name
40
-
41
- if org_module.__class__.__name__ == "Conv2d":
42
- in_dim = org_module.in_channels
43
- out_dim = org_module.out_channels
44
- else:
45
- in_dim = org_module.in_features
46
- out_dim = org_module.out_features
47
-
48
- # if limit_rank:
49
- # self.lora_dim = min(lora_dim, in_dim, out_dim)
50
- # if self.lora_dim != lora_dim:
51
- # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
52
- # else:
53
- self.lora_dim = lora_dim
54
-
55
- if org_module.__class__.__name__ == "Conv2d":
56
- kernel_size = org_module.kernel_size
57
- stride = org_module.stride
58
- padding = org_module.padding
59
- self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
60
- self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
61
- else:
62
- self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
63
- self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
64
-
65
- if type(alpha) == torch.Tensor:
66
- alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
67
- alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
68
- self.scale = alpha / self.lora_dim
69
- self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
70
-
71
- # same as microsoft's
72
- torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
73
- torch.nn.init.zeros_(self.lora_up.weight)
74
-
75
- self.multiplier = multiplier
76
- self.org_module = org_module # remove in applying
77
- self.dropout = dropout
78
- self.rank_dropout = rank_dropout
79
- self.module_dropout = module_dropout
80
-
81
- def apply_to(self):
82
- self.org_forward = self.org_module.forward
83
- self.org_module.forward = self.forward
84
- del self.org_module
85
-
86
- def forward(self, x):
87
- org_forwarded = self.org_forward(x)
88
-
89
- # module dropout
90
- if self.module_dropout is not None and self.training:
91
- if torch.rand(1) < self.module_dropout:
92
- return org_forwarded
93
-
94
- lx = self.lora_down(x)
95
-
96
- # normal dropout
97
- if self.dropout is not None and self.training:
98
- lx = torch.nn.functional.dropout(lx, p=self.dropout)
99
-
100
- # rank dropout
101
- if self.rank_dropout is not None and self.training:
102
- mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
103
- if len(lx.size()) == 3:
104
- mask = mask.unsqueeze(1) # for Text Encoder
105
- elif len(lx.size()) == 4:
106
- mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
107
- lx = lx * mask
108
-
109
- # scaling for rank dropout: treat as if the rank is changed
110
- # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
111
- scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
112
- else:
113
- scale = self.scale
114
-
115
- lx = self.lora_up(lx)
116
-
117
- return org_forwarded + lx * self.multiplier * scale
118
-
119
-
120
- class LoRAInfModule(LoRAModule):
121
- def __init__(
122
- self,
123
- lora_name,
124
- org_module: torch.nn.Module,
125
- multiplier=1.0,
126
- lora_dim=4,
127
- alpha=1,
128
- **kwargs,
129
- ):
130
- # no dropout for inference
131
- super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
132
-
133
- self.org_module_ref = [org_module] # 後から参照できるように
134
- self.enabled = True
135
-
136
- # check regional or not by lora_name
137
- self.text_encoder = False
138
- if lora_name.startswith("lora_te_"):
139
- self.regional = False
140
- self.use_sub_prompt = True
141
- self.text_encoder = True
142
- elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
143
- self.regional = False
144
- self.use_sub_prompt = True
145
- elif "time_emb" in lora_name:
146
- self.regional = False
147
- self.use_sub_prompt = False
148
- else:
149
- self.regional = True
150
- self.use_sub_prompt = False
151
-
152
- self.network: LoRANetwork = None
153
-
154
- def set_network(self, network):
155
- self.network = network
156
-
157
- # freezeしてマージする
158
- def merge_to(self, sd, dtype, device):
159
- # get up/down weight
160
- up_weight = sd["lora_up.weight"].to(torch.float).to(device)
161
- down_weight = sd["lora_down.weight"].to(torch.float).to(device)
162
-
163
- # extract weight from org_module
164
- org_sd = self.org_module.state_dict()
165
- weight = org_sd["weight"].to(torch.float)
166
-
167
- # merge weight
168
- if len(weight.size()) == 2:
169
- # linear
170
- weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
171
- elif down_weight.size()[2:4] == (1, 1):
172
- # conv2d 1x1
173
- weight = (
174
- weight
175
- + self.multiplier
176
- * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
177
- * self.scale
178
- )
179
- else:
180
- # conv2d 3x3
181
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
182
- # print(conved.size(), weight.size(), module.stride, module.padding)
183
- weight = weight + self.multiplier * conved * self.scale
184
-
185
- # set weight to org_module
186
- org_sd["weight"] = weight.to(dtype)
187
- self.org_module.load_state_dict(org_sd)
188
-
189
- # 復元できるマージのため、このモジュールのweightを返す
190
- def get_weight(self, multiplier=None):
191
- if multiplier is None:
192
- multiplier = self.multiplier
193
-
194
- # get up/down weight from module
195
- up_weight = self.lora_up.weight.to(torch.float)
196
- down_weight = self.lora_down.weight.to(torch.float)
197
-
198
- # pre-calculated weight
199
- if len(down_weight.size()) == 2:
200
- # linear
201
- weight = self.multiplier * (up_weight @ down_weight) * self.scale
202
- elif down_weight.size()[2:4] == (1, 1):
203
- # conv2d 1x1
204
- weight = (
205
- self.multiplier
206
- * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
207
- * self.scale
208
- )
209
- else:
210
- # conv2d 3x3
211
- conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
212
- weight = self.multiplier * conved * self.scale
213
-
214
- return weight
215
-
216
- def set_region(self, region):
217
- self.region = region
218
- self.region_mask = None
219
-
220
- def default_forward(self, x):
221
- # print("default_forward", self.lora_name, x.size())
222
- return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
223
-
224
- def forward(self, x):
225
- if not self.enabled:
226
- return self.org_forward(x)
227
-
228
- if self.network is None or self.network.sub_prompt_index is None:
229
- return self.default_forward(x)
230
- if not self.regional and not self.use_sub_prompt:
231
- return self.default_forward(x)
232
-
233
- if self.regional:
234
- return self.regional_forward(x)
235
- else:
236
- return self.sub_prompt_forward(x)
237
-
238
- def get_mask_for_x(self, x):
239
- # calculate size from shape of x
240
- if len(x.size()) == 4:
241
- h, w = x.size()[2:4]
242
- area = h * w
243
- else:
244
- area = x.size()[1]
245
-
246
- mask = self.network.mask_dic[area]
247
- if mask is None:
248
- raise ValueError(f"mask is None for resolution {area}")
249
- if len(x.size()) != 4:
250
- mask = torch.reshape(mask, (1, -1, 1))
251
- return mask
252
-
253
- def regional_forward(self, x):
254
- if "attn2_to_out" in self.lora_name:
255
- return self.to_out_forward(x)
256
-
257
- if self.network.mask_dic is None: # sub_prompt_index >= 3
258
- return self.default_forward(x)
259
-
260
- # apply mask for LoRA result
261
- lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
262
- mask = self.get_mask_for_x(lx)
263
- # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
264
- lx = lx * mask
265
-
266
- x = self.org_forward(x)
267
- x = x + lx
268
-
269
- if "attn2_to_q" in self.lora_name and self.network.is_last_network:
270
- x = self.postp_to_q(x)
271
-
272
- return x
273
-
274
- def postp_to_q(self, x):
275
- # repeat x to num_sub_prompts
276
- has_real_uncond = x.size()[0] // self.network.batch_size == 3
277
- qc = self.network.batch_size # uncond
278
- qc += self.network.batch_size * self.network.num_sub_prompts # cond
279
- if has_real_uncond:
280
- qc += self.network.batch_size # real_uncond
281
-
282
- query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
283
- query[: self.network.batch_size] = x[: self.network.batch_size]
284
-
285
- for i in range(self.network.batch_size):
286
- qi = self.network.batch_size + i * self.network.num_sub_prompts
287
- query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
288
-
289
- if has_real_uncond:
290
- query[-self.network.batch_size :] = x[-self.network.batch_size :]
291
-
292
- # print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
293
- return query
294
-
295
- def sub_prompt_forward(self, x):
296
- if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
297
- return self.org_forward(x)
298
-
299
- emb_idx = self.network.sub_prompt_index
300
- if not self.text_encoder:
301
- emb_idx += self.network.batch_size
302
-
303
- # apply sub prompt of X
304
- lx = x[emb_idx :: self.network.num_sub_prompts]
305
- lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
306
-
307
- # print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
308
-
309
- x = self.org_forward(x)
310
- x[emb_idx :: self.network.num_sub_prompts] += lx
311
-
312
- return x
313
-
314
- def to_out_forward(self, x):
315
- # print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
316
-
317
- if self.network.is_last_network:
318
- masks = [None] * self.network.num_sub_prompts
319
- self.network.shared[self.lora_name] = (None, masks)
320
- else:
321
- lx, masks = self.network.shared[self.lora_name]
322
-
323
- # call own LoRA
324
- x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
325
- lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
326
-
327
- if self.network.is_last_network:
328
- lx = torch.zeros(
329
- (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
330
- )
331
- self.network.shared[self.lora_name] = (lx, masks)
332
-
333
- # print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
334
- lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
335
- masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
336
-
337
- # if not last network, return x and masks
338
- x = self.org_forward(x)
339
- if not self.network.is_last_network:
340
- return x
341
-
342
- lx, masks = self.network.shared.pop(self.lora_name)
343
-
344
- # if last network, combine separated x with mask weighted sum
345
- has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
346
-
347
- out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
348
- out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
349
- if has_real_uncond:
350
- out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
351
-
352
- # print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
353
- # for i in range(len(masks)):
354
- # if masks[i] is None:
355
- # masks[i] = torch.zeros_like(masks[-1])
356
-
357
- mask = torch.cat(masks)
358
- mask_sum = torch.sum(mask, dim=0) + 1e-4
359
- for i in range(self.network.batch_size):
360
- # 1枚の画像ごとに処理する
361
- lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
362
- lx1 = lx1 * mask
363
- lx1 = torch.sum(lx1, dim=0)
364
-
365
- xi = self.network.batch_size + i * self.network.num_sub_prompts
366
- x1 = x[xi : xi + self.network.num_sub_prompts]
367
- x1 = x1 * mask
368
- x1 = torch.sum(x1, dim=0)
369
- x1 = x1 / mask_sum
370
-
371
- x1 = x1 + lx1
372
- out[self.network.batch_size + i] = x1
373
-
374
- # print("to_out_forward", x.size(), out.size(), has_real_uncond)
375
- return out
376
-
377
-
378
- def parse_block_lr_kwargs(nw_kwargs):
379
- down_lr_weight = nw_kwargs.get("down_lr_weight", None)
380
- mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
381
- up_lr_weight = nw_kwargs.get("up_lr_weight", None)
382
-
383
- # 以上のいずれにも設定がない場合は無効としてNoneを返す
384
- if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
385
- return None, None, None
386
-
387
- # extract learning rate weight for each block
388
- if down_lr_weight is not None:
389
- # if some parameters are not set, use zero
390
- if "," in down_lr_weight:
391
- down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
392
-
393
- if mid_lr_weight is not None:
394
- mid_lr_weight = float(mid_lr_weight)
395
-
396
- if up_lr_weight is not None:
397
- if "," in up_lr_weight:
398
- up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
399
-
400
- down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
401
- down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
402
- )
403
-
404
- return down_lr_weight, mid_lr_weight, up_lr_weight
405
-
406
-
407
- def create_network(
408
- multiplier: float,
409
- network_dim: Optional[int],
410
- network_alpha: Optional[float],
411
- vae: AutoencoderKL,
412
- text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
413
- unet,
414
- neuron_dropout: Optional[float] = None,
415
- **kwargs,
416
- ):
417
- if network_dim is None:
418
- network_dim = 4 # default
419
- if network_alpha is None:
420
- network_alpha = 1.0
421
-
422
- # extract dim/alpha for conv2d, and block dim
423
- conv_dim = kwargs.get("conv_dim", None)
424
- conv_alpha = kwargs.get("conv_alpha", None)
425
- if conv_dim is not None:
426
- conv_dim = int(conv_dim)
427
- if conv_alpha is None:
428
- conv_alpha = 1.0
429
- else:
430
- conv_alpha = float(conv_alpha)
431
-
432
- # block dim/alpha/lr
433
- block_dims = kwargs.get("block_dims", None)
434
- down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
435
-
436
- # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
437
- if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
438
- block_alphas = kwargs.get("block_alphas", None)
439
- conv_block_dims = kwargs.get("conv_block_dims", None)
440
- conv_block_alphas = kwargs.get("conv_block_alphas", None)
441
-
442
- block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
443
- block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
444
- )
445
-
446
- # remove block dim/alpha without learning rate
447
- block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
448
- block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
449
- )
450
-
451
- else:
452
- block_alphas = None
453
- conv_block_dims = None
454
- conv_block_alphas = None
455
-
456
- # rank/module dropout
457
- rank_dropout = kwargs.get("rank_dropout", None)
458
- if rank_dropout is not None:
459
- rank_dropout = float(rank_dropout)
460
- module_dropout = kwargs.get("module_dropout", None)
461
- if module_dropout is not None:
462
- module_dropout = float(module_dropout)
463
-
464
- # すごく引数が多いな ( ^ω^)・・・
465
- network = LoRANetwork(
466
- text_encoder,
467
- unet,
468
- multiplier=multiplier,
469
- lora_dim=network_dim,
470
- alpha=network_alpha,
471
- dropout=neuron_dropout,
472
- rank_dropout=rank_dropout,
473
- module_dropout=module_dropout,
474
- conv_lora_dim=conv_dim,
475
- conv_alpha=conv_alpha,
476
- block_dims=block_dims,
477
- block_alphas=block_alphas,
478
- conv_block_dims=conv_block_dims,
479
- conv_block_alphas=conv_block_alphas,
480
- varbose=True,
481
- )
482
-
483
- if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
484
- network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
485
-
486
- return network
487
-
488
-
489
- # このメソッドは外部から呼び出される可能性を考慮しておく
490
- # network_dim, network_alpha にはデフォルト値が入っている。
491
- # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
492
- # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
493
- def get_block_dims_and_alphas(
494
- block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
495
- ):
496
- num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
497
-
498
- def parse_ints(s):
499
- return [int(i) for i in s.split(",")]
500
-
501
- def parse_floats(s):
502
- return [float(i) for i in s.split(",")]
503
-
504
- # block_dimsとblock_alphasをパースする。必ず値が入る
505
- if block_dims is not None:
506
- block_dims = parse_ints(block_dims)
507
- assert (
508
- len(block_dims) == num_total_blocks
509
- ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
510
- else:
511
- print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
512
- block_dims = [network_dim] * num_total_blocks
513
-
514
- if block_alphas is not None:
515
- block_alphas = parse_floats(block_alphas)
516
- assert (
517
- len(block_alphas) == num_total_blocks
518
- ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
519
- else:
520
- print(
521
- f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
522
- )
523
- block_alphas = [network_alpha] * num_total_blocks
524
-
525
- # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
526
- if conv_block_dims is not None:
527
- conv_block_dims = parse_ints(conv_block_dims)
528
- assert (
529
- len(conv_block_dims) == num_total_blocks
530
- ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
531
-
532
- if conv_block_alphas is not None:
533
- conv_block_alphas = parse_floats(conv_block_alphas)
534
- assert (
535
- len(conv_block_alphas) == num_total_blocks
536
- ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
537
- else:
538
- if conv_alpha is None:
539
- conv_alpha = 1.0
540
- print(
541
- f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
542
- )
543
- conv_block_alphas = [conv_alpha] * num_total_blocks
544
- else:
545
- if conv_dim is not None:
546
- print(
547
- f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
548
- )
549
- conv_block_dims = [conv_dim] * num_total_blocks
550
- conv_block_alphas = [conv_alpha] * num_total_blocks
551
- else:
552
- conv_block_dims = None
553
- conv_block_alphas = None
554
-
555
- return block_dims, block_alphas, conv_block_dims, conv_block_alphas
556
-
557
-
558
- # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
559
- def get_block_lr_weight(
560
- down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
561
- ) -> Tuple[List[float], List[float], List[float]]:
562
- # パラメータ未指定時は何もせず、今までと同じ動作とする
563
- if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
564
- return None, None, None
565
-
566
- max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
567
-
568
- def get_list(name_with_suffix) -> List[float]:
569
- import math
570
-
571
- tokens = name_with_suffix.split("+")
572
- name = tokens[0]
573
- base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
574
-
575
- if name == "cosine":
576
- return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
577
- elif name == "sine":
578
- return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
579
- elif name == "linear":
580
- return [i / (max_len - 1) + base_lr for i in range(max_len)]
581
- elif name == "reverse_linear":
582
- return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
583
- elif name == "zeros":
584
- return [0.0 + base_lr] * max_len
585
- else:
586
- print(
587
- "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
588
- % (name)
589
- )
590
- return None
591
-
592
- if type(down_lr_weight) == str:
593
- down_lr_weight = get_list(down_lr_weight)
594
- if type(up_lr_weight) == str:
595
- up_lr_weight = get_list(up_lr_weight)
596
-
597
- if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
598
- print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
599
- print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
600
- up_lr_weight = up_lr_weight[:max_len]
601
- down_lr_weight = down_lr_weight[:max_len]
602
-
603
- if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
604
- print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
605
- print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
606
-
607
- if down_lr_weight != None and len(down_lr_weight) < max_len:
608
- down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
609
- if up_lr_weight != None and len(up_lr_weight) < max_len:
610
- up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
611
-
612
- if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
613
- print("apply block learning rate / 階層別学習���を適用します。")
614
- if down_lr_weight != None:
615
- down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
616
- print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
617
- else:
618
- print("down_lr_weight: all 1.0, すべて1.0")
619
-
620
- if mid_lr_weight != None:
621
- mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
622
- print("mid_lr_weight:", mid_lr_weight)
623
- else:
624
- print("mid_lr_weight: 1.0")
625
-
626
- if up_lr_weight != None:
627
- up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
628
- print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
629
- else:
630
- print("up_lr_weight: all 1.0, すべて1.0")
631
-
632
- return down_lr_weight, mid_lr_weight, up_lr_weight
633
-
634
-
635
- # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
636
- def remove_block_dims_and_alphas(
637
- block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
638
- ):
639
- # set 0 to block dim without learning rate to remove the block
640
- if down_lr_weight != None:
641
- for i, lr in enumerate(down_lr_weight):
642
- if lr == 0:
643
- block_dims[i] = 0
644
- if conv_block_dims is not None:
645
- conv_block_dims[i] = 0
646
- if mid_lr_weight != None:
647
- if mid_lr_weight == 0:
648
- block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
649
- if conv_block_dims is not None:
650
- conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
651
- if up_lr_weight != None:
652
- for i, lr in enumerate(up_lr_weight):
653
- if lr == 0:
654
- block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
655
- if conv_block_dims is not None:
656
- conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
657
-
658
- return block_dims, block_alphas, conv_block_dims, conv_block_alphas
659
-
660
-
661
- # 外部から呼び出す可能性を考慮しておく
662
- def get_block_index(lora_name: str) -> int:
663
- block_idx = -1 # invalid lora name
664
-
665
- m = RE_UPDOWN.search(lora_name)
666
- if m:
667
- g = m.groups()
668
- i = int(g[1])
669
- j = int(g[3])
670
- if g[2] == "resnets":
671
- idx = 3 * i + j
672
- elif g[2] == "attentions":
673
- idx = 3 * i + j
674
- elif g[2] == "upsamplers" or g[2] == "downsamplers":
675
- idx = 3 * i + 2
676
-
677
- if g[0] == "down":
678
- block_idx = 1 + idx # 0に該当するLoRAは存在しない
679
- elif g[0] == "up":
680
- block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
681
-
682
- elif "mid_block_" in lora_name:
683
- block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
684
-
685
- return block_idx
686
-
687
-
688
- # Create network from weights for inference, weights are not loaded here (because can be merged)
689
- def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
690
- if weights_sd is None:
691
- if os.path.splitext(file)[1] == ".safetensors":
692
- from safetensors.torch import load_file, safe_open
693
-
694
- weights_sd = load_file(file)
695
- else:
696
- weights_sd = torch.load(file, map_location="cpu")
697
-
698
- # get dim/alpha mapping
699
- modules_dim = {}
700
- modules_alpha = {}
701
- for key, value in weights_sd.items():
702
- if "." not in key:
703
- continue
704
-
705
- lora_name = key.split(".")[0]
706
- if "alpha" in key:
707
- modules_alpha[lora_name] = value
708
- elif "lora_down" in key:
709
- dim = value.size()[0]
710
- modules_dim[lora_name] = dim
711
- # print(lora_name, value.size(), dim)
712
-
713
- # support old LoRA without alpha
714
- for key in modules_dim.keys():
715
- if key not in modules_alpha:
716
- modules_alpha[key] = modules_dim[key]
717
-
718
- module_class = LoRAInfModule if for_inference else LoRAModule
719
-
720
- network = LoRANetwork(
721
- text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
722
- )
723
-
724
- # block lr
725
- down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
726
- if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
727
- network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
728
-
729
- return network, weights_sd
730
-
731
-
732
- class LoRANetwork(torch.nn.Module):
733
- NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
734
-
735
- UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
736
- UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
737
- TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
738
- LORA_PREFIX_UNET = "lora_unet"
739
- LORA_PREFIX_TEXT_ENCODER = "lora_te"
740
-
741
- # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
742
- LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
743
- LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
744
-
745
- def __init__(
746
- self,
747
- text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
748
- unet,
749
  multiplier: float = 1.0,
750
  lora_dim: int = 4,
751
- alpha: float = 1,
752
  dropout: Optional[float] = None,
753
  rank_dropout: Optional[float] = None,
754
  module_dropout: Optional[float] = None,
755
- conv_lora_dim: Optional[int] = None,
756
- conv_alpha: Optional[float] = None,
757
- block_dims: Optional[List[int]] = None,
758
- block_alphas: Optional[List[float]] = None,
759
- conv_block_dims: Optional[List[int]] = None,
760
- conv_block_alphas: Optional[List[float]] = None,
761
- modules_dim: Optional[Dict[str, int]] = None,
762
- modules_alpha: Optional[Dict[str, int]] = None,
763
- module_class: Type[object] = LoRAModule,
764
- varbose: Optional[bool] = False,
765
- ) -> None:
766
  """
767
- LoRA network: すごく引数が多いが、パターンは以下の通り
768
- 1. lora_dimとalphaを指定
769
- 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
770
- 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
771
- 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
772
- 5. modules_dimとmodules_alphaを指定 (推論用)
 
 
 
773
  """
774
  super().__init__()
 
775
  self.multiplier = multiplier
776
-
777
  self.lora_dim = lora_dim
778
- self.alpha = alpha
779
- self.conv_lora_dim = conv_lora_dim
780
- self.conv_alpha = conv_alpha
781
  self.dropout = dropout
782
  self.rank_dropout = rank_dropout
783
  self.module_dropout = module_dropout
784
 
785
- if modules_dim is not None:
786
- print(f"create LoRA network from weights")
787
- elif block_dims is not None:
788
- print(f"create LoRA network from block_dims")
789
- print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
790
- print(f"block_dims: {block_dims}")
791
- print(f"block_alphas: {block_alphas}")
792
- if conv_block_dims is not None:
793
- print(f"conv_block_dims: {conv_block_dims}")
794
- print(f"conv_block_alphas: {conv_block_alphas}")
795
- else:
796
- print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
797
- print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
798
- if self.conv_lora_dim is not None:
799
- print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
800
-
801
- # create module instances
802
- def create_modules(
803
- is_unet: bool,
804
- text_encoder_idx: Optional[int], # None, 1, 2
805
- root_module: torch.nn.Module,
806
- target_replace_modules: List[torch.nn.Module],
807
- ) -> List[LoRAModule]:
808
- prefix = (
809
- self.LORA_PREFIX_UNET
810
- if is_unet
811
- else (
812
- self.LORA_PREFIX_TEXT_ENCODER
813
- if text_encoder_idx is None
814
- else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
815
- )
816
- )
817
- loras = []
818
- skipped = []
819
- for name, module in root_module.named_modules():
820
- if module.__class__.__name__ in target_replace_modules:
821
- for child_name, child_module in module.named_modules():
822
- is_linear = child_module.__class__.__name__ == "Linear"
823
- is_conv2d = child_module.__class__.__name__ == "Conv2d"
824
- is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
825
-
826
- if is_linear or is_conv2d:
827
- lora_name = prefix + "." + name + "." + child_name
828
- lora_name = lora_name.replace(".", "_")
829
-
830
- dim = None
831
- alpha = None
832
-
833
- if modules_dim is not None:
834
- # モジュール指定あり
835
- if lora_name in modules_dim:
836
- dim = modules_dim[lora_name]
837
- alpha = modules_alpha[lora_name]
838
- elif is_unet and block_dims is not None:
839
- # U-Netでblock_dims指定あり
840
- block_idx = get_block_index(lora_name)
841
- if is_linear or is_conv2d_1x1:
842
- dim = block_dims[block_idx]
843
- alpha = block_alphas[block_idx]
844
- elif conv_block_dims is not None:
845
- dim = conv_block_dims[block_idx]
846
- alpha = conv_block_alphas[block_idx]
847
- else:
848
- # 通常、すべて対象とする
849
- if is_linear or is_conv2d_1x1:
850
- dim = self.lora_dim
851
- alpha = self.alpha
852
- elif self.conv_lora_dim is not None:
853
- dim = self.conv_lora_dim
854
- alpha = self.conv_alpha
855
-
856
- if dim is None or dim == 0:
857
- # skipした情報を出力
858
- if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
859
- skipped.append(lora_name)
860
- continue
861
-
862
- lora = module_class(
863
- lora_name,
864
- child_module,
865
- self.multiplier,
866
- dim,
867
- alpha,
868
- dropout=dropout,
869
- rank_dropout=rank_dropout,
870
- module_dropout=module_dropout,
871
- )
872
- loras.append(lora)
873
- return loras, skipped
874
-
875
- text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
876
- print(text_encoders)
877
- # create LoRA for text encoder
878
- # 毎回すべてのモジュールを作るのは無駄なので要検討
879
- self.text_encoder_loras = []
880
- skipped_te = []
881
- for i, text_encoder in enumerate(text_encoders):
882
- if len(text_encoders) > 1:
883
- index = i + 1
884
- print(f"create LoRA for Text Encoder {index}:")
885
- else:
886
- index = None
887
- print(f"create LoRA for Text Encoder:")
888
-
889
- print(text_encoder)
890
- text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
891
- self.text_encoder_loras.extend(text_encoder_loras)
892
- skipped_te += skipped
893
- print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
894
-
895
- # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
896
- target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
897
- if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
898
- target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
899
-
900
- self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
901
- print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
902
 
903
- skipped = skipped_te + skipped_un
904
- if varbose and len(skipped) > 0:
905
- print(
906
- f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
907
- )
908
- for name in skipped:
909
- print(f"\t{name}")
910
-
911
- self.up_lr_weight: List[float] = None
912
- self.down_lr_weight: List[float] = None
913
- self.mid_lr_weight: float = None
914
- self.block_lr = False
915
-
916
- # assertion
917
- names = set()
918
- for lora in self.text_encoder_loras + self.unet_loras:
919
- assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
920
- names.add(lora.lora_name)
921
-
922
- def set_multiplier(self, multiplier):
923
- self.multiplier = multiplier
924
- for lora in self.text_encoder_loras + self.unet_loras:
925
- lora.multiplier = self.multiplier
926
-
927
- def load_weights(self, file):
928
- if os.path.splitext(file)[1] == ".safetensors":
929
- from safetensors.torch import load_file
930
-
931
- weights_sd = load_file(file)
932
- else:
933
- weights_sd = torch.load(file, map_location="cpu")
934
- info = self.load_state_dict(weights_sd, False)
935
- return info
936
-
937
- def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
938
- if apply_text_encoder:
939
- print("enable LoRA for text encoder")
940
- else:
941
- self.text_encoder_loras = []
942
-
943
- if apply_unet:
944
- print("enable LoRA for U-Net")
945
- else:
946
- self.unet_loras = []
947
-
948
- for lora in self.text_encoder_loras + self.unet_loras:
949
- lora.apply_to()
950
- self.add_module(lora.lora_name, lora)
951
-
952
- # マージできるかどうかを返す
953
- def is_mergeable(self):
954
- return True
955
-
956
- # TODO refactor to common function with apply_to
957
- def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
958
- apply_text_encoder = apply_unet = False
959
- for key in weights_sd.keys():
960
- if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
961
- apply_text_encoder = True
962
- elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
963
- apply_unet = True
964
-
965
- if apply_text_encoder:
966
- print("enable LoRA for text encoder")
967
- else:
968
- self.text_encoder_loras = []
969
-
970
- if apply_unet:
971
- print("enable LoRA for U-Net")
972
- else:
973
- self.unet_loras = []
974
-
975
- for lora in self.text_encoder_loras + self.unet_loras:
976
- sd_for_lora = {}
977
- for key in weights_sd.keys():
978
- if key.startswith(lora.lora_name):
979
- sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
980
- lora.merge_to(sd_for_lora, dtype, device)
981
-
982
- print(f"weights are merged")
983
-
984
- # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
985
- def set_block_lr_weight(
986
- self,
987
- up_lr_weight: List[float] = None,
988
- mid_lr_weight: float = None,
989
- down_lr_weight: List[float] = None,
990
- ):
991
- self.block_lr = True
992
- self.down_lr_weight = down_lr_weight
993
- self.mid_lr_weight = mid_lr_weight
994
- self.up_lr_weight = up_lr_weight
995
-
996
- def get_lr_weight(self, lora: LoRAModule) -> float:
997
- lr_weight = 1.0
998
- block_idx = get_block_index(lora.lora_name)
999
- if block_idx < 0:
1000
- return lr_weight
1001
-
1002
- if block_idx < LoRANetwork.NUM_OF_BLOCKS:
1003
- if self.down_lr_weight != None:
1004
- lr_weight = self.down_lr_weight[block_idx]
1005
- elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
1006
- if self.mid_lr_weight != None:
1007
- lr_weight = self.mid_lr_weight
1008
- elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
1009
- if self.up_lr_weight != None:
1010
- lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
1011
-
1012
- return lr_weight
1013
-
1014
- # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1015
- def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1016
- self.requires_grad_(True)
1017
- all_params = []
1018
-
1019
- def enumerate_params(loras):
1020
- params = []
1021
- for lora in loras:
1022
- params.extend(lora.parameters())
1023
- return params
1024
-
1025
- if self.text_encoder_loras:
1026
- param_data = {"params": enumerate_params(self.text_encoder_loras)}
1027
- if text_encoder_lr is not None:
1028
- param_data["lr"] = text_encoder_lr
1029
- all_params.append(param_data)
1030
-
1031
- if self.unet_loras:
1032
- if self.block_lr:
1033
- # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1034
- block_idx_to_lora = {}
1035
- for lora in self.unet_loras:
1036
- idx = get_block_index(lora.lora_name)
1037
- if idx not in block_idx_to_lora:
1038
- block_idx_to_lora[idx] = []
1039
- block_idx_to_lora[idx].append(lora)
1040
-
1041
- # blockごとにパラメータを設定する
1042
- for idx, block_loras in block_idx_to_lora.items():
1043
- param_data = {"params": enumerate_params(block_loras)}
1044
-
1045
- if unet_lr is not None:
1046
- param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1047
- elif default_lr is not None:
1048
- param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1049
- if ("lr" in param_data) and (param_data["lr"] == 0):
1050
- continue
1051
- all_params.append(param_data)
1052
-
1053
- else:
1054
- param_data = {"params": enumerate_params(self.unet_loras)}
1055
- if unet_lr is not None:
1056
- param_data["lr"] = unet_lr
1057
- all_params.append(param_data)
1058
-
1059
- return all_params
1060
-
1061
- def enable_gradient_checkpointing(self):
1062
- # not supported
1063
- pass
1064
-
1065
- def prepare_grad_etc(self, text_encoder, unet):
1066
- self.requires_grad_(True)
1067
-
1068
- def on_epoch_start(self, text_encoder, unet):
1069
- self.train()
1070
-
1071
- def get_trainable_params(self):
1072
- return self.parameters()
1073
-
1074
- def save_weights(self, file, dtype, metadata):
1075
- if metadata is not None and len(metadata) == 0:
1076
- metadata = None
1077
-
1078
- state_dict = self.state_dict()
1079
-
1080
- if dtype is not None:
1081
- for key in list(state_dict.keys()):
1082
- v = state_dict[key]
1083
- v = v.detach().clone().to("cpu").to(dtype)
1084
- state_dict[key] = v
1085
-
1086
- if os.path.splitext(file)[1] == ".safetensors":
1087
- from safetensors.torch import save_file
1088
- from library import train_util
1089
-
1090
- # Precalculate model hashes to save time on indexing
1091
- if metadata is None:
1092
- metadata = {}
1093
- model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1094
- metadata["sshs_model_hash"] = model_hash
1095
- metadata["sshs_legacy_hash"] = legacy_hash
1096
-
1097
- save_file(state_dict, file, metadata)
1098
  else:
1099
- torch.save(state_dict, file)
1100
-
1101
- # mask is a tensor with values from 0 to 1
1102
- def set_region(self, sub_prompt_index, is_last_network, mask):
1103
- if mask.max() == 0:
1104
- mask = torch.ones_like(mask)
1105
-
1106
- self.mask = mask
1107
- self.sub_prompt_index = sub_prompt_index
1108
- self.is_last_network = is_last_network
1109
-
1110
- for lora in self.text_encoder_loras + self.unet_loras:
1111
- lora.set_network(self)
1112
 
1113
- def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1114
- self.batch_size = batch_size
1115
- self.num_sub_prompts = num_sub_prompts
1116
- self.current_size = (height, width)
1117
- self.shared = shared
1118
 
1119
- # create masks
1120
- mask = self.mask
1121
- mask_dic = {}
1122
- mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1123
- ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1124
- dtype = ref_weight.dtype
1125
- device = ref_weight.device
1126
 
1127
- def resize_add(mh, mw):
1128
- # print(mh, mw, mh * mw)
1129
- m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1130
- m = m.to(device, dtype=dtype)
1131
- mask_dic[mh * mw] = m
1132
 
1133
- h = height // 8
1134
- w = width // 8
1135
- for _ in range(4):
1136
- resize_add(h, w)
1137
- if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1138
- resize_add(h + h % 2, w + w % 2)
1139
- h = (h + 1) // 2
1140
- w = (w + 1) // 2
1141
-
1142
- self.mask_dic = mask_dic
1143
-
1144
- def backup_weights(self):
1145
- # 重みのバックアップを行う
1146
- loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1147
- for lora in loras:
1148
- org_module = lora.org_module_ref[0]
1149
- if not hasattr(org_module, "_lora_org_weight"):
1150
- sd = org_module.state_dict()
1151
- org_module._lora_org_weight = sd["weight"].detach().clone()
1152
- org_module._lora_restored = True
1153
-
1154
- def restore_weights(self):
1155
- # 重みのリストアを行う
1156
- loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1157
- for lora in loras:
1158
- org_module = lora.org_module_ref[0]
1159
- if not org_module._lora_restored:
1160
- sd = org_module.state_dict()
1161
- sd["weight"] = org_module._lora_org_weight
1162
- org_module.load_state_dict(sd)
1163
- org_module._lora_restored = True
1164
-
1165
- def pre_calculation(self):
1166
- # 事前計算を行う
1167
- loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1168
- for lora in loras:
1169
- org_module = lora.org_module_ref[0]
1170
- sd = org_module.state_dict()
1171
-
1172
- org_weight = sd["weight"]
1173
- lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1174
- sd["weight"] = org_weight + lora_weight
1175
- assert sd["weight"].shape == org_weight.shape
1176
- org_module.load_state_dict(sd)
1177
-
1178
- org_module._lora_restored = False
1179
- lora.enabled = False
1180
-
1181
- def apply_max_norm_regularization(self, max_norm_value, device):
1182
- downkeys = []
1183
- upkeys = []
1184
- alphakeys = []
1185
- norms = []
1186
- keys_scaled = 0
1187
-
1188
- state_dict = self.state_dict()
1189
- for key in state_dict.keys():
1190
- if "lora_down" in key and "weight" in key:
1191
- downkeys.append(key)
1192
- upkeys.append(key.replace("lora_down", "lora_up"))
1193
- alphakeys.append(key.replace("lora_down.weight", "alpha"))
1194
 
1195
- for i in range(len(downkeys)):
1196
- down = state_dict[downkeys[i]].to(device)
1197
- up = state_dict[upkeys[i]].to(device)
1198
- alpha = state_dict[alphakeys[i]].to(device)
1199
- dim = down.shape[0]
1200
- scale = alpha / dim
1201
 
1202
- if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1203
- updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1204
- elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1205
- updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1206
- else:
1207
- updown = up @ down
1208
 
1209
- updown *= scale
 
 
 
 
 
 
 
 
1210
 
1211
- norm = updown.norm().clamp(min=max_norm_value / 2)
1212
- desired = torch.clamp(norm, max=max_norm_value)
1213
- ratio = desired.cpu() / norm.cpu()
1214
- sqrt_ratio = ratio**0.5
1215
- if ratio != 1:
1216
- keys_scaled += 1
1217
- state_dict[upkeys[i]] *= sqrt_ratio
1218
- state_dict[downkeys[i]] *= sqrt_ratio
1219
- scalednorm = updown.norm() * ratio
1220
- norms.append(scalednorm.item())
1221
 
1222
- return keys_scaled, sum(norms) / len(norms), max(norms)
 
 
 
 
 
 
 
1
  import math
 
 
 
 
 
2
  import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
 
 
 
 
5
 
6
+ class LoRAModule(nn.Module):
 
7
  """
8
+ LoRA module that replaces the forward method of an original Linear or Conv2D module.
9
  """
10
 
11
  def __init__(
12
  self,
13
+ lora_name: str,
14
+ org_module: nn.Module,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  multiplier: float = 1.0,
16
  lora_dim: int = 4,
17
+ alpha: Optional[float] = None,
18
  dropout: Optional[float] = None,
19
  rank_dropout: Optional[float] = None,
20
  module_dropout: Optional[float] = None,
21
+ ):
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ Args:
24
+ lora_name (str): Name of the LoRA module.
25
+ org_module (nn.Module): The original module to wrap.
26
+ multiplier (float): Scaling factor for the LoRA output.
27
+ lora_dim (int): The rank of the LoRA decomposition.
28
+ alpha (float, optional): Scaling factor for LoRA weights. Defaults to lora_dim.
29
+ dropout (float, optional): Dropout probability. Defaults to None.
30
+ rank_dropout (float, optional): Dropout probability for rank reduction. Defaults to None.
31
+ module_dropout (float, optional): Probability of completely dropping the module during training. Defaults to None.
32
  """
33
  super().__init__()
34
+ self.lora_name = lora_name
35
  self.multiplier = multiplier
 
36
  self.lora_dim = lora_dim
 
 
 
37
  self.dropout = dropout
38
  self.rank_dropout = rank_dropout
39
  self.module_dropout = module_dropout
40
 
41
+ # Determine layer type (Linear or Conv2D)
42
+ is_conv2d = isinstance(org_module, nn.Conv2d)
43
+ in_dim = org_module.in_channels if is_conv2d else org_module.in_features
44
+ out_dim = org_module.out_channels if is_conv2d else org_module.out_features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Define LoRA layers
47
+ if is_conv2d:
48
+ self.lora_down = nn.Conv2d(in_dim, lora_dim, kernel_size=org_module.kernel_size,
49
+ stride=org_module.stride, padding=org_module.padding, bias=False)
50
+ self.lora_up = nn.Conv2d(lora_dim, out_dim, kernel_size=1, stride=1, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
53
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Initialize weights
56
+ nn.init.xavier_uniform_(self.lora_down.weight)
57
+ nn.init.zeros_(self.lora_up.weight)
 
 
58
 
59
+ # Set alpha scaling factor
60
+ self.scale = (alpha if alpha is not None else lora_dim) / lora_dim
61
+ self.register_buffer("alpha", torch.tensor(self.scale, dtype=torch.float32))
 
 
 
 
62
 
63
+ # Store reference to the original module
64
+ self.org_module = org_module
65
+ self.org_forward = org_module.forward
 
 
66
 
67
+ def apply_to(self):
68
+ """Replace the forward method of the original module with this module's forward method."""
69
+ self.org_module.forward = self.forward
70
+ del self.org_module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ def forward(self, x):
73
+ """
74
+ Forward pass for LoRA-enhanced module.
75
+ """
76
+ if self.module_dropout and self.training and torch.rand(1).item() < self.module_dropout:
77
+ return self.org_forward(x)
78
 
79
+ # Compute LoRA down projection
80
+ lora_output = self.lora_down(x)
 
 
 
 
81
 
82
+ # Apply dropout if training
83
+ if self.training:
84
+ if self.dropout:
85
+ lora_output = F.dropout(lora_output, p=self.dropout)
86
+ if self.rank_dropout:
87
+ dropout_mask = torch.rand_like(lora_output) > self.rank_dropout
88
+ lora_output *= dropout_mask
89
+ scale_factor = 1.0 / (1.0 - self.rank_dropout)
90
+ lora_output *= scale_factor
91
 
92
+ # Compute LoRA up projection
93
+ lora_output = self.lora_up(lora_output)
 
 
 
 
 
 
 
 
94
 
95
+ # Combine with original output
96
+ return self.org_forward(x) + lora_output * self.multiplier * self.scale