Luisgust commited on
Commit
081dbbf
·
verified ·
1 Parent(s): a10ee1f

Create vtoonify/model/stylegan/model.py

Browse files
Files changed (1) hide show
  1. vtoonify/model/stylegan/model.py +709 -0
vtoonify/model/stylegan/model.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.autograd import Function
10
+
11
+ from model.stylegan.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12
+
13
+ class PixelNorm(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, input):
18
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
19
+
20
+
21
+ def make_kernel(k):
22
+ k = torch.tensor(k, dtype=torch.float32)
23
+
24
+ if k.ndim == 1:
25
+ k = k[None, :] * k[:, None]
26
+
27
+ k /= k.sum()
28
+
29
+ return k
30
+
31
+
32
+ class Upsample(nn.Module):
33
+ def __init__(self, kernel, factor=2):
34
+ super().__init__()
35
+
36
+ self.factor = factor
37
+ kernel = make_kernel(kernel) * (factor ** 2)
38
+ self.register_buffer("kernel", kernel)
39
+
40
+ p = kernel.shape[0] - factor
41
+
42
+ pad0 = (p + 1) // 2 + factor - 1
43
+ pad1 = p // 2
44
+
45
+ self.pad = (pad0, pad1)
46
+
47
+ def forward(self, input):
48
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
49
+
50
+ return out
51
+
52
+
53
+ class Downsample(nn.Module):
54
+ def __init__(self, kernel, factor=2):
55
+ super().__init__()
56
+
57
+ self.factor = factor
58
+ kernel = make_kernel(kernel)
59
+ self.register_buffer("kernel", kernel)
60
+
61
+ p = kernel.shape[0] - factor
62
+
63
+ pad0 = (p + 1) // 2
64
+ pad1 = p // 2
65
+
66
+ self.pad = (pad0, pad1)
67
+
68
+ def forward(self, input):
69
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
70
+
71
+ return out
72
+
73
+
74
+ class Blur(nn.Module):
75
+ def __init__(self, kernel, pad, upsample_factor=1):
76
+ super().__init__()
77
+
78
+ kernel = make_kernel(kernel)
79
+
80
+ if upsample_factor > 1:
81
+ kernel = kernel * (upsample_factor ** 2)
82
+
83
+ self.register_buffer("kernel", kernel)
84
+
85
+ self.pad = pad
86
+
87
+ def forward(self, input):
88
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
89
+
90
+ return out
91
+
92
+
93
+ class EqualConv2d(nn.Module):
94
+ def __init__(
95
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dilation=1 ## modified
96
+ ):
97
+ super().__init__()
98
+
99
+ self.weight = nn.Parameter(
100
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
101
+ )
102
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
103
+
104
+ self.stride = stride
105
+ self.padding = padding
106
+ self.dilation = dilation ## modified
107
+
108
+ if bias:
109
+ self.bias = nn.Parameter(torch.zeros(out_channel))
110
+
111
+ else:
112
+ self.bias = None
113
+
114
+ def forward(self, input):
115
+ out = conv2d_gradfix.conv2d(
116
+ input,
117
+ self.weight * self.scale,
118
+ bias=self.bias,
119
+ stride=self.stride,
120
+ padding=self.padding,
121
+ dilation=self.dilation, ## modified
122
+ )
123
+
124
+ return out
125
+
126
+ def __repr__(self):
127
+ return (
128
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
129
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding}, dilation={self.dilation})" ## modified
130
+ )
131
+
132
+
133
+ class EqualLinear(nn.Module):
134
+ def __init__(
135
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
136
+ ):
137
+ super().__init__()
138
+
139
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
140
+
141
+ if bias:
142
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
143
+
144
+ else:
145
+ self.bias = None
146
+
147
+ self.activation = activation
148
+
149
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
150
+ self.lr_mul = lr_mul
151
+
152
+ def forward(self, input):
153
+ if self.activation:
154
+ out = F.linear(input, self.weight * self.scale)
155
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
156
+
157
+ else:
158
+ out = F.linear(
159
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
160
+ )
161
+
162
+ return out
163
+
164
+ def __repr__(self):
165
+ return (
166
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
167
+ )
168
+
169
+
170
+ class ModulatedConv2d(nn.Module):
171
+ def __init__(
172
+ self,
173
+ in_channel,
174
+ out_channel,
175
+ kernel_size,
176
+ style_dim,
177
+ demodulate=True,
178
+ upsample=False,
179
+ downsample=False,
180
+ blur_kernel=[1, 3, 3, 1],
181
+ fused=True,
182
+ ):
183
+ super().__init__()
184
+
185
+ self.eps = 1e-8
186
+ self.kernel_size = kernel_size
187
+ self.in_channel = in_channel
188
+ self.out_channel = out_channel
189
+ self.upsample = upsample
190
+ self.downsample = downsample
191
+
192
+ if upsample:
193
+ factor = 2
194
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
195
+ pad0 = (p + 1) // 2 + factor - 1
196
+ pad1 = p // 2 + 1
197
+
198
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
199
+
200
+ if downsample:
201
+ factor = 2
202
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
203
+ pad0 = (p + 1) // 2
204
+ pad1 = p // 2
205
+
206
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
207
+
208
+ fan_in = in_channel * kernel_size ** 2
209
+ self.scale = 1 / math.sqrt(fan_in)
210
+ self.padding = kernel_size // 2
211
+
212
+ self.weight = nn.Parameter(
213
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
214
+ )
215
+
216
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
217
+
218
+ self.demodulate = demodulate
219
+ self.fused = fused
220
+
221
+ def __repr__(self):
222
+ return (
223
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
224
+ f"upsample={self.upsample}, downsample={self.downsample})"
225
+ )
226
+
227
+ def forward(self, input, style, externalweight=None):
228
+ batch, in_channel, height, width = input.shape
229
+
230
+ if not self.fused:
231
+ weight = self.scale * self.weight.squeeze(0)
232
+ style = self.modulation(style)
233
+
234
+ if self.demodulate:
235
+ w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
236
+ dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
237
+
238
+ input = input * style.reshape(batch, in_channel, 1, 1)
239
+
240
+ if self.upsample:
241
+ weight = weight.transpose(0, 1)
242
+ out = conv2d_gradfix.conv_transpose2d(
243
+ input, weight, padding=0, stride=2
244
+ )
245
+ out = self.blur(out)
246
+
247
+ elif self.downsample:
248
+ input = self.blur(input)
249
+ out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
250
+
251
+ else:
252
+ out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
253
+
254
+ if self.demodulate:
255
+ out = out * dcoefs.view(batch, -1, 1, 1)
256
+
257
+ return out
258
+
259
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
260
+ if externalweight is None:
261
+ weight = self.scale * self.weight * style
262
+ else:
263
+ weight = self.scale * (self.weight + externalweight) * style
264
+
265
+ if self.demodulate:
266
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
267
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
268
+
269
+ weight = weight.view(
270
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
271
+ )
272
+
273
+ if self.upsample:
274
+ input = input.view(1, batch * in_channel, height, width)
275
+ weight = weight.view(
276
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
277
+ )
278
+ weight = weight.transpose(1, 2).reshape(
279
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
280
+ )
281
+ out = conv2d_gradfix.conv_transpose2d(
282
+ input, weight, padding=0, stride=2, groups=batch
283
+ )
284
+ _, _, height, width = out.shape
285
+ out = out.view(batch, self.out_channel, height, width)
286
+ out = self.blur(out)
287
+
288
+ elif self.downsample:
289
+ input = self.blur(input)
290
+ _, _, height, width = input.shape
291
+ input = input.view(1, batch * in_channel, height, width)
292
+ out = conv2d_gradfix.conv2d(
293
+ input, weight, padding=0, stride=2, groups=batch
294
+ )
295
+ _, _, height, width = out.shape
296
+ out = out.view(batch, self.out_channel, height, width)
297
+
298
+ else:
299
+ input = input.view(1, batch * in_channel, height, width)
300
+ out = conv2d_gradfix.conv2d(
301
+ input, weight, padding=self.padding, groups=batch
302
+ )
303
+ _, _, height, width = out.shape
304
+ out = out.view(batch, self.out_channel, height, width)
305
+
306
+ return out
307
+
308
+
309
+ class NoiseInjection(nn.Module):
310
+ def __init__(self):
311
+ super().__init__()
312
+
313
+ self.weight = nn.Parameter(torch.zeros(1))
314
+
315
+ def forward(self, image, noise=None):
316
+ if noise is None:
317
+ batch, _, height, width = image.shape
318
+ noise = image.new_empty(batch, 1, height, width).normal_()
319
+
320
+ return image + self.weight * noise
321
+
322
+
323
+ class ConstantInput(nn.Module):
324
+ def __init__(self, channel, size=4):
325
+ super().__init__()
326
+
327
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
328
+
329
+ def forward(self, input):
330
+ batch = input.shape[0]
331
+ out = self.input.repeat(batch, 1, 1, 1)
332
+
333
+ return out
334
+
335
+
336
+ class StyledConv(nn.Module):
337
+ def __init__(
338
+ self,
339
+ in_channel,
340
+ out_channel,
341
+ kernel_size,
342
+ style_dim,
343
+ upsample=False,
344
+ blur_kernel=[1, 3, 3, 1],
345
+ demodulate=True,
346
+ ):
347
+ super().__init__()
348
+
349
+ self.conv = ModulatedConv2d(
350
+ in_channel,
351
+ out_channel,
352
+ kernel_size,
353
+ style_dim,
354
+ upsample=upsample,
355
+ blur_kernel=blur_kernel,
356
+ demodulate=demodulate,
357
+ )
358
+
359
+ self.noise = NoiseInjection()
360
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
361
+ # self.activate = ScaledLeakyReLU(0.2)
362
+ self.activate = FusedLeakyReLU(out_channel)
363
+
364
+ def forward(self, input, style, noise=None, externalweight=None):
365
+ out = self.conv(input, style, externalweight)
366
+ out = self.noise(out, noise=noise)
367
+ # out = out + self.bias
368
+ out = self.activate(out)
369
+
370
+ return out
371
+
372
+
373
+ class ToRGB(nn.Module):
374
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
375
+ super().__init__()
376
+
377
+ if upsample:
378
+ self.upsample = Upsample(blur_kernel)
379
+
380
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
381
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
382
+
383
+ def forward(self, input, style, skip=None, externalweight=None):
384
+ out = self.conv(input, style, externalweight)
385
+ out = out + self.bias
386
+
387
+ if skip is not None:
388
+ skip = self.upsample(skip)
389
+
390
+ out = out + skip
391
+
392
+ return out
393
+
394
+
395
+ class Generator(nn.Module):
396
+ def __init__(
397
+ self,
398
+ size,
399
+ style_dim,
400
+ n_mlp,
401
+ channel_multiplier=2,
402
+ blur_kernel=[1, 3, 3, 1],
403
+ lr_mlp=0.01,
404
+ ):
405
+ super().__init__()
406
+
407
+ self.size = size
408
+
409
+ self.style_dim = style_dim
410
+
411
+ layers = [PixelNorm()]
412
+
413
+ for i in range(n_mlp):
414
+ layers.append(
415
+ EqualLinear(
416
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
417
+ )
418
+ )
419
+
420
+ self.style = nn.Sequential(*layers)
421
+
422
+ self.channels = {
423
+ 4: 512,
424
+ 8: 512,
425
+ 16: 512,
426
+ 32: 512,
427
+ 64: 256 * channel_multiplier,
428
+ 128: 128 * channel_multiplier,
429
+ 256: 64 * channel_multiplier,
430
+ 512: 32 * channel_multiplier,
431
+ 1024: 16 * channel_multiplier,
432
+ }
433
+
434
+ self.input = ConstantInput(self.channels[4])
435
+ self.conv1 = StyledConv(
436
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
437
+ )
438
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
439
+
440
+ self.log_size = int(math.log(size, 2))
441
+ self.num_layers = (self.log_size - 2) * 2 + 1
442
+
443
+ self.convs = nn.ModuleList()
444
+ self.upsamples = nn.ModuleList()
445
+ self.to_rgbs = nn.ModuleList()
446
+ self.noises = nn.Module()
447
+
448
+ in_channel = self.channels[4]
449
+
450
+ for layer_idx in range(self.num_layers):
451
+ res = (layer_idx + 5) // 2
452
+ shape = [1, 1, 2 ** res, 2 ** res]
453
+ self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
454
+
455
+ for i in range(3, self.log_size + 1):
456
+ out_channel = self.channels[2 ** i]
457
+
458
+ self.convs.append(
459
+ StyledConv(
460
+ in_channel,
461
+ out_channel,
462
+ 3,
463
+ style_dim,
464
+ upsample=True,
465
+ blur_kernel=blur_kernel,
466
+ )
467
+ )
468
+
469
+ self.convs.append(
470
+ StyledConv(
471
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
472
+ )
473
+ )
474
+
475
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
476
+
477
+ in_channel = out_channel
478
+
479
+ self.n_latent = self.log_size * 2 - 2
480
+
481
+ def make_noise(self):
482
+ device = self.input.input.device
483
+
484
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
485
+
486
+ for i in range(3, self.log_size + 1):
487
+ for _ in range(2):
488
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
489
+
490
+ return noises
491
+
492
+ def mean_latent(self, n_latent):
493
+ latent_in = torch.randn(
494
+ n_latent, self.style_dim, device=self.input.input.device
495
+ )
496
+ latent = self.style(latent_in).mean(0, keepdim=True)
497
+
498
+ return latent
499
+
500
+ def get_latent(self, input):
501
+ return self.style(input)
502
+
503
+ def forward(
504
+ self,
505
+ styles,
506
+ return_latents=False,
507
+ inject_index=None,
508
+ truncation=1,
509
+ truncation_latent=None,
510
+ input_is_latent=False,
511
+ noise=None,
512
+ randomize_noise=True,
513
+ z_plus_latent=False,
514
+ return_feature_ind=999,
515
+ ):
516
+ if not input_is_latent:
517
+ if not z_plus_latent:
518
+ styles = [self.style(s) for s in styles]
519
+ else:
520
+ styles_ = []
521
+ for s in styles:
522
+ style_ = []
523
+ for i in range(s.shape[1]):
524
+ style_.append(self.style(s[:,i]).unsqueeze(1))
525
+ styles_.append(torch.cat(style_,dim=1))
526
+ styles = styles_
527
+
528
+ if noise is None:
529
+ if randomize_noise:
530
+ noise = [None] * self.num_layers
531
+ else:
532
+ noise = [
533
+ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
534
+ ]
535
+
536
+ if truncation < 1:
537
+ style_t = []
538
+
539
+ for style in styles:
540
+ style_t.append(
541
+ truncation_latent + truncation * (style - truncation_latent)
542
+ )
543
+
544
+ styles = style_t
545
+
546
+ if len(styles) < 2:
547
+ inject_index = self.n_latent
548
+
549
+ if styles[0].ndim < 3:
550
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
551
+
552
+ else:
553
+ latent = styles[0]
554
+
555
+ else:
556
+ if inject_index is None:
557
+ inject_index = random.randint(1, self.n_latent - 1)
558
+
559
+ if styles[0].ndim < 3:
560
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
561
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
562
+
563
+ latent = torch.cat([latent, latent2], 1)
564
+ else:
565
+ latent = torch.cat([styles[0][:,0:inject_index], styles[1][:,inject_index:]], 1)
566
+
567
+ out = self.input(latent)
568
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
569
+
570
+ skip = self.to_rgb1(out, latent[:, 1])
571
+
572
+ i = 1
573
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
574
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
575
+ ):
576
+ out = conv1(out, latent[:, i], noise=noise1)
577
+ out = conv2(out, latent[:, i + 1], noise=noise2)
578
+ skip = to_rgb(out, latent[:, i + 2], skip)
579
+
580
+ i += 2
581
+ if i > return_feature_ind:
582
+ return out, skip
583
+
584
+ image = skip
585
+
586
+ if return_latents:
587
+ return image, latent
588
+
589
+ else:
590
+ return image, None
591
+
592
+
593
+ class ConvLayer(nn.Sequential):
594
+ def __init__(
595
+ self,
596
+ in_channel,
597
+ out_channel,
598
+ kernel_size,
599
+ downsample=False,
600
+ blur_kernel=[1, 3, 3, 1],
601
+ bias=True,
602
+ activate=True,
603
+ dilation=1, ## modified
604
+ ):
605
+ layers = []
606
+
607
+ if downsample:
608
+ factor = 2
609
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
610
+ pad0 = (p + 1) // 2
611
+ pad1 = p // 2
612
+
613
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
614
+
615
+ stride = 2
616
+ self.padding = 0
617
+
618
+ else:
619
+ stride = 1
620
+ self.padding = kernel_size // 2 + dilation-1 ## modified
621
+
622
+ layers.append(
623
+ EqualConv2d(
624
+ in_channel,
625
+ out_channel,
626
+ kernel_size,
627
+ padding=self.padding,
628
+ stride=stride,
629
+ bias=bias and not activate,
630
+ dilation=dilation, ## modified
631
+ )
632
+ )
633
+
634
+ if activate:
635
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
636
+
637
+ super().__init__(*layers)
638
+
639
+
640
+ class ResBlock(nn.Module):
641
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
642
+ super().__init__()
643
+
644
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
645
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
646
+
647
+ self.skip = ConvLayer(
648
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
649
+ )
650
+
651
+ def forward(self, input):
652
+ out = self.conv1(input)
653
+ out = self.conv2(out)
654
+
655
+ skip = self.skip(input)
656
+ out = (out + skip) / math.sqrt(2)
657
+
658
+ return out
659
+
660
+
661
+ class Discriminator(nn.Module):
662
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
663
+ super().__init__()
664
+
665
+ channels = {
666
+ 4: 512,
667
+ 8: 512,
668
+ 16: 512,
669
+ 32: 512,
670
+ 64: 256 * channel_multiplier,
671
+ 128: 128 * channel_multiplier,
672
+ 256: 64 * channel_multiplier,
673
+ 512: 32 * channel_multiplier,
674
+ 1024: 16 * channel_multiplier,
675
+ }
676
+
677
+ convs = [ConvLayer(3, channels[size], 1)]
678
+
679
+ log_size = int(math.log(size, 2))
680
+
681
+ in_channel = channels[size]
682
+
683
+ for i in range(log_size, 2, -1):
684
+ out_channel = channels[2 ** (i - 1)]
685
+
686
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
687
+
688
+ in_channel = out_channel
689
+
690
+ self.convs = nn.Sequential(*convs)
691
+
692
+ self.stddev_group = 4
693
+ self.stddev_feat = 1
694
+
695
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
696
+ self.final_linear = nn.Sequential(
697
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
698
+ EqualLinear(channels[4], 1),
699
+ )
700
+
701
+ def forward(self, input):
702
+ out = self.convs(input)
703
+
704
+ batch, channel, height, width = out.shape
705
+ group = min(batch, self.stddev_group)
706
+ stddev = out.view(
707
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
708
+ )
709
+ stddev = torch.sqrt(stddev.var(0