Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files
src/backend/upscale/aura_sr.py
ADDED
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
|
2 |
+
# based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
|
3 |
+
#
|
4 |
+
# https://mingukkang.github.io/GigaGAN/
|
5 |
+
from math import log2, ceil
|
6 |
+
from functools import partial
|
7 |
+
from typing import Any, Optional, List, Iterable
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
from PIL import Image
|
12 |
+
from torch import nn, einsum, Tensor
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from einops import rearrange, repeat, reduce
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
from torchvision.utils import save_image
|
18 |
+
import math
|
19 |
+
|
20 |
+
|
21 |
+
def get_same_padding(size, kernel, dilation, stride):
|
22 |
+
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
23 |
+
|
24 |
+
|
25 |
+
class AdaptiveConv2DMod(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
dim,
|
29 |
+
dim_out,
|
30 |
+
kernel,
|
31 |
+
*,
|
32 |
+
demod=True,
|
33 |
+
stride=1,
|
34 |
+
dilation=1,
|
35 |
+
eps=1e-8,
|
36 |
+
num_conv_kernels=1, # set this to be greater than 1 for adaptive
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
self.eps = eps
|
40 |
+
|
41 |
+
self.dim_out = dim_out
|
42 |
+
|
43 |
+
self.kernel = kernel
|
44 |
+
self.stride = stride
|
45 |
+
self.dilation = dilation
|
46 |
+
self.adaptive = num_conv_kernels > 1
|
47 |
+
|
48 |
+
self.weights = nn.Parameter(
|
49 |
+
torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
|
50 |
+
)
|
51 |
+
|
52 |
+
self.demod = demod
|
53 |
+
|
54 |
+
nn.init.kaiming_normal_(
|
55 |
+
self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
|
60 |
+
):
|
61 |
+
"""
|
62 |
+
notation
|
63 |
+
|
64 |
+
b - batch
|
65 |
+
n - convs
|
66 |
+
o - output
|
67 |
+
i - input
|
68 |
+
k - kernel
|
69 |
+
"""
|
70 |
+
|
71 |
+
b, h = fmap.shape[0], fmap.shape[-2]
|
72 |
+
|
73 |
+
# account for feature map that has been expanded by the scale in the first dimension
|
74 |
+
# due to multiscale inputs and outputs
|
75 |
+
|
76 |
+
if mod.shape[0] != b:
|
77 |
+
mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
|
78 |
+
|
79 |
+
if exists(kernel_mod):
|
80 |
+
kernel_mod_has_el = kernel_mod.numel() > 0
|
81 |
+
|
82 |
+
assert self.adaptive or not kernel_mod_has_el
|
83 |
+
|
84 |
+
if kernel_mod_has_el and kernel_mod.shape[0] != b:
|
85 |
+
kernel_mod = repeat(
|
86 |
+
kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
|
87 |
+
)
|
88 |
+
|
89 |
+
# prepare weights for modulation
|
90 |
+
|
91 |
+
weights = self.weights
|
92 |
+
|
93 |
+
if self.adaptive:
|
94 |
+
weights = repeat(weights, "... -> b ...", b=b)
|
95 |
+
|
96 |
+
# determine an adaptive weight and 'select' the kernel to use with softmax
|
97 |
+
|
98 |
+
assert exists(kernel_mod) and kernel_mod.numel() > 0
|
99 |
+
|
100 |
+
kernel_attn = kernel_mod.softmax(dim=-1)
|
101 |
+
kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
|
102 |
+
|
103 |
+
weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
|
104 |
+
|
105 |
+
# do the modulation, demodulation, as done in stylegan2
|
106 |
+
|
107 |
+
mod = rearrange(mod, "b i -> b 1 i 1 1")
|
108 |
+
|
109 |
+
weights = weights * (mod + 1)
|
110 |
+
|
111 |
+
if self.demod:
|
112 |
+
inv_norm = (
|
113 |
+
reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
|
114 |
+
.clamp(min=self.eps)
|
115 |
+
.rsqrt()
|
116 |
+
)
|
117 |
+
weights = weights * inv_norm
|
118 |
+
|
119 |
+
fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
|
120 |
+
|
121 |
+
weights = rearrange(weights, "b o ... -> (b o) ...")
|
122 |
+
|
123 |
+
padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
|
124 |
+
fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
|
125 |
+
|
126 |
+
return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
|
127 |
+
|
128 |
+
|
129 |
+
class Attend(nn.Module):
|
130 |
+
def __init__(self, dropout=0.0, flash=False):
|
131 |
+
super().__init__()
|
132 |
+
self.dropout = dropout
|
133 |
+
self.attn_dropout = nn.Dropout(dropout)
|
134 |
+
self.scale = nn.Parameter(torch.randn(1))
|
135 |
+
self.flash = flash
|
136 |
+
|
137 |
+
def flash_attn(self, q, k, v):
|
138 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
139 |
+
out = F.scaled_dot_product_attention(
|
140 |
+
q, k, v, dropout_p=self.dropout if self.training else 0.0
|
141 |
+
)
|
142 |
+
return out
|
143 |
+
|
144 |
+
def forward(self, q, k, v):
|
145 |
+
if self.flash:
|
146 |
+
return self.flash_attn(q, k, v)
|
147 |
+
|
148 |
+
scale = q.shape[-1] ** -0.5
|
149 |
+
|
150 |
+
# similarity
|
151 |
+
sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
|
152 |
+
|
153 |
+
# attention
|
154 |
+
attn = sim.softmax(dim=-1)
|
155 |
+
attn = self.attn_dropout(attn)
|
156 |
+
|
157 |
+
# aggregate values
|
158 |
+
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
159 |
+
|
160 |
+
return out
|
161 |
+
|
162 |
+
|
163 |
+
def exists(x):
|
164 |
+
return x is not None
|
165 |
+
|
166 |
+
|
167 |
+
def default(val, d):
|
168 |
+
if exists(val):
|
169 |
+
return val
|
170 |
+
return d() if callable(d) else d
|
171 |
+
|
172 |
+
|
173 |
+
def cast_tuple(t, length=1):
|
174 |
+
if isinstance(t, tuple):
|
175 |
+
return t
|
176 |
+
return (t,) * length
|
177 |
+
|
178 |
+
|
179 |
+
def identity(t, *args, **kwargs):
|
180 |
+
return t
|
181 |
+
|
182 |
+
|
183 |
+
def is_power_of_two(n):
|
184 |
+
return log2(n).is_integer()
|
185 |
+
|
186 |
+
|
187 |
+
def null_iterator():
|
188 |
+
while True:
|
189 |
+
yield None
|
190 |
+
|
191 |
+
|
192 |
+
def Downsample(dim, dim_out=None):
|
193 |
+
return nn.Sequential(
|
194 |
+
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
|
195 |
+
nn.Conv2d(dim * 4, default(dim_out, dim), 1),
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
class RMSNorm(nn.Module):
|
200 |
+
def __init__(self, dim):
|
201 |
+
super().__init__()
|
202 |
+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
203 |
+
self.eps = 1e-4
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
|
207 |
+
|
208 |
+
|
209 |
+
# building block modules
|
210 |
+
|
211 |
+
|
212 |
+
class Block(nn.Module):
|
213 |
+
def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
|
214 |
+
super().__init__()
|
215 |
+
self.proj = AdaptiveConv2DMod(
|
216 |
+
dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
|
217 |
+
)
|
218 |
+
self.kernel = 3
|
219 |
+
self.dilation = 1
|
220 |
+
self.stride = 1
|
221 |
+
|
222 |
+
self.act = nn.SiLU()
|
223 |
+
|
224 |
+
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
225 |
+
conv_mods_iter = default(conv_mods_iter, null_iterator())
|
226 |
+
|
227 |
+
x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
|
228 |
+
|
229 |
+
x = self.act(x)
|
230 |
+
return x
|
231 |
+
|
232 |
+
|
233 |
+
class ResnetBlock(nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
|
236 |
+
):
|
237 |
+
super().__init__()
|
238 |
+
style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
|
239 |
+
|
240 |
+
self.block1 = Block(
|
241 |
+
dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
242 |
+
)
|
243 |
+
self.block2 = Block(
|
244 |
+
dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
|
245 |
+
)
|
246 |
+
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
247 |
+
|
248 |
+
def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
|
249 |
+
h = self.block1(x, conv_mods_iter=conv_mods_iter)
|
250 |
+
h = self.block2(h, conv_mods_iter=conv_mods_iter)
|
251 |
+
|
252 |
+
return h + self.res_conv(x)
|
253 |
+
|
254 |
+
|
255 |
+
class LinearAttention(nn.Module):
|
256 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
257 |
+
super().__init__()
|
258 |
+
self.scale = dim_head**-0.5
|
259 |
+
self.heads = heads
|
260 |
+
hidden_dim = dim_head * heads
|
261 |
+
|
262 |
+
self.norm = RMSNorm(dim)
|
263 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
264 |
+
|
265 |
+
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
b, c, h, w = x.shape
|
269 |
+
|
270 |
+
x = self.norm(x)
|
271 |
+
|
272 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
273 |
+
q, k, v = map(
|
274 |
+
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
|
275 |
+
)
|
276 |
+
|
277 |
+
q = q.softmax(dim=-2)
|
278 |
+
k = k.softmax(dim=-1)
|
279 |
+
|
280 |
+
q = q * self.scale
|
281 |
+
|
282 |
+
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
283 |
+
|
284 |
+
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
285 |
+
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
286 |
+
return self.to_out(out)
|
287 |
+
|
288 |
+
|
289 |
+
class Attention(nn.Module):
|
290 |
+
def __init__(self, dim, heads=4, dim_head=32, flash=False):
|
291 |
+
super().__init__()
|
292 |
+
self.heads = heads
|
293 |
+
hidden_dim = dim_head * heads
|
294 |
+
|
295 |
+
self.norm = RMSNorm(dim)
|
296 |
+
|
297 |
+
self.attend = Attend(flash=flash)
|
298 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
299 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
b, c, h, w = x.shape
|
303 |
+
x = self.norm(x)
|
304 |
+
qkv = self.to_qkv(x).chunk(3, dim=1)
|
305 |
+
|
306 |
+
q, k, v = map(
|
307 |
+
lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
|
308 |
+
)
|
309 |
+
|
310 |
+
out = self.attend(q, k, v)
|
311 |
+
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
312 |
+
|
313 |
+
return self.to_out(out)
|
314 |
+
|
315 |
+
|
316 |
+
# feedforward
|
317 |
+
def FeedForward(dim, mult=4):
|
318 |
+
return nn.Sequential(
|
319 |
+
RMSNorm(dim),
|
320 |
+
nn.Conv2d(dim, dim * mult, 1),
|
321 |
+
nn.GELU(),
|
322 |
+
nn.Conv2d(dim * mult, dim, 1),
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
# transformers
|
327 |
+
class Transformer(nn.Module):
|
328 |
+
def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
|
329 |
+
super().__init__()
|
330 |
+
self.layers = nn.ModuleList([])
|
331 |
+
|
332 |
+
for _ in range(depth):
|
333 |
+
self.layers.append(
|
334 |
+
nn.ModuleList(
|
335 |
+
[
|
336 |
+
Attention(
|
337 |
+
dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
|
338 |
+
),
|
339 |
+
FeedForward(dim=dim, mult=ff_mult),
|
340 |
+
]
|
341 |
+
)
|
342 |
+
)
|
343 |
+
|
344 |
+
def forward(self, x):
|
345 |
+
for attn, ff in self.layers:
|
346 |
+
x = attn(x) + x
|
347 |
+
x = ff(x) + x
|
348 |
+
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
class LinearTransformer(nn.Module):
|
353 |
+
def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
|
354 |
+
super().__init__()
|
355 |
+
self.layers = nn.ModuleList([])
|
356 |
+
|
357 |
+
for _ in range(depth):
|
358 |
+
self.layers.append(
|
359 |
+
nn.ModuleList(
|
360 |
+
[
|
361 |
+
LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
|
362 |
+
FeedForward(dim=dim, mult=ff_mult),
|
363 |
+
]
|
364 |
+
)
|
365 |
+
)
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
for attn, ff in self.layers:
|
369 |
+
x = attn(x) + x
|
370 |
+
x = ff(x) + x
|
371 |
+
|
372 |
+
return x
|
373 |
+
|
374 |
+
|
375 |
+
class NearestNeighborhoodUpsample(nn.Module):
|
376 |
+
def __init__(self, dim, dim_out=None):
|
377 |
+
super().__init__()
|
378 |
+
dim_out = default(dim_out, dim)
|
379 |
+
self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
|
383 |
+
if x.shape[0] >= 64:
|
384 |
+
x = x.contiguous()
|
385 |
+
|
386 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
387 |
+
x = self.conv(x)
|
388 |
+
|
389 |
+
return x
|
390 |
+
|
391 |
+
|
392 |
+
class EqualLinear(nn.Module):
|
393 |
+
def __init__(self, dim, dim_out, lr_mul=1, bias=True):
|
394 |
+
super().__init__()
|
395 |
+
self.weight = nn.Parameter(torch.randn(dim_out, dim))
|
396 |
+
if bias:
|
397 |
+
self.bias = nn.Parameter(torch.zeros(dim_out))
|
398 |
+
|
399 |
+
self.lr_mul = lr_mul
|
400 |
+
|
401 |
+
def forward(self, input):
|
402 |
+
return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
|
403 |
+
|
404 |
+
|
405 |
+
class StyleGanNetwork(nn.Module):
|
406 |
+
def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
|
407 |
+
super().__init__()
|
408 |
+
self.dim_in = dim_in
|
409 |
+
self.dim_out = dim_out
|
410 |
+
self.dim_text_latent = dim_text_latent
|
411 |
+
|
412 |
+
layers = []
|
413 |
+
for i in range(depth):
|
414 |
+
is_first = i == 0
|
415 |
+
|
416 |
+
if is_first:
|
417 |
+
dim_in_layer = dim_in + dim_text_latent
|
418 |
+
else:
|
419 |
+
dim_in_layer = dim_out
|
420 |
+
|
421 |
+
dim_out_layer = dim_out
|
422 |
+
|
423 |
+
layers.extend(
|
424 |
+
[EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
|
425 |
+
)
|
426 |
+
|
427 |
+
self.net = nn.Sequential(*layers)
|
428 |
+
|
429 |
+
def forward(self, x, text_latent=None):
|
430 |
+
x = F.normalize(x, dim=1)
|
431 |
+
if self.dim_text_latent > 0:
|
432 |
+
assert exists(text_latent)
|
433 |
+
x = torch.cat((x, text_latent), dim=-1)
|
434 |
+
return self.net(x)
|
435 |
+
|
436 |
+
|
437 |
+
class UnetUpsampler(torch.nn.Module):
|
438 |
+
|
439 |
+
def __init__(
|
440 |
+
self,
|
441 |
+
dim: int,
|
442 |
+
*,
|
443 |
+
image_size: int,
|
444 |
+
input_image_size: int,
|
445 |
+
init_dim: Optional[int] = None,
|
446 |
+
out_dim: Optional[int] = None,
|
447 |
+
style_network: Optional[dict] = None,
|
448 |
+
up_dim_mults: tuple = (1, 2, 4, 8, 16),
|
449 |
+
down_dim_mults: tuple = (4, 8, 16),
|
450 |
+
channels: int = 3,
|
451 |
+
resnet_block_groups: int = 8,
|
452 |
+
full_attn: tuple = (False, False, False, True, True),
|
453 |
+
flash_attn: bool = True,
|
454 |
+
self_attn_dim_head: int = 64,
|
455 |
+
self_attn_heads: int = 8,
|
456 |
+
attn_depths: tuple = (2, 2, 2, 2, 4),
|
457 |
+
mid_attn_depth: int = 4,
|
458 |
+
num_conv_kernels: int = 4,
|
459 |
+
resize_mode: str = "bilinear",
|
460 |
+
unconditional: bool = True,
|
461 |
+
skip_connect_scale: Optional[float] = None,
|
462 |
+
):
|
463 |
+
super().__init__()
|
464 |
+
self.style_network = style_network = StyleGanNetwork(**style_network)
|
465 |
+
self.unconditional = unconditional
|
466 |
+
assert not (
|
467 |
+
unconditional
|
468 |
+
and exists(style_network)
|
469 |
+
and style_network.dim_text_latent > 0
|
470 |
+
)
|
471 |
+
|
472 |
+
assert is_power_of_two(image_size) and is_power_of_two(
|
473 |
+
input_image_size
|
474 |
+
), "both output image size and input image size must be power of 2"
|
475 |
+
assert (
|
476 |
+
input_image_size < image_size
|
477 |
+
), "input image size must be smaller than the output image size, thus upsampling"
|
478 |
+
|
479 |
+
self.image_size = image_size
|
480 |
+
self.input_image_size = input_image_size
|
481 |
+
|
482 |
+
style_embed_split_dims = []
|
483 |
+
|
484 |
+
self.channels = channels
|
485 |
+
input_channels = channels
|
486 |
+
|
487 |
+
init_dim = default(init_dim, dim)
|
488 |
+
|
489 |
+
up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
|
490 |
+
init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
|
491 |
+
down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
|
492 |
+
self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
|
493 |
+
|
494 |
+
up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
|
495 |
+
down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
|
496 |
+
|
497 |
+
block_klass = partial(
|
498 |
+
ResnetBlock,
|
499 |
+
groups=resnet_block_groups,
|
500 |
+
num_conv_kernels=num_conv_kernels,
|
501 |
+
style_dims=style_embed_split_dims,
|
502 |
+
)
|
503 |
+
|
504 |
+
FullAttention = partial(Transformer, flash_attn=flash_attn)
|
505 |
+
*_, mid_dim = up_dims
|
506 |
+
|
507 |
+
self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
|
508 |
+
|
509 |
+
self.downs = nn.ModuleList([])
|
510 |
+
self.ups = nn.ModuleList([])
|
511 |
+
|
512 |
+
block_count = 6
|
513 |
+
|
514 |
+
for ind, (
|
515 |
+
(dim_in, dim_out),
|
516 |
+
layer_full_attn,
|
517 |
+
layer_attn_depth,
|
518 |
+
) in enumerate(zip(down_in_out, full_attn, attn_depths)):
|
519 |
+
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
520 |
+
|
521 |
+
blocks = []
|
522 |
+
for i in range(block_count):
|
523 |
+
blocks.append(block_klass(dim_in, dim_in))
|
524 |
+
|
525 |
+
self.downs.append(
|
526 |
+
nn.ModuleList(
|
527 |
+
[
|
528 |
+
nn.ModuleList(blocks),
|
529 |
+
nn.ModuleList(
|
530 |
+
[
|
531 |
+
(
|
532 |
+
attn_klass(
|
533 |
+
dim_in,
|
534 |
+
dim_head=self_attn_dim_head,
|
535 |
+
heads=self_attn_heads,
|
536 |
+
depth=layer_attn_depth,
|
537 |
+
)
|
538 |
+
if layer_full_attn
|
539 |
+
else None
|
540 |
+
),
|
541 |
+
nn.Conv2d(
|
542 |
+
dim_in, dim_out, kernel_size=3, stride=2, padding=1
|
543 |
+
),
|
544 |
+
]
|
545 |
+
),
|
546 |
+
]
|
547 |
+
)
|
548 |
+
)
|
549 |
+
|
550 |
+
self.mid_block1 = block_klass(mid_dim, mid_dim)
|
551 |
+
self.mid_attn = FullAttention(
|
552 |
+
mid_dim,
|
553 |
+
dim_head=self_attn_dim_head,
|
554 |
+
heads=self_attn_heads,
|
555 |
+
depth=mid_attn_depth,
|
556 |
+
)
|
557 |
+
self.mid_block2 = block_klass(mid_dim, mid_dim)
|
558 |
+
|
559 |
+
*_, last_dim = up_dims
|
560 |
+
|
561 |
+
for ind, (
|
562 |
+
(dim_in, dim_out),
|
563 |
+
layer_full_attn,
|
564 |
+
layer_attn_depth,
|
565 |
+
) in enumerate(
|
566 |
+
zip(
|
567 |
+
reversed(up_in_out),
|
568 |
+
reversed(full_attn),
|
569 |
+
reversed(attn_depths),
|
570 |
+
)
|
571 |
+
):
|
572 |
+
attn_klass = FullAttention if layer_full_attn else LinearTransformer
|
573 |
+
|
574 |
+
blocks = []
|
575 |
+
input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
|
576 |
+
for i in range(block_count):
|
577 |
+
blocks.append(block_klass(input_dim, dim_in))
|
578 |
+
|
579 |
+
self.ups.append(
|
580 |
+
nn.ModuleList(
|
581 |
+
[
|
582 |
+
nn.ModuleList(blocks),
|
583 |
+
nn.ModuleList(
|
584 |
+
[
|
585 |
+
NearestNeighborhoodUpsample(
|
586 |
+
last_dim if ind == 0 else dim_out,
|
587 |
+
dim_in,
|
588 |
+
),
|
589 |
+
(
|
590 |
+
attn_klass(
|
591 |
+
dim_in,
|
592 |
+
dim_head=self_attn_dim_head,
|
593 |
+
heads=self_attn_heads,
|
594 |
+
depth=layer_attn_depth,
|
595 |
+
)
|
596 |
+
if layer_full_attn
|
597 |
+
else None
|
598 |
+
),
|
599 |
+
]
|
600 |
+
),
|
601 |
+
]
|
602 |
+
)
|
603 |
+
)
|
604 |
+
|
605 |
+
self.out_dim = default(out_dim, channels)
|
606 |
+
self.final_res_block = block_klass(dim, dim)
|
607 |
+
self.final_to_rgb = nn.Conv2d(dim, channels, 1)
|
608 |
+
self.resize_mode = resize_mode
|
609 |
+
self.style_to_conv_modulations = nn.Linear(
|
610 |
+
style_network.dim_out, sum(style_embed_split_dims)
|
611 |
+
)
|
612 |
+
self.style_embed_split_dims = style_embed_split_dims
|
613 |
+
|
614 |
+
@property
|
615 |
+
def allowable_rgb_resolutions(self):
|
616 |
+
input_res_base = int(log2(self.input_image_size))
|
617 |
+
output_res_base = int(log2(self.image_size))
|
618 |
+
allowed_rgb_res_base = list(range(input_res_base, output_res_base))
|
619 |
+
return [*map(lambda p: 2**p, allowed_rgb_res_base)]
|
620 |
+
|
621 |
+
@property
|
622 |
+
def device(self):
|
623 |
+
return next(self.parameters()).device
|
624 |
+
|
625 |
+
@property
|
626 |
+
def total_params(self):
|
627 |
+
return sum([p.numel() for p in self.parameters()])
|
628 |
+
|
629 |
+
def resize_image_to(self, x, size):
|
630 |
+
return F.interpolate(x, (size, size), mode=self.resize_mode)
|
631 |
+
|
632 |
+
def forward(
|
633 |
+
self,
|
634 |
+
lowres_image: torch.Tensor,
|
635 |
+
styles: Optional[torch.Tensor] = None,
|
636 |
+
noise: Optional[torch.Tensor] = None,
|
637 |
+
global_text_tokens: Optional[torch.Tensor] = None,
|
638 |
+
return_all_rgbs: bool = False,
|
639 |
+
):
|
640 |
+
x = lowres_image
|
641 |
+
|
642 |
+
noise_scale = 0.001 # Adjust the scale of the noise as needed
|
643 |
+
noise_aug = torch.randn_like(x) * noise_scale
|
644 |
+
x = x + noise_aug
|
645 |
+
x = x.clamp(0, 1)
|
646 |
+
|
647 |
+
shape = x.shape
|
648 |
+
batch_size = shape[0]
|
649 |
+
|
650 |
+
assert shape[-2:] == ((self.input_image_size,) * 2)
|
651 |
+
|
652 |
+
# styles
|
653 |
+
if not exists(styles):
|
654 |
+
assert exists(self.style_network)
|
655 |
+
|
656 |
+
noise = default(
|
657 |
+
noise,
|
658 |
+
torch.randn(
|
659 |
+
(batch_size, self.style_network.dim_in), device=self.device
|
660 |
+
),
|
661 |
+
)
|
662 |
+
styles = self.style_network(noise, global_text_tokens)
|
663 |
+
|
664 |
+
# project styles to conv modulations
|
665 |
+
conv_mods = self.style_to_conv_modulations(styles)
|
666 |
+
conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
|
667 |
+
conv_mods = iter(conv_mods)
|
668 |
+
|
669 |
+
x = self.init_conv(x)
|
670 |
+
|
671 |
+
h = []
|
672 |
+
for blocks, (attn, downsample) in self.downs:
|
673 |
+
for block in blocks:
|
674 |
+
x = block(x, conv_mods_iter=conv_mods)
|
675 |
+
h.append(x)
|
676 |
+
|
677 |
+
if attn is not None:
|
678 |
+
x = attn(x)
|
679 |
+
|
680 |
+
x = downsample(x)
|
681 |
+
|
682 |
+
x = self.mid_block1(x, conv_mods_iter=conv_mods)
|
683 |
+
x = self.mid_attn(x)
|
684 |
+
x = self.mid_block2(x, conv_mods_iter=conv_mods)
|
685 |
+
|
686 |
+
for (
|
687 |
+
blocks,
|
688 |
+
(
|
689 |
+
upsample,
|
690 |
+
attn,
|
691 |
+
),
|
692 |
+
) in self.ups:
|
693 |
+
x = upsample(x)
|
694 |
+
for block in blocks:
|
695 |
+
if h != []:
|
696 |
+
res = h.pop()
|
697 |
+
res = res * self.skip_connect_scale
|
698 |
+
x = torch.cat((x, res), dim=1)
|
699 |
+
|
700 |
+
x = block(x, conv_mods_iter=conv_mods)
|
701 |
+
|
702 |
+
if attn is not None:
|
703 |
+
x = attn(x)
|
704 |
+
|
705 |
+
x = self.final_res_block(x, conv_mods_iter=conv_mods)
|
706 |
+
rgb = self.final_to_rgb(x)
|
707 |
+
|
708 |
+
if not return_all_rgbs:
|
709 |
+
return rgb
|
710 |
+
|
711 |
+
return rgb, []
|
712 |
+
|
713 |
+
|
714 |
+
def tile_image(image, chunk_size=64):
|
715 |
+
c, h, w = image.shape
|
716 |
+
h_chunks = ceil(h / chunk_size)
|
717 |
+
w_chunks = ceil(w / chunk_size)
|
718 |
+
tiles = []
|
719 |
+
for i in range(h_chunks):
|
720 |
+
for j in range(w_chunks):
|
721 |
+
tile = image[
|
722 |
+
:,
|
723 |
+
i * chunk_size : (i + 1) * chunk_size,
|
724 |
+
j * chunk_size : (j + 1) * chunk_size,
|
725 |
+
]
|
726 |
+
tiles.append(tile)
|
727 |
+
return tiles, h_chunks, w_chunks
|
728 |
+
|
729 |
+
|
730 |
+
# This helps create a checkboard pattern with some edge blending
|
731 |
+
def create_checkerboard_weights(tile_size):
|
732 |
+
x = torch.linspace(-1, 1, tile_size)
|
733 |
+
y = torch.linspace(-1, 1, tile_size)
|
734 |
+
|
735 |
+
x, y = torch.meshgrid(x, y, indexing="ij")
|
736 |
+
d = torch.sqrt(x * x + y * y)
|
737 |
+
sigma, mu = 0.5, 0.0
|
738 |
+
weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))
|
739 |
+
|
740 |
+
# saturate the values to sure get high weights in the center
|
741 |
+
weights = weights**8
|
742 |
+
|
743 |
+
return weights / weights.max() # Normalize to [0, 1]
|
744 |
+
|
745 |
+
|
746 |
+
def repeat_weights(weights, image_size):
|
747 |
+
tile_size = weights.shape[0]
|
748 |
+
repeats = (
|
749 |
+
math.ceil(image_size[0] / tile_size),
|
750 |
+
math.ceil(image_size[1] / tile_size),
|
751 |
+
)
|
752 |
+
return weights.repeat(repeats)[: image_size[0], : image_size[1]]
|
753 |
+
|
754 |
+
|
755 |
+
def create_offset_weights(weights, image_size):
|
756 |
+
tile_size = weights.shape[0]
|
757 |
+
offset = tile_size // 2
|
758 |
+
full_weights = repeat_weights(
|
759 |
+
weights, (image_size[0] + offset, image_size[1] + offset)
|
760 |
+
)
|
761 |
+
return full_weights[offset:, offset:]
|
762 |
+
|
763 |
+
|
764 |
+
def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
|
765 |
+
# Determine the shape of the output tensor
|
766 |
+
c = tiles[0].shape[0]
|
767 |
+
h = h_chunks * chunk_size
|
768 |
+
w = w_chunks * chunk_size
|
769 |
+
|
770 |
+
# Create an empty tensor to hold the merged image
|
771 |
+
merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
|
772 |
+
|
773 |
+
# Iterate over the tiles and place them in the correct position
|
774 |
+
for idx, tile in enumerate(tiles):
|
775 |
+
i = idx // w_chunks
|
776 |
+
j = idx % w_chunks
|
777 |
+
|
778 |
+
h_start = i * chunk_size
|
779 |
+
w_start = j * chunk_size
|
780 |
+
|
781 |
+
tile_h, tile_w = tile.shape[1:]
|
782 |
+
merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile
|
783 |
+
|
784 |
+
return merged
|
785 |
+
|
786 |
+
|
787 |
+
class AuraSR:
|
788 |
+
def __init__(self, config: dict[str, Any], device: str = "cuda"):
|
789 |
+
self.upsampler = UnetUpsampler(**config).to(device)
|
790 |
+
self.input_image_size = config["input_image_size"]
|
791 |
+
|
792 |
+
@classmethod
|
793 |
+
def from_pretrained(
|
794 |
+
cls,
|
795 |
+
model_id: str = "fal-ai/AuraSR",
|
796 |
+
use_safetensors: bool = True,
|
797 |
+
device: str = "cuda",
|
798 |
+
):
|
799 |
+
import json
|
800 |
+
import torch
|
801 |
+
from pathlib import Path
|
802 |
+
from huggingface_hub import snapshot_download
|
803 |
+
|
804 |
+
# Check if model_id is a local file
|
805 |
+
if Path(model_id).is_file():
|
806 |
+
local_file = Path(model_id)
|
807 |
+
if local_file.suffix == ".safetensors":
|
808 |
+
use_safetensors = True
|
809 |
+
elif local_file.suffix == ".ckpt":
|
810 |
+
use_safetensors = False
|
811 |
+
else:
|
812 |
+
raise ValueError(
|
813 |
+
f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
|
814 |
+
)
|
815 |
+
|
816 |
+
# For local files, we need to provide the config separately
|
817 |
+
config_path = local_file.with_name("config.json")
|
818 |
+
if not config_path.exists():
|
819 |
+
raise FileNotFoundError(
|
820 |
+
f"Config file not found: {config_path}. "
|
821 |
+
f"When loading from a local file, ensure that 'config.json' "
|
822 |
+
f"is present in the same directory as '{local_file.name}'. "
|
823 |
+
f"If you're trying to load a model from Hugging Face, "
|
824 |
+
f"please provide the model ID instead of a file path."
|
825 |
+
)
|
826 |
+
|
827 |
+
config = json.loads(config_path.read_text())
|
828 |
+
hf_model_path = local_file.parent
|
829 |
+
else:
|
830 |
+
hf_model_path = Path(
|
831 |
+
snapshot_download(model_id, ignore_patterns=["*.ckpt"])
|
832 |
+
)
|
833 |
+
config = json.loads((hf_model_path / "config.json").read_text())
|
834 |
+
|
835 |
+
model = cls(config, device)
|
836 |
+
|
837 |
+
if use_safetensors:
|
838 |
+
try:
|
839 |
+
from safetensors.torch import load_file
|
840 |
+
|
841 |
+
checkpoint = load_file(
|
842 |
+
hf_model_path / "model.safetensors"
|
843 |
+
if not Path(model_id).is_file()
|
844 |
+
else model_id
|
845 |
+
)
|
846 |
+
except ImportError:
|
847 |
+
raise ImportError(
|
848 |
+
"The safetensors library is not installed. "
|
849 |
+
"Please install it with `pip install safetensors` "
|
850 |
+
"or use `use_safetensors=False` to load the model with PyTorch."
|
851 |
+
)
|
852 |
+
else:
|
853 |
+
checkpoint = torch.load(
|
854 |
+
hf_model_path / "model.ckpt"
|
855 |
+
if not Path(model_id).is_file()
|
856 |
+
else model_id
|
857 |
+
)
|
858 |
+
|
859 |
+
model.upsampler.load_state_dict(checkpoint, strict=True)
|
860 |
+
return model
|
861 |
+
|
862 |
+
@torch.no_grad()
|
863 |
+
def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
|
864 |
+
tensor_transform = transforms.ToTensor()
|
865 |
+
device = self.upsampler.device
|
866 |
+
|
867 |
+
image_tensor = tensor_transform(image).unsqueeze(0)
|
868 |
+
_, _, h, w = image_tensor.shape
|
869 |
+
pad_h = (
|
870 |
+
self.input_image_size - h % self.input_image_size
|
871 |
+
) % self.input_image_size
|
872 |
+
pad_w = (
|
873 |
+
self.input_image_size - w % self.input_image_size
|
874 |
+
) % self.input_image_size
|
875 |
+
|
876 |
+
# Pad the image
|
877 |
+
image_tensor = torch.nn.functional.pad(
|
878 |
+
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
|
879 |
+
).squeeze(0)
|
880 |
+
tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
|
881 |
+
|
882 |
+
# Batch processing of tiles
|
883 |
+
num_tiles = len(tiles)
|
884 |
+
batches = [
|
885 |
+
tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
|
886 |
+
]
|
887 |
+
reconstructed_tiles = []
|
888 |
+
|
889 |
+
for batch in batches:
|
890 |
+
model_input = torch.stack(batch).to(device)
|
891 |
+
generator_output = self.upsampler(
|
892 |
+
lowres_image=model_input,
|
893 |
+
noise=torch.randn(model_input.shape[0], 128, device=device),
|
894 |
+
)
|
895 |
+
reconstructed_tiles.extend(
|
896 |
+
list(generator_output.clamp_(0, 1).detach().cpu())
|
897 |
+
)
|
898 |
+
|
899 |
+
merged_tensor = merge_tiles(
|
900 |
+
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
|
901 |
+
)
|
902 |
+
unpadded = merged_tensor[:, : h * 4, : w * 4]
|
903 |
+
|
904 |
+
to_pil = transforms.ToPILImage()
|
905 |
+
return to_pil(unpadded)
|
906 |
+
|
907 |
+
# Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
|
908 |
+
# weights options are 'checkboard' and 'constant'
|
909 |
+
@torch.no_grad()
|
910 |
+
def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
|
911 |
+
tensor_transform = transforms.ToTensor()
|
912 |
+
device = self.upsampler.device
|
913 |
+
|
914 |
+
image_tensor = tensor_transform(image).unsqueeze(0)
|
915 |
+
_, _, h, w = image_tensor.shape
|
916 |
+
|
917 |
+
# Calculate paddings
|
918 |
+
pad_h = (
|
919 |
+
self.input_image_size - h % self.input_image_size
|
920 |
+
) % self.input_image_size
|
921 |
+
pad_w = (
|
922 |
+
self.input_image_size - w % self.input_image_size
|
923 |
+
) % self.input_image_size
|
924 |
+
|
925 |
+
# Pad the image
|
926 |
+
image_tensor = torch.nn.functional.pad(
|
927 |
+
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
|
928 |
+
).squeeze(0)
|
929 |
+
|
930 |
+
# Function to process tiles
|
931 |
+
def process_tiles(tiles, h_chunks, w_chunks):
|
932 |
+
num_tiles = len(tiles)
|
933 |
+
batches = [
|
934 |
+
tiles[i : i + max_batch_size]
|
935 |
+
for i in range(0, num_tiles, max_batch_size)
|
936 |
+
]
|
937 |
+
reconstructed_tiles = []
|
938 |
+
|
939 |
+
for batch in batches:
|
940 |
+
model_input = torch.stack(batch).to(device)
|
941 |
+
generator_output = self.upsampler(
|
942 |
+
lowres_image=model_input,
|
943 |
+
noise=torch.randn(model_input.shape[0], 128, device=device),
|
944 |
+
)
|
945 |
+
reconstructed_tiles.extend(
|
946 |
+
list(generator_output.clamp_(0, 1).detach().cpu())
|
947 |
+
)
|
948 |
+
|
949 |
+
return merge_tiles(
|
950 |
+
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
|
951 |
+
)
|
952 |
+
|
953 |
+
# First pass
|
954 |
+
tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
|
955 |
+
result1 = process_tiles(tiles1, h_chunks1, w_chunks1)
|
956 |
+
|
957 |
+
# Second pass with offset
|
958 |
+
offset = self.input_image_size // 2
|
959 |
+
image_tensor_offset = torch.nn.functional.pad(
|
960 |
+
image_tensor, (offset, offset, offset, offset), mode="reflect"
|
961 |
+
).squeeze(0)
|
962 |
+
|
963 |
+
tiles2, h_chunks2, w_chunks2 = tile_image(
|
964 |
+
image_tensor_offset, self.input_image_size
|
965 |
+
)
|
966 |
+
result2 = process_tiles(tiles2, h_chunks2, w_chunks2)
|
967 |
+
|
968 |
+
# unpad
|
969 |
+
offset_4x = offset * 4
|
970 |
+
result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]
|
971 |
+
|
972 |
+
if weight_type == "checkboard":
|
973 |
+
weight_tile = create_checkerboard_weights(self.input_image_size * 4)
|
974 |
+
|
975 |
+
weight_shape = result2_interior.shape[1:]
|
976 |
+
weights_1 = create_offset_weights(weight_tile, weight_shape)
|
977 |
+
weights_2 = repeat_weights(weight_tile, weight_shape)
|
978 |
+
|
979 |
+
normalizer = weights_1 + weights_2
|
980 |
+
weights_1 = weights_1 / normalizer
|
981 |
+
weights_2 = weights_2 / normalizer
|
982 |
+
|
983 |
+
weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
|
984 |
+
weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
|
985 |
+
elif weight_type == "constant":
|
986 |
+
weights_1 = torch.ones_like(result2_interior) * 0.5
|
987 |
+
weights_2 = weights_1
|
988 |
+
else:
|
989 |
+
raise ValueError(
|
990 |
+
"weight_type should be either 'gaussian' or 'constant' but got",
|
991 |
+
weight_type,
|
992 |
+
)
|
993 |
+
|
994 |
+
result1 = result1 * weights_2
|
995 |
+
result2 = result2_interior * weights_1
|
996 |
+
|
997 |
+
# Average the overlapping region
|
998 |
+
result1 = result1 + result2
|
999 |
+
|
1000 |
+
# Remove padding
|
1001 |
+
unpadded = result1[:, : h * 4, : w * 4]
|
1002 |
+
|
1003 |
+
to_pil = transforms.ToPILImage()
|
1004 |
+
return to_pil(unpadded)
|
src/backend/upscale/aura_sr_upscale.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.upscale.aura_sr import AuraSR
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def upscale_aura_sr(image_path: str):
|
6 |
+
|
7 |
+
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
|
8 |
+
image_in = Image.open(image_path) # .resize((256, 256))
|
9 |
+
return aura_sr.upscale_4x(image_in)
|
src/backend/upscale/edsr_upscale_onnx.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def upscale_edsr_2x(image_path: str):
|
8 |
+
input_image = Image.open(image_path).convert("RGB")
|
9 |
+
input_image = np.array(input_image).astype("float32")
|
10 |
+
input_image = np.transpose(input_image, (2, 0, 1))
|
11 |
+
img_arr = np.expand_dims(input_image, axis=0)
|
12 |
+
|
13 |
+
if np.max(img_arr) > 256: # 16-bit image
|
14 |
+
max_range = 65535
|
15 |
+
else:
|
16 |
+
max_range = 255.0
|
17 |
+
img = img_arr / max_range
|
18 |
+
|
19 |
+
model_path = hf_hub_download(
|
20 |
+
repo_id="rupeshs/edsr-onnx",
|
21 |
+
filename="edsr_onnxsim_2x.onnx",
|
22 |
+
)
|
23 |
+
sess = onnxruntime.InferenceSession(model_path)
|
24 |
+
|
25 |
+
input_name = sess.get_inputs()[0].name
|
26 |
+
output_name = sess.get_outputs()[0].name
|
27 |
+
output = sess.run(
|
28 |
+
[output_name],
|
29 |
+
{input_name: img},
|
30 |
+
)[0]
|
31 |
+
|
32 |
+
result = output.squeeze()
|
33 |
+
result = result.clip(0, 1)
|
34 |
+
image_array = np.transpose(result, (1, 2, 0))
|
35 |
+
image_array = np.uint8(image_array * 255)
|
36 |
+
upscaled_image = Image.fromarray(image_array)
|
37 |
+
return upscaled_image
|
src/backend/upscale/tiled_upscale.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import math
|
3 |
+
import logging
|
4 |
+
from PIL import Image, ImageDraw, ImageFilter
|
5 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
6 |
+
from context import Context
|
7 |
+
from constants import DEVICE
|
8 |
+
|
9 |
+
|
10 |
+
def generate_upscaled_image(
|
11 |
+
config,
|
12 |
+
input_path=None,
|
13 |
+
strength=0.3,
|
14 |
+
scale_factor=2.0,
|
15 |
+
tile_overlap=16,
|
16 |
+
upscale_settings=None,
|
17 |
+
context: Context = None,
|
18 |
+
output_path=None,
|
19 |
+
image_format="PNG",
|
20 |
+
):
|
21 |
+
if config == None or (
|
22 |
+
input_path == None or input_path == "" and upscale_settings == None
|
23 |
+
):
|
24 |
+
logging.error("Wrong arguments in tiled upscale function call!")
|
25 |
+
return
|
26 |
+
|
27 |
+
# Use the upscale_settings dict if provided; otherwise, build the
|
28 |
+
# upscale_settings dict using the function arguments and default values
|
29 |
+
if upscale_settings == None:
|
30 |
+
upscale_settings = {
|
31 |
+
"source_file": input_path,
|
32 |
+
"target_file": None,
|
33 |
+
"output_format": image_format,
|
34 |
+
"strength": strength,
|
35 |
+
"scale_factor": scale_factor,
|
36 |
+
"prompt": config.lcm_diffusion_setting.prompt,
|
37 |
+
"tile_overlap": tile_overlap,
|
38 |
+
"tile_size": 256,
|
39 |
+
"tiles": [],
|
40 |
+
}
|
41 |
+
source_image = Image.open(input_path) # PIL image
|
42 |
+
else:
|
43 |
+
source_image = Image.open(upscale_settings["source_file"])
|
44 |
+
|
45 |
+
upscale_settings["source_image"] = source_image
|
46 |
+
|
47 |
+
if upscale_settings["target_file"]:
|
48 |
+
result = Image.open(upscale_settings["target_file"])
|
49 |
+
else:
|
50 |
+
result = Image.new(
|
51 |
+
mode="RGBA",
|
52 |
+
size=(
|
53 |
+
source_image.size[0] * int(upscale_settings["scale_factor"]),
|
54 |
+
source_image.size[1] * int(upscale_settings["scale_factor"]),
|
55 |
+
),
|
56 |
+
color=(0, 0, 0, 0),
|
57 |
+
)
|
58 |
+
upscale_settings["target_image"] = result
|
59 |
+
|
60 |
+
# If the custom tile definition array 'tiles' is empty, proceed with the
|
61 |
+
# default tiled upscale task by defining all the possible image tiles; note
|
62 |
+
# that the actual tile size is 'tile_size' + 'tile_overlap' and the target
|
63 |
+
# image width and height are no longer constrained to multiples of 256 but
|
64 |
+
# are instead multiples of the actual tile size
|
65 |
+
if len(upscale_settings["tiles"]) == 0:
|
66 |
+
tile_size = upscale_settings["tile_size"]
|
67 |
+
scale_factor = upscale_settings["scale_factor"]
|
68 |
+
tile_overlap = upscale_settings["tile_overlap"]
|
69 |
+
total_cols = math.ceil(
|
70 |
+
source_image.size[0] / tile_size
|
71 |
+
) # Image width / tile size
|
72 |
+
total_rows = math.ceil(
|
73 |
+
source_image.size[1] / tile_size
|
74 |
+
) # Image height / tile size
|
75 |
+
for y in range(0, total_rows):
|
76 |
+
y_offset = tile_overlap if y > 0 else 0 # Tile mask offset
|
77 |
+
for x in range(0, total_cols):
|
78 |
+
x_offset = tile_overlap if x > 0 else 0 # Tile mask offset
|
79 |
+
x1 = x * tile_size
|
80 |
+
y1 = y * tile_size
|
81 |
+
w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
|
82 |
+
h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
|
83 |
+
mask_box = ( # Default tile mask box definition
|
84 |
+
x_offset,
|
85 |
+
y_offset,
|
86 |
+
int(w * scale_factor),
|
87 |
+
int(h * scale_factor),
|
88 |
+
)
|
89 |
+
upscale_settings["tiles"].append(
|
90 |
+
{
|
91 |
+
"x": x1,
|
92 |
+
"y": y1,
|
93 |
+
"w": w,
|
94 |
+
"h": h,
|
95 |
+
"mask_box": mask_box,
|
96 |
+
"prompt": upscale_settings["prompt"], # Use top level prompt if available
|
97 |
+
"scale_factor": scale_factor,
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
# Generate the output image tiles
|
102 |
+
for i in range(0, len(upscale_settings["tiles"])):
|
103 |
+
generate_upscaled_tile(
|
104 |
+
config,
|
105 |
+
i,
|
106 |
+
upscale_settings,
|
107 |
+
context=context,
|
108 |
+
)
|
109 |
+
|
110 |
+
# Save completed upscaled image
|
111 |
+
if upscale_settings["output_format"].upper() == "JPEG":
|
112 |
+
result_rgb = result.convert("RGB")
|
113 |
+
result.close()
|
114 |
+
result = result_rgb
|
115 |
+
result.save(output_path)
|
116 |
+
result.close()
|
117 |
+
source_image.close()
|
118 |
+
return
|
119 |
+
|
120 |
+
|
121 |
+
def get_current_tile(
|
122 |
+
config,
|
123 |
+
context,
|
124 |
+
strength,
|
125 |
+
):
|
126 |
+
config.lcm_diffusion_setting.strength = strength
|
127 |
+
config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
|
128 |
+
if (
|
129 |
+
config.lcm_diffusion_setting.use_tiny_auto_encoder
|
130 |
+
and config.lcm_diffusion_setting.use_openvino
|
131 |
+
):
|
132 |
+
config.lcm_diffusion_setting.use_tiny_auto_encoder = False
|
133 |
+
current_tile = context.generate_text_to_image(
|
134 |
+
settings=config,
|
135 |
+
reshape=True,
|
136 |
+
device=DEVICE,
|
137 |
+
save_config=False,
|
138 |
+
)[0]
|
139 |
+
return current_tile
|
140 |
+
|
141 |
+
|
142 |
+
# Generates a single tile from the source image as defined in the
|
143 |
+
# upscale_settings["tiles"] array with the corresponding index and pastes the
|
144 |
+
# generated tile into the target image using the corresponding mask and scale
|
145 |
+
# factor; note that scale factor for the target image and the individual tiles
|
146 |
+
# can be different, this function will adjust scale factors as needed
|
147 |
+
def generate_upscaled_tile(
|
148 |
+
config,
|
149 |
+
index,
|
150 |
+
upscale_settings,
|
151 |
+
context: Context = None,
|
152 |
+
):
|
153 |
+
if config == None or upscale_settings == None:
|
154 |
+
logging.error("Wrong arguments in tile creation function call!")
|
155 |
+
return
|
156 |
+
|
157 |
+
x = upscale_settings["tiles"][index]["x"]
|
158 |
+
y = upscale_settings["tiles"][index]["y"]
|
159 |
+
w = upscale_settings["tiles"][index]["w"]
|
160 |
+
h = upscale_settings["tiles"][index]["h"]
|
161 |
+
tile_prompt = upscale_settings["tiles"][index]["prompt"]
|
162 |
+
scale_factor = upscale_settings["scale_factor"]
|
163 |
+
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
164 |
+
target_width = int(w * tile_scale_factor)
|
165 |
+
target_height = int(h * tile_scale_factor)
|
166 |
+
strength = upscale_settings["strength"]
|
167 |
+
source_image = upscale_settings["source_image"]
|
168 |
+
target_image = upscale_settings["target_image"]
|
169 |
+
mask_image = generate_tile_mask(config, index, upscale_settings)
|
170 |
+
|
171 |
+
config.lcm_diffusion_setting.number_of_images = 1
|
172 |
+
config.lcm_diffusion_setting.prompt = tile_prompt
|
173 |
+
config.lcm_diffusion_setting.image_width = target_width
|
174 |
+
config.lcm_diffusion_setting.image_height = target_height
|
175 |
+
config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
|
176 |
+
|
177 |
+
current_tile = None
|
178 |
+
print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
|
179 |
+
if tile_prompt == None or tile_prompt == "":
|
180 |
+
config.lcm_diffusion_setting.prompt = ""
|
181 |
+
config.lcm_diffusion_setting.negative_prompt = ""
|
182 |
+
current_tile = get_current_tile(config, context, strength)
|
183 |
+
else:
|
184 |
+
# Attempt to use img2img with low denoising strength to
|
185 |
+
# generate the tiles with the extra aid of a prompt
|
186 |
+
# context = get_context(InterfaceType.CLI)
|
187 |
+
current_tile = get_current_tile(config, context, strength)
|
188 |
+
|
189 |
+
if math.isclose(scale_factor, tile_scale_factor):
|
190 |
+
target_image.paste(
|
191 |
+
current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
target_image.paste(
|
195 |
+
current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
|
196 |
+
(int(x * scale_factor), int(y * scale_factor)),
|
197 |
+
mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
|
198 |
+
)
|
199 |
+
mask_image.close()
|
200 |
+
current_tile.close()
|
201 |
+
config.lcm_diffusion_setting.init_image.close()
|
202 |
+
|
203 |
+
|
204 |
+
# Generate tile mask using the box definition in the upscale_settings["tiles"]
|
205 |
+
# array with the corresponding index; note that tile masks for the default
|
206 |
+
# tiled upscale task can be reused but that would complicate the code, so
|
207 |
+
# new tile masks are instead created for each tile
|
208 |
+
def generate_tile_mask(
|
209 |
+
config,
|
210 |
+
index,
|
211 |
+
upscale_settings,
|
212 |
+
):
|
213 |
+
scale_factor = upscale_settings["scale_factor"]
|
214 |
+
tile_overlap = upscale_settings["tile_overlap"]
|
215 |
+
tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
|
216 |
+
w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
|
217 |
+
h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
|
218 |
+
# The Stable Diffusion pipeline automatically adjusts the output size
|
219 |
+
# to multiples of 8 pixels; the mask must be created with the same
|
220 |
+
# size as the output tile
|
221 |
+
w = w - (w % 8)
|
222 |
+
h = h - (h % 8)
|
223 |
+
mask_box = upscale_settings["tiles"][index]["mask_box"]
|
224 |
+
if mask_box == None:
|
225 |
+
# Build a default solid mask with soft/transparent edges
|
226 |
+
mask_box = (
|
227 |
+
tile_overlap,
|
228 |
+
tile_overlap,
|
229 |
+
w - tile_overlap,
|
230 |
+
h - tile_overlap,
|
231 |
+
)
|
232 |
+
mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
|
233 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
234 |
+
mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
|
235 |
+
mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
|
236 |
+
mask_image.close()
|
237 |
+
return mask_blur
|
src/backend/upscale/upscaler.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.models.lcmdiffusion_setting import DiffusionTask
|
2 |
+
from backend.models.upscale import UpscaleMode
|
3 |
+
from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
|
4 |
+
from backend.upscale.aura_sr_upscale import upscale_aura_sr
|
5 |
+
from backend.upscale.tiled_upscale import generate_upscaled_image
|
6 |
+
from context import Context
|
7 |
+
from PIL import Image
|
8 |
+
from state import get_settings
|
9 |
+
|
10 |
+
|
11 |
+
config = get_settings()
|
12 |
+
|
13 |
+
|
14 |
+
def upscale_image(
|
15 |
+
context: Context,
|
16 |
+
src_image_path: str,
|
17 |
+
dst_image_path: str,
|
18 |
+
scale_factor: int = 2,
|
19 |
+
upscale_mode: UpscaleMode = UpscaleMode.normal.value,
|
20 |
+
strength: float = 0.1,
|
21 |
+
):
|
22 |
+
if upscale_mode == UpscaleMode.normal.value:
|
23 |
+
upscaled_img = upscale_edsr_2x(src_image_path)
|
24 |
+
upscaled_img.save(dst_image_path)
|
25 |
+
print(f"Upscaled image saved {dst_image_path}")
|
26 |
+
elif upscale_mode == UpscaleMode.aura_sr.value:
|
27 |
+
upscaled_img = upscale_aura_sr(src_image_path)
|
28 |
+
upscaled_img.save(dst_image_path)
|
29 |
+
print(f"Upscaled image saved {dst_image_path}")
|
30 |
+
else:
|
31 |
+
config.settings.lcm_diffusion_setting.strength = (
|
32 |
+
0.3 if config.settings.lcm_diffusion_setting.use_openvino else strength
|
33 |
+
)
|
34 |
+
config.settings.lcm_diffusion_setting.diffusion_task = (
|
35 |
+
DiffusionTask.image_to_image.value
|
36 |
+
)
|
37 |
+
|
38 |
+
generate_upscaled_image(
|
39 |
+
config.settings,
|
40 |
+
src_image_path,
|
41 |
+
config.settings.lcm_diffusion_setting.strength,
|
42 |
+
upscale_settings=None,
|
43 |
+
context=context,
|
44 |
+
tile_overlap=(
|
45 |
+
32 if config.settings.lcm_diffusion_setting.use_openvino else 16
|
46 |
+
),
|
47 |
+
output_path=dst_image_path,
|
48 |
+
image_format=config.settings.generated_images.format,
|
49 |
+
)
|
50 |
+
print(f"Upscaled image saved {dst_image_path}")
|
51 |
+
|
52 |
+
return [Image.open(dst_image_path)]
|