Meismaxandmaxisme commited on
Commit
0c4ab48
·
verified ·
1 Parent(s): cef3212

Upload 5 files

Browse files
src/backend/upscale/aura_sr.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
2
+ # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
3
+ #
4
+ # https://mingukkang.github.io/GigaGAN/
5
+ from math import log2, ceil
6
+ from functools import partial
7
+ from typing import Any, Optional, List, Iterable
8
+
9
+ import torch
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ from torch import nn, einsum, Tensor
13
+ import torch.nn.functional as F
14
+
15
+ from einops import rearrange, repeat, reduce
16
+ from einops.layers.torch import Rearrange
17
+ from torchvision.utils import save_image
18
+ import math
19
+
20
+
21
+ def get_same_padding(size, kernel, dilation, stride):
22
+ return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
23
+
24
+
25
+ class AdaptiveConv2DMod(nn.Module):
26
+ def __init__(
27
+ self,
28
+ dim,
29
+ dim_out,
30
+ kernel,
31
+ *,
32
+ demod=True,
33
+ stride=1,
34
+ dilation=1,
35
+ eps=1e-8,
36
+ num_conv_kernels=1, # set this to be greater than 1 for adaptive
37
+ ):
38
+ super().__init__()
39
+ self.eps = eps
40
+
41
+ self.dim_out = dim_out
42
+
43
+ self.kernel = kernel
44
+ self.stride = stride
45
+ self.dilation = dilation
46
+ self.adaptive = num_conv_kernels > 1
47
+
48
+ self.weights = nn.Parameter(
49
+ torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
50
+ )
51
+
52
+ self.demod = demod
53
+
54
+ nn.init.kaiming_normal_(
55
+ self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
56
+ )
57
+
58
+ def forward(
59
+ self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
60
+ ):
61
+ """
62
+ notation
63
+
64
+ b - batch
65
+ n - convs
66
+ o - output
67
+ i - input
68
+ k - kernel
69
+ """
70
+
71
+ b, h = fmap.shape[0], fmap.shape[-2]
72
+
73
+ # account for feature map that has been expanded by the scale in the first dimension
74
+ # due to multiscale inputs and outputs
75
+
76
+ if mod.shape[0] != b:
77
+ mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
78
+
79
+ if exists(kernel_mod):
80
+ kernel_mod_has_el = kernel_mod.numel() > 0
81
+
82
+ assert self.adaptive or not kernel_mod_has_el
83
+
84
+ if kernel_mod_has_el and kernel_mod.shape[0] != b:
85
+ kernel_mod = repeat(
86
+ kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
87
+ )
88
+
89
+ # prepare weights for modulation
90
+
91
+ weights = self.weights
92
+
93
+ if self.adaptive:
94
+ weights = repeat(weights, "... -> b ...", b=b)
95
+
96
+ # determine an adaptive weight and 'select' the kernel to use with softmax
97
+
98
+ assert exists(kernel_mod) and kernel_mod.numel() > 0
99
+
100
+ kernel_attn = kernel_mod.softmax(dim=-1)
101
+ kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
102
+
103
+ weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
104
+
105
+ # do the modulation, demodulation, as done in stylegan2
106
+
107
+ mod = rearrange(mod, "b i -> b 1 i 1 1")
108
+
109
+ weights = weights * (mod + 1)
110
+
111
+ if self.demod:
112
+ inv_norm = (
113
+ reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
114
+ .clamp(min=self.eps)
115
+ .rsqrt()
116
+ )
117
+ weights = weights * inv_norm
118
+
119
+ fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
120
+
121
+ weights = rearrange(weights, "b o ... -> (b o) ...")
122
+
123
+ padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
124
+ fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
125
+
126
+ return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
127
+
128
+
129
+ class Attend(nn.Module):
130
+ def __init__(self, dropout=0.0, flash=False):
131
+ super().__init__()
132
+ self.dropout = dropout
133
+ self.attn_dropout = nn.Dropout(dropout)
134
+ self.scale = nn.Parameter(torch.randn(1))
135
+ self.flash = flash
136
+
137
+ def flash_attn(self, q, k, v):
138
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
139
+ out = F.scaled_dot_product_attention(
140
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
141
+ )
142
+ return out
143
+
144
+ def forward(self, q, k, v):
145
+ if self.flash:
146
+ return self.flash_attn(q, k, v)
147
+
148
+ scale = q.shape[-1] ** -0.5
149
+
150
+ # similarity
151
+ sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
152
+
153
+ # attention
154
+ attn = sim.softmax(dim=-1)
155
+ attn = self.attn_dropout(attn)
156
+
157
+ # aggregate values
158
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
159
+
160
+ return out
161
+
162
+
163
+ def exists(x):
164
+ return x is not None
165
+
166
+
167
+ def default(val, d):
168
+ if exists(val):
169
+ return val
170
+ return d() if callable(d) else d
171
+
172
+
173
+ def cast_tuple(t, length=1):
174
+ if isinstance(t, tuple):
175
+ return t
176
+ return (t,) * length
177
+
178
+
179
+ def identity(t, *args, **kwargs):
180
+ return t
181
+
182
+
183
+ def is_power_of_two(n):
184
+ return log2(n).is_integer()
185
+
186
+
187
+ def null_iterator():
188
+ while True:
189
+ yield None
190
+
191
+
192
+ def Downsample(dim, dim_out=None):
193
+ return nn.Sequential(
194
+ Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
195
+ nn.Conv2d(dim * 4, default(dim_out, dim), 1),
196
+ )
197
+
198
+
199
+ class RMSNorm(nn.Module):
200
+ def __init__(self, dim):
201
+ super().__init__()
202
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
203
+ self.eps = 1e-4
204
+
205
+ def forward(self, x):
206
+ return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
207
+
208
+
209
+ # building block modules
210
+
211
+
212
+ class Block(nn.Module):
213
+ def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
214
+ super().__init__()
215
+ self.proj = AdaptiveConv2DMod(
216
+ dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
217
+ )
218
+ self.kernel = 3
219
+ self.dilation = 1
220
+ self.stride = 1
221
+
222
+ self.act = nn.SiLU()
223
+
224
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
225
+ conv_mods_iter = default(conv_mods_iter, null_iterator())
226
+
227
+ x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
228
+
229
+ x = self.act(x)
230
+ return x
231
+
232
+
233
+ class ResnetBlock(nn.Module):
234
+ def __init__(
235
+ self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
236
+ ):
237
+ super().__init__()
238
+ style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
239
+
240
+ self.block1 = Block(
241
+ dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
242
+ )
243
+ self.block2 = Block(
244
+ dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
245
+ )
246
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
247
+
248
+ def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
249
+ h = self.block1(x, conv_mods_iter=conv_mods_iter)
250
+ h = self.block2(h, conv_mods_iter=conv_mods_iter)
251
+
252
+ return h + self.res_conv(x)
253
+
254
+
255
+ class LinearAttention(nn.Module):
256
+ def __init__(self, dim, heads=4, dim_head=32):
257
+ super().__init__()
258
+ self.scale = dim_head**-0.5
259
+ self.heads = heads
260
+ hidden_dim = dim_head * heads
261
+
262
+ self.norm = RMSNorm(dim)
263
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
264
+
265
+ self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
266
+
267
+ def forward(self, x):
268
+ b, c, h, w = x.shape
269
+
270
+ x = self.norm(x)
271
+
272
+ qkv = self.to_qkv(x).chunk(3, dim=1)
273
+ q, k, v = map(
274
+ lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
275
+ )
276
+
277
+ q = q.softmax(dim=-2)
278
+ k = k.softmax(dim=-1)
279
+
280
+ q = q * self.scale
281
+
282
+ context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
283
+
284
+ out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
285
+ out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
286
+ return self.to_out(out)
287
+
288
+
289
+ class Attention(nn.Module):
290
+ def __init__(self, dim, heads=4, dim_head=32, flash=False):
291
+ super().__init__()
292
+ self.heads = heads
293
+ hidden_dim = dim_head * heads
294
+
295
+ self.norm = RMSNorm(dim)
296
+
297
+ self.attend = Attend(flash=flash)
298
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
299
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
300
+
301
+ def forward(self, x):
302
+ b, c, h, w = x.shape
303
+ x = self.norm(x)
304
+ qkv = self.to_qkv(x).chunk(3, dim=1)
305
+
306
+ q, k, v = map(
307
+ lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
308
+ )
309
+
310
+ out = self.attend(q, k, v)
311
+ out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
312
+
313
+ return self.to_out(out)
314
+
315
+
316
+ # feedforward
317
+ def FeedForward(dim, mult=4):
318
+ return nn.Sequential(
319
+ RMSNorm(dim),
320
+ nn.Conv2d(dim, dim * mult, 1),
321
+ nn.GELU(),
322
+ nn.Conv2d(dim * mult, dim, 1),
323
+ )
324
+
325
+
326
+ # transformers
327
+ class Transformer(nn.Module):
328
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
329
+ super().__init__()
330
+ self.layers = nn.ModuleList([])
331
+
332
+ for _ in range(depth):
333
+ self.layers.append(
334
+ nn.ModuleList(
335
+ [
336
+ Attention(
337
+ dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
338
+ ),
339
+ FeedForward(dim=dim, mult=ff_mult),
340
+ ]
341
+ )
342
+ )
343
+
344
+ def forward(self, x):
345
+ for attn, ff in self.layers:
346
+ x = attn(x) + x
347
+ x = ff(x) + x
348
+
349
+ return x
350
+
351
+
352
+ class LinearTransformer(nn.Module):
353
+ def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
354
+ super().__init__()
355
+ self.layers = nn.ModuleList([])
356
+
357
+ for _ in range(depth):
358
+ self.layers.append(
359
+ nn.ModuleList(
360
+ [
361
+ LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
362
+ FeedForward(dim=dim, mult=ff_mult),
363
+ ]
364
+ )
365
+ )
366
+
367
+ def forward(self, x):
368
+ for attn, ff in self.layers:
369
+ x = attn(x) + x
370
+ x = ff(x) + x
371
+
372
+ return x
373
+
374
+
375
+ class NearestNeighborhoodUpsample(nn.Module):
376
+ def __init__(self, dim, dim_out=None):
377
+ super().__init__()
378
+ dim_out = default(dim_out, dim)
379
+ self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
380
+
381
+ def forward(self, x):
382
+
383
+ if x.shape[0] >= 64:
384
+ x = x.contiguous()
385
+
386
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
387
+ x = self.conv(x)
388
+
389
+ return x
390
+
391
+
392
+ class EqualLinear(nn.Module):
393
+ def __init__(self, dim, dim_out, lr_mul=1, bias=True):
394
+ super().__init__()
395
+ self.weight = nn.Parameter(torch.randn(dim_out, dim))
396
+ if bias:
397
+ self.bias = nn.Parameter(torch.zeros(dim_out))
398
+
399
+ self.lr_mul = lr_mul
400
+
401
+ def forward(self, input):
402
+ return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
403
+
404
+
405
+ class StyleGanNetwork(nn.Module):
406
+ def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
407
+ super().__init__()
408
+ self.dim_in = dim_in
409
+ self.dim_out = dim_out
410
+ self.dim_text_latent = dim_text_latent
411
+
412
+ layers = []
413
+ for i in range(depth):
414
+ is_first = i == 0
415
+
416
+ if is_first:
417
+ dim_in_layer = dim_in + dim_text_latent
418
+ else:
419
+ dim_in_layer = dim_out
420
+
421
+ dim_out_layer = dim_out
422
+
423
+ layers.extend(
424
+ [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
425
+ )
426
+
427
+ self.net = nn.Sequential(*layers)
428
+
429
+ def forward(self, x, text_latent=None):
430
+ x = F.normalize(x, dim=1)
431
+ if self.dim_text_latent > 0:
432
+ assert exists(text_latent)
433
+ x = torch.cat((x, text_latent), dim=-1)
434
+ return self.net(x)
435
+
436
+
437
+ class UnetUpsampler(torch.nn.Module):
438
+
439
+ def __init__(
440
+ self,
441
+ dim: int,
442
+ *,
443
+ image_size: int,
444
+ input_image_size: int,
445
+ init_dim: Optional[int] = None,
446
+ out_dim: Optional[int] = None,
447
+ style_network: Optional[dict] = None,
448
+ up_dim_mults: tuple = (1, 2, 4, 8, 16),
449
+ down_dim_mults: tuple = (4, 8, 16),
450
+ channels: int = 3,
451
+ resnet_block_groups: int = 8,
452
+ full_attn: tuple = (False, False, False, True, True),
453
+ flash_attn: bool = True,
454
+ self_attn_dim_head: int = 64,
455
+ self_attn_heads: int = 8,
456
+ attn_depths: tuple = (2, 2, 2, 2, 4),
457
+ mid_attn_depth: int = 4,
458
+ num_conv_kernels: int = 4,
459
+ resize_mode: str = "bilinear",
460
+ unconditional: bool = True,
461
+ skip_connect_scale: Optional[float] = None,
462
+ ):
463
+ super().__init__()
464
+ self.style_network = style_network = StyleGanNetwork(**style_network)
465
+ self.unconditional = unconditional
466
+ assert not (
467
+ unconditional
468
+ and exists(style_network)
469
+ and style_network.dim_text_latent > 0
470
+ )
471
+
472
+ assert is_power_of_two(image_size) and is_power_of_two(
473
+ input_image_size
474
+ ), "both output image size and input image size must be power of 2"
475
+ assert (
476
+ input_image_size < image_size
477
+ ), "input image size must be smaller than the output image size, thus upsampling"
478
+
479
+ self.image_size = image_size
480
+ self.input_image_size = input_image_size
481
+
482
+ style_embed_split_dims = []
483
+
484
+ self.channels = channels
485
+ input_channels = channels
486
+
487
+ init_dim = default(init_dim, dim)
488
+
489
+ up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
490
+ init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
491
+ down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
492
+ self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
493
+
494
+ up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
495
+ down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
496
+
497
+ block_klass = partial(
498
+ ResnetBlock,
499
+ groups=resnet_block_groups,
500
+ num_conv_kernels=num_conv_kernels,
501
+ style_dims=style_embed_split_dims,
502
+ )
503
+
504
+ FullAttention = partial(Transformer, flash_attn=flash_attn)
505
+ *_, mid_dim = up_dims
506
+
507
+ self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
508
+
509
+ self.downs = nn.ModuleList([])
510
+ self.ups = nn.ModuleList([])
511
+
512
+ block_count = 6
513
+
514
+ for ind, (
515
+ (dim_in, dim_out),
516
+ layer_full_attn,
517
+ layer_attn_depth,
518
+ ) in enumerate(zip(down_in_out, full_attn, attn_depths)):
519
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
520
+
521
+ blocks = []
522
+ for i in range(block_count):
523
+ blocks.append(block_klass(dim_in, dim_in))
524
+
525
+ self.downs.append(
526
+ nn.ModuleList(
527
+ [
528
+ nn.ModuleList(blocks),
529
+ nn.ModuleList(
530
+ [
531
+ (
532
+ attn_klass(
533
+ dim_in,
534
+ dim_head=self_attn_dim_head,
535
+ heads=self_attn_heads,
536
+ depth=layer_attn_depth,
537
+ )
538
+ if layer_full_attn
539
+ else None
540
+ ),
541
+ nn.Conv2d(
542
+ dim_in, dim_out, kernel_size=3, stride=2, padding=1
543
+ ),
544
+ ]
545
+ ),
546
+ ]
547
+ )
548
+ )
549
+
550
+ self.mid_block1 = block_klass(mid_dim, mid_dim)
551
+ self.mid_attn = FullAttention(
552
+ mid_dim,
553
+ dim_head=self_attn_dim_head,
554
+ heads=self_attn_heads,
555
+ depth=mid_attn_depth,
556
+ )
557
+ self.mid_block2 = block_klass(mid_dim, mid_dim)
558
+
559
+ *_, last_dim = up_dims
560
+
561
+ for ind, (
562
+ (dim_in, dim_out),
563
+ layer_full_attn,
564
+ layer_attn_depth,
565
+ ) in enumerate(
566
+ zip(
567
+ reversed(up_in_out),
568
+ reversed(full_attn),
569
+ reversed(attn_depths),
570
+ )
571
+ ):
572
+ attn_klass = FullAttention if layer_full_attn else LinearTransformer
573
+
574
+ blocks = []
575
+ input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
576
+ for i in range(block_count):
577
+ blocks.append(block_klass(input_dim, dim_in))
578
+
579
+ self.ups.append(
580
+ nn.ModuleList(
581
+ [
582
+ nn.ModuleList(blocks),
583
+ nn.ModuleList(
584
+ [
585
+ NearestNeighborhoodUpsample(
586
+ last_dim if ind == 0 else dim_out,
587
+ dim_in,
588
+ ),
589
+ (
590
+ attn_klass(
591
+ dim_in,
592
+ dim_head=self_attn_dim_head,
593
+ heads=self_attn_heads,
594
+ depth=layer_attn_depth,
595
+ )
596
+ if layer_full_attn
597
+ else None
598
+ ),
599
+ ]
600
+ ),
601
+ ]
602
+ )
603
+ )
604
+
605
+ self.out_dim = default(out_dim, channels)
606
+ self.final_res_block = block_klass(dim, dim)
607
+ self.final_to_rgb = nn.Conv2d(dim, channels, 1)
608
+ self.resize_mode = resize_mode
609
+ self.style_to_conv_modulations = nn.Linear(
610
+ style_network.dim_out, sum(style_embed_split_dims)
611
+ )
612
+ self.style_embed_split_dims = style_embed_split_dims
613
+
614
+ @property
615
+ def allowable_rgb_resolutions(self):
616
+ input_res_base = int(log2(self.input_image_size))
617
+ output_res_base = int(log2(self.image_size))
618
+ allowed_rgb_res_base = list(range(input_res_base, output_res_base))
619
+ return [*map(lambda p: 2**p, allowed_rgb_res_base)]
620
+
621
+ @property
622
+ def device(self):
623
+ return next(self.parameters()).device
624
+
625
+ @property
626
+ def total_params(self):
627
+ return sum([p.numel() for p in self.parameters()])
628
+
629
+ def resize_image_to(self, x, size):
630
+ return F.interpolate(x, (size, size), mode=self.resize_mode)
631
+
632
+ def forward(
633
+ self,
634
+ lowres_image: torch.Tensor,
635
+ styles: Optional[torch.Tensor] = None,
636
+ noise: Optional[torch.Tensor] = None,
637
+ global_text_tokens: Optional[torch.Tensor] = None,
638
+ return_all_rgbs: bool = False,
639
+ ):
640
+ x = lowres_image
641
+
642
+ noise_scale = 0.001 # Adjust the scale of the noise as needed
643
+ noise_aug = torch.randn_like(x) * noise_scale
644
+ x = x + noise_aug
645
+ x = x.clamp(0, 1)
646
+
647
+ shape = x.shape
648
+ batch_size = shape[0]
649
+
650
+ assert shape[-2:] == ((self.input_image_size,) * 2)
651
+
652
+ # styles
653
+ if not exists(styles):
654
+ assert exists(self.style_network)
655
+
656
+ noise = default(
657
+ noise,
658
+ torch.randn(
659
+ (batch_size, self.style_network.dim_in), device=self.device
660
+ ),
661
+ )
662
+ styles = self.style_network(noise, global_text_tokens)
663
+
664
+ # project styles to conv modulations
665
+ conv_mods = self.style_to_conv_modulations(styles)
666
+ conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
667
+ conv_mods = iter(conv_mods)
668
+
669
+ x = self.init_conv(x)
670
+
671
+ h = []
672
+ for blocks, (attn, downsample) in self.downs:
673
+ for block in blocks:
674
+ x = block(x, conv_mods_iter=conv_mods)
675
+ h.append(x)
676
+
677
+ if attn is not None:
678
+ x = attn(x)
679
+
680
+ x = downsample(x)
681
+
682
+ x = self.mid_block1(x, conv_mods_iter=conv_mods)
683
+ x = self.mid_attn(x)
684
+ x = self.mid_block2(x, conv_mods_iter=conv_mods)
685
+
686
+ for (
687
+ blocks,
688
+ (
689
+ upsample,
690
+ attn,
691
+ ),
692
+ ) in self.ups:
693
+ x = upsample(x)
694
+ for block in blocks:
695
+ if h != []:
696
+ res = h.pop()
697
+ res = res * self.skip_connect_scale
698
+ x = torch.cat((x, res), dim=1)
699
+
700
+ x = block(x, conv_mods_iter=conv_mods)
701
+
702
+ if attn is not None:
703
+ x = attn(x)
704
+
705
+ x = self.final_res_block(x, conv_mods_iter=conv_mods)
706
+ rgb = self.final_to_rgb(x)
707
+
708
+ if not return_all_rgbs:
709
+ return rgb
710
+
711
+ return rgb, []
712
+
713
+
714
+ def tile_image(image, chunk_size=64):
715
+ c, h, w = image.shape
716
+ h_chunks = ceil(h / chunk_size)
717
+ w_chunks = ceil(w / chunk_size)
718
+ tiles = []
719
+ for i in range(h_chunks):
720
+ for j in range(w_chunks):
721
+ tile = image[
722
+ :,
723
+ i * chunk_size : (i + 1) * chunk_size,
724
+ j * chunk_size : (j + 1) * chunk_size,
725
+ ]
726
+ tiles.append(tile)
727
+ return tiles, h_chunks, w_chunks
728
+
729
+
730
+ # This helps create a checkboard pattern with some edge blending
731
+ def create_checkerboard_weights(tile_size):
732
+ x = torch.linspace(-1, 1, tile_size)
733
+ y = torch.linspace(-1, 1, tile_size)
734
+
735
+ x, y = torch.meshgrid(x, y, indexing="ij")
736
+ d = torch.sqrt(x * x + y * y)
737
+ sigma, mu = 0.5, 0.0
738
+ weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))
739
+
740
+ # saturate the values to sure get high weights in the center
741
+ weights = weights**8
742
+
743
+ return weights / weights.max() # Normalize to [0, 1]
744
+
745
+
746
+ def repeat_weights(weights, image_size):
747
+ tile_size = weights.shape[0]
748
+ repeats = (
749
+ math.ceil(image_size[0] / tile_size),
750
+ math.ceil(image_size[1] / tile_size),
751
+ )
752
+ return weights.repeat(repeats)[: image_size[0], : image_size[1]]
753
+
754
+
755
+ def create_offset_weights(weights, image_size):
756
+ tile_size = weights.shape[0]
757
+ offset = tile_size // 2
758
+ full_weights = repeat_weights(
759
+ weights, (image_size[0] + offset, image_size[1] + offset)
760
+ )
761
+ return full_weights[offset:, offset:]
762
+
763
+
764
+ def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
765
+ # Determine the shape of the output tensor
766
+ c = tiles[0].shape[0]
767
+ h = h_chunks * chunk_size
768
+ w = w_chunks * chunk_size
769
+
770
+ # Create an empty tensor to hold the merged image
771
+ merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
772
+
773
+ # Iterate over the tiles and place them in the correct position
774
+ for idx, tile in enumerate(tiles):
775
+ i = idx // w_chunks
776
+ j = idx % w_chunks
777
+
778
+ h_start = i * chunk_size
779
+ w_start = j * chunk_size
780
+
781
+ tile_h, tile_w = tile.shape[1:]
782
+ merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile
783
+
784
+ return merged
785
+
786
+
787
+ class AuraSR:
788
+ def __init__(self, config: dict[str, Any], device: str = "cuda"):
789
+ self.upsampler = UnetUpsampler(**config).to(device)
790
+ self.input_image_size = config["input_image_size"]
791
+
792
+ @classmethod
793
+ def from_pretrained(
794
+ cls,
795
+ model_id: str = "fal-ai/AuraSR",
796
+ use_safetensors: bool = True,
797
+ device: str = "cuda",
798
+ ):
799
+ import json
800
+ import torch
801
+ from pathlib import Path
802
+ from huggingface_hub import snapshot_download
803
+
804
+ # Check if model_id is a local file
805
+ if Path(model_id).is_file():
806
+ local_file = Path(model_id)
807
+ if local_file.suffix == ".safetensors":
808
+ use_safetensors = True
809
+ elif local_file.suffix == ".ckpt":
810
+ use_safetensors = False
811
+ else:
812
+ raise ValueError(
813
+ f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
814
+ )
815
+
816
+ # For local files, we need to provide the config separately
817
+ config_path = local_file.with_name("config.json")
818
+ if not config_path.exists():
819
+ raise FileNotFoundError(
820
+ f"Config file not found: {config_path}. "
821
+ f"When loading from a local file, ensure that 'config.json' "
822
+ f"is present in the same directory as '{local_file.name}'. "
823
+ f"If you're trying to load a model from Hugging Face, "
824
+ f"please provide the model ID instead of a file path."
825
+ )
826
+
827
+ config = json.loads(config_path.read_text())
828
+ hf_model_path = local_file.parent
829
+ else:
830
+ hf_model_path = Path(
831
+ snapshot_download(model_id, ignore_patterns=["*.ckpt"])
832
+ )
833
+ config = json.loads((hf_model_path / "config.json").read_text())
834
+
835
+ model = cls(config, device)
836
+
837
+ if use_safetensors:
838
+ try:
839
+ from safetensors.torch import load_file
840
+
841
+ checkpoint = load_file(
842
+ hf_model_path / "model.safetensors"
843
+ if not Path(model_id).is_file()
844
+ else model_id
845
+ )
846
+ except ImportError:
847
+ raise ImportError(
848
+ "The safetensors library is not installed. "
849
+ "Please install it with `pip install safetensors` "
850
+ "or use `use_safetensors=False` to load the model with PyTorch."
851
+ )
852
+ else:
853
+ checkpoint = torch.load(
854
+ hf_model_path / "model.ckpt"
855
+ if not Path(model_id).is_file()
856
+ else model_id
857
+ )
858
+
859
+ model.upsampler.load_state_dict(checkpoint, strict=True)
860
+ return model
861
+
862
+ @torch.no_grad()
863
+ def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
864
+ tensor_transform = transforms.ToTensor()
865
+ device = self.upsampler.device
866
+
867
+ image_tensor = tensor_transform(image).unsqueeze(0)
868
+ _, _, h, w = image_tensor.shape
869
+ pad_h = (
870
+ self.input_image_size - h % self.input_image_size
871
+ ) % self.input_image_size
872
+ pad_w = (
873
+ self.input_image_size - w % self.input_image_size
874
+ ) % self.input_image_size
875
+
876
+ # Pad the image
877
+ image_tensor = torch.nn.functional.pad(
878
+ image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
879
+ ).squeeze(0)
880
+ tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
881
+
882
+ # Batch processing of tiles
883
+ num_tiles = len(tiles)
884
+ batches = [
885
+ tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
886
+ ]
887
+ reconstructed_tiles = []
888
+
889
+ for batch in batches:
890
+ model_input = torch.stack(batch).to(device)
891
+ generator_output = self.upsampler(
892
+ lowres_image=model_input,
893
+ noise=torch.randn(model_input.shape[0], 128, device=device),
894
+ )
895
+ reconstructed_tiles.extend(
896
+ list(generator_output.clamp_(0, 1).detach().cpu())
897
+ )
898
+
899
+ merged_tensor = merge_tiles(
900
+ reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
901
+ )
902
+ unpadded = merged_tensor[:, : h * 4, : w * 4]
903
+
904
+ to_pil = transforms.ToPILImage()
905
+ return to_pil(unpadded)
906
+
907
+ # Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
908
+ # weights options are 'checkboard' and 'constant'
909
+ @torch.no_grad()
910
+ def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
911
+ tensor_transform = transforms.ToTensor()
912
+ device = self.upsampler.device
913
+
914
+ image_tensor = tensor_transform(image).unsqueeze(0)
915
+ _, _, h, w = image_tensor.shape
916
+
917
+ # Calculate paddings
918
+ pad_h = (
919
+ self.input_image_size - h % self.input_image_size
920
+ ) % self.input_image_size
921
+ pad_w = (
922
+ self.input_image_size - w % self.input_image_size
923
+ ) % self.input_image_size
924
+
925
+ # Pad the image
926
+ image_tensor = torch.nn.functional.pad(
927
+ image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
928
+ ).squeeze(0)
929
+
930
+ # Function to process tiles
931
+ def process_tiles(tiles, h_chunks, w_chunks):
932
+ num_tiles = len(tiles)
933
+ batches = [
934
+ tiles[i : i + max_batch_size]
935
+ for i in range(0, num_tiles, max_batch_size)
936
+ ]
937
+ reconstructed_tiles = []
938
+
939
+ for batch in batches:
940
+ model_input = torch.stack(batch).to(device)
941
+ generator_output = self.upsampler(
942
+ lowres_image=model_input,
943
+ noise=torch.randn(model_input.shape[0], 128, device=device),
944
+ )
945
+ reconstructed_tiles.extend(
946
+ list(generator_output.clamp_(0, 1).detach().cpu())
947
+ )
948
+
949
+ return merge_tiles(
950
+ reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
951
+ )
952
+
953
+ # First pass
954
+ tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
955
+ result1 = process_tiles(tiles1, h_chunks1, w_chunks1)
956
+
957
+ # Second pass with offset
958
+ offset = self.input_image_size // 2
959
+ image_tensor_offset = torch.nn.functional.pad(
960
+ image_tensor, (offset, offset, offset, offset), mode="reflect"
961
+ ).squeeze(0)
962
+
963
+ tiles2, h_chunks2, w_chunks2 = tile_image(
964
+ image_tensor_offset, self.input_image_size
965
+ )
966
+ result2 = process_tiles(tiles2, h_chunks2, w_chunks2)
967
+
968
+ # unpad
969
+ offset_4x = offset * 4
970
+ result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]
971
+
972
+ if weight_type == "checkboard":
973
+ weight_tile = create_checkerboard_weights(self.input_image_size * 4)
974
+
975
+ weight_shape = result2_interior.shape[1:]
976
+ weights_1 = create_offset_weights(weight_tile, weight_shape)
977
+ weights_2 = repeat_weights(weight_tile, weight_shape)
978
+
979
+ normalizer = weights_1 + weights_2
980
+ weights_1 = weights_1 / normalizer
981
+ weights_2 = weights_2 / normalizer
982
+
983
+ weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
984
+ weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
985
+ elif weight_type == "constant":
986
+ weights_1 = torch.ones_like(result2_interior) * 0.5
987
+ weights_2 = weights_1
988
+ else:
989
+ raise ValueError(
990
+ "weight_type should be either 'gaussian' or 'constant' but got",
991
+ weight_type,
992
+ )
993
+
994
+ result1 = result1 * weights_2
995
+ result2 = result2_interior * weights_1
996
+
997
+ # Average the overlapping region
998
+ result1 = result1 + result2
999
+
1000
+ # Remove padding
1001
+ unpadded = result1[:, : h * 4, : w * 4]
1002
+
1003
+ to_pil = transforms.ToPILImage()
1004
+ return to_pil(unpadded)
src/backend/upscale/aura_sr_upscale.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from backend.upscale.aura_sr import AuraSR
2
+ from PIL import Image
3
+
4
+
5
+ def upscale_aura_sr(image_path: str):
6
+
7
+ aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
8
+ image_in = Image.open(image_path) # .resize((256, 256))
9
+ return aura_sr.upscale_4x(image_in)
src/backend/upscale/edsr_upscale_onnx.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+ from huggingface_hub import hf_hub_download
4
+ from PIL import Image
5
+
6
+
7
+ def upscale_edsr_2x(image_path: str):
8
+ input_image = Image.open(image_path).convert("RGB")
9
+ input_image = np.array(input_image).astype("float32")
10
+ input_image = np.transpose(input_image, (2, 0, 1))
11
+ img_arr = np.expand_dims(input_image, axis=0)
12
+
13
+ if np.max(img_arr) > 256: # 16-bit image
14
+ max_range = 65535
15
+ else:
16
+ max_range = 255.0
17
+ img = img_arr / max_range
18
+
19
+ model_path = hf_hub_download(
20
+ repo_id="rupeshs/edsr-onnx",
21
+ filename="edsr_onnxsim_2x.onnx",
22
+ )
23
+ sess = onnxruntime.InferenceSession(model_path)
24
+
25
+ input_name = sess.get_inputs()[0].name
26
+ output_name = sess.get_outputs()[0].name
27
+ output = sess.run(
28
+ [output_name],
29
+ {input_name: img},
30
+ )[0]
31
+
32
+ result = output.squeeze()
33
+ result = result.clip(0, 1)
34
+ image_array = np.transpose(result, (1, 2, 0))
35
+ image_array = np.uint8(image_array * 255)
36
+ upscaled_image = Image.fromarray(image_array)
37
+ return upscaled_image
src/backend/upscale/tiled_upscale.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import logging
4
+ from PIL import Image, ImageDraw, ImageFilter
5
+ from backend.models.lcmdiffusion_setting import DiffusionTask
6
+ from context import Context
7
+ from constants import DEVICE
8
+
9
+
10
+ def generate_upscaled_image(
11
+ config,
12
+ input_path=None,
13
+ strength=0.3,
14
+ scale_factor=2.0,
15
+ tile_overlap=16,
16
+ upscale_settings=None,
17
+ context: Context = None,
18
+ output_path=None,
19
+ image_format="PNG",
20
+ ):
21
+ if config == None or (
22
+ input_path == None or input_path == "" and upscale_settings == None
23
+ ):
24
+ logging.error("Wrong arguments in tiled upscale function call!")
25
+ return
26
+
27
+ # Use the upscale_settings dict if provided; otherwise, build the
28
+ # upscale_settings dict using the function arguments and default values
29
+ if upscale_settings == None:
30
+ upscale_settings = {
31
+ "source_file": input_path,
32
+ "target_file": None,
33
+ "output_format": image_format,
34
+ "strength": strength,
35
+ "scale_factor": scale_factor,
36
+ "prompt": config.lcm_diffusion_setting.prompt,
37
+ "tile_overlap": tile_overlap,
38
+ "tile_size": 256,
39
+ "tiles": [],
40
+ }
41
+ source_image = Image.open(input_path) # PIL image
42
+ else:
43
+ source_image = Image.open(upscale_settings["source_file"])
44
+
45
+ upscale_settings["source_image"] = source_image
46
+
47
+ if upscale_settings["target_file"]:
48
+ result = Image.open(upscale_settings["target_file"])
49
+ else:
50
+ result = Image.new(
51
+ mode="RGBA",
52
+ size=(
53
+ source_image.size[0] * int(upscale_settings["scale_factor"]),
54
+ source_image.size[1] * int(upscale_settings["scale_factor"]),
55
+ ),
56
+ color=(0, 0, 0, 0),
57
+ )
58
+ upscale_settings["target_image"] = result
59
+
60
+ # If the custom tile definition array 'tiles' is empty, proceed with the
61
+ # default tiled upscale task by defining all the possible image tiles; note
62
+ # that the actual tile size is 'tile_size' + 'tile_overlap' and the target
63
+ # image width and height are no longer constrained to multiples of 256 but
64
+ # are instead multiples of the actual tile size
65
+ if len(upscale_settings["tiles"]) == 0:
66
+ tile_size = upscale_settings["tile_size"]
67
+ scale_factor = upscale_settings["scale_factor"]
68
+ tile_overlap = upscale_settings["tile_overlap"]
69
+ total_cols = math.ceil(
70
+ source_image.size[0] / tile_size
71
+ ) # Image width / tile size
72
+ total_rows = math.ceil(
73
+ source_image.size[1] / tile_size
74
+ ) # Image height / tile size
75
+ for y in range(0, total_rows):
76
+ y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
77
+ for x in range(0, total_cols):
78
+ x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
79
+ x1 = x * tile_size
80
+ y1 = y * tile_size
81
+ w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
82
+ h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
83
+ mask_box = ( # Default tile mask box definition
84
+ x_offset,
85
+ y_offset,
86
+ int(w * scale_factor),
87
+ int(h * scale_factor),
88
+ )
89
+ upscale_settings["tiles"].append(
90
+ {
91
+ "x": x1,
92
+ "y": y1,
93
+ "w": w,
94
+ "h": h,
95
+ "mask_box": mask_box,
96
+ "prompt": upscale_settings["prompt"], # Use top level prompt if available
97
+ "scale_factor": scale_factor,
98
+ }
99
+ )
100
+
101
+ # Generate the output image tiles
102
+ for i in range(0, len(upscale_settings["tiles"])):
103
+ generate_upscaled_tile(
104
+ config,
105
+ i,
106
+ upscale_settings,
107
+ context=context,
108
+ )
109
+
110
+ # Save completed upscaled image
111
+ if upscale_settings["output_format"].upper() == "JPEG":
112
+ result_rgb = result.convert("RGB")
113
+ result.close()
114
+ result = result_rgb
115
+ result.save(output_path)
116
+ result.close()
117
+ source_image.close()
118
+ return
119
+
120
+
121
+ def get_current_tile(
122
+ config,
123
+ context,
124
+ strength,
125
+ ):
126
+ config.lcm_diffusion_setting.strength = strength
127
+ config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
128
+ if (
129
+ config.lcm_diffusion_setting.use_tiny_auto_encoder
130
+ and config.lcm_diffusion_setting.use_openvino
131
+ ):
132
+ config.lcm_diffusion_setting.use_tiny_auto_encoder = False
133
+ current_tile = context.generate_text_to_image(
134
+ settings=config,
135
+ reshape=True,
136
+ device=DEVICE,
137
+ save_config=False,
138
+ )[0]
139
+ return current_tile
140
+
141
+
142
+ # Generates a single tile from the source image as defined in the
143
+ # upscale_settings["tiles"] array with the corresponding index and pastes the
144
+ # generated tile into the target image using the corresponding mask and scale
145
+ # factor; note that scale factor for the target image and the individual tiles
146
+ # can be different, this function will adjust scale factors as needed
147
+ def generate_upscaled_tile(
148
+ config,
149
+ index,
150
+ upscale_settings,
151
+ context: Context = None,
152
+ ):
153
+ if config == None or upscale_settings == None:
154
+ logging.error("Wrong arguments in tile creation function call!")
155
+ return
156
+
157
+ x = upscale_settings["tiles"][index]["x"]
158
+ y = upscale_settings["tiles"][index]["y"]
159
+ w = upscale_settings["tiles"][index]["w"]
160
+ h = upscale_settings["tiles"][index]["h"]
161
+ tile_prompt = upscale_settings["tiles"][index]["prompt"]
162
+ scale_factor = upscale_settings["scale_factor"]
163
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
164
+ target_width = int(w * tile_scale_factor)
165
+ target_height = int(h * tile_scale_factor)
166
+ strength = upscale_settings["strength"]
167
+ source_image = upscale_settings["source_image"]
168
+ target_image = upscale_settings["target_image"]
169
+ mask_image = generate_tile_mask(config, index, upscale_settings)
170
+
171
+ config.lcm_diffusion_setting.number_of_images = 1
172
+ config.lcm_diffusion_setting.prompt = tile_prompt
173
+ config.lcm_diffusion_setting.image_width = target_width
174
+ config.lcm_diffusion_setting.image_height = target_height
175
+ config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
176
+
177
+ current_tile = None
178
+ print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
179
+ if tile_prompt == None or tile_prompt == "":
180
+ config.lcm_diffusion_setting.prompt = ""
181
+ config.lcm_diffusion_setting.negative_prompt = ""
182
+ current_tile = get_current_tile(config, context, strength)
183
+ else:
184
+ # Attempt to use img2img with low denoising strength to
185
+ # generate the tiles with the extra aid of a prompt
186
+ # context = get_context(InterfaceType.CLI)
187
+ current_tile = get_current_tile(config, context, strength)
188
+
189
+ if math.isclose(scale_factor, tile_scale_factor):
190
+ target_image.paste(
191
+ current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
192
+ )
193
+ else:
194
+ target_image.paste(
195
+ current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
196
+ (int(x * scale_factor), int(y * scale_factor)),
197
+ mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
198
+ )
199
+ mask_image.close()
200
+ current_tile.close()
201
+ config.lcm_diffusion_setting.init_image.close()
202
+
203
+
204
+ # Generate tile mask using the box definition in the upscale_settings["tiles"]
205
+ # array with the corresponding index; note that tile masks for the default
206
+ # tiled upscale task can be reused but that would complicate the code, so
207
+ # new tile masks are instead created for each tile
208
+ def generate_tile_mask(
209
+ config,
210
+ index,
211
+ upscale_settings,
212
+ ):
213
+ scale_factor = upscale_settings["scale_factor"]
214
+ tile_overlap = upscale_settings["tile_overlap"]
215
+ tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
216
+ w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
217
+ h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
218
+ # The Stable Diffusion pipeline automatically adjusts the output size
219
+ # to multiples of 8 pixels; the mask must be created with the same
220
+ # size as the output tile
221
+ w = w - (w % 8)
222
+ h = h - (h % 8)
223
+ mask_box = upscale_settings["tiles"][index]["mask_box"]
224
+ if mask_box == None:
225
+ # Build a default solid mask with soft/transparent edges
226
+ mask_box = (
227
+ tile_overlap,
228
+ tile_overlap,
229
+ w - tile_overlap,
230
+ h - tile_overlap,
231
+ )
232
+ mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
233
+ mask_draw = ImageDraw.Draw(mask_image)
234
+ mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
235
+ mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
236
+ mask_image.close()
237
+ return mask_blur
src/backend/upscale/upscaler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from backend.models.lcmdiffusion_setting import DiffusionTask
2
+ from backend.models.upscale import UpscaleMode
3
+ from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
4
+ from backend.upscale.aura_sr_upscale import upscale_aura_sr
5
+ from backend.upscale.tiled_upscale import generate_upscaled_image
6
+ from context import Context
7
+ from PIL import Image
8
+ from state import get_settings
9
+
10
+
11
+ config = get_settings()
12
+
13
+
14
+ def upscale_image(
15
+ context: Context,
16
+ src_image_path: str,
17
+ dst_image_path: str,
18
+ scale_factor: int = 2,
19
+ upscale_mode: UpscaleMode = UpscaleMode.normal.value,
20
+ strength: float = 0.1,
21
+ ):
22
+ if upscale_mode == UpscaleMode.normal.value:
23
+ upscaled_img = upscale_edsr_2x(src_image_path)
24
+ upscaled_img.save(dst_image_path)
25
+ print(f"Upscaled image saved {dst_image_path}")
26
+ elif upscale_mode == UpscaleMode.aura_sr.value:
27
+ upscaled_img = upscale_aura_sr(src_image_path)
28
+ upscaled_img.save(dst_image_path)
29
+ print(f"Upscaled image saved {dst_image_path}")
30
+ else:
31
+ config.settings.lcm_diffusion_setting.strength = (
32
+ 0.3 if config.settings.lcm_diffusion_setting.use_openvino else strength
33
+ )
34
+ config.settings.lcm_diffusion_setting.diffusion_task = (
35
+ DiffusionTask.image_to_image.value
36
+ )
37
+
38
+ generate_upscaled_image(
39
+ config.settings,
40
+ src_image_path,
41
+ config.settings.lcm_diffusion_setting.strength,
42
+ upscale_settings=None,
43
+ context=context,
44
+ tile_overlap=(
45
+ 32 if config.settings.lcm_diffusion_setting.use_openvino else 16
46
+ ),
47
+ output_path=dst_image_path,
48
+ image_format=config.settings.generated_images.format,
49
+ )
50
+ print(f"Upscaled image saved {dst_image_path}")
51
+
52
+ return [Image.open(dst_image_path)]