Spaces:
Runtime error
Runtime error
add discriminator.
Browse files- README.md +1 -1
- basicsr/archs/vqgan_arch.py +46 -1
README.md
CHANGED
|
@@ -16,7 +16,7 @@ S-Lab, Nanyang Technological University
|
|
| 16 |
<img src="assets/network.jpg" width="800px"/>
|
| 17 |
|
| 18 |
|
| 19 |
-
:star: If CodeFormer is helpful to your
|
| 20 |
|
| 21 |
### Updates
|
| 22 |
|
|
|
|
| 16 |
<img src="assets/network.jpg" width="800px"/>
|
| 17 |
|
| 18 |
|
| 19 |
+
:star: If CodeFormer is helpful to your projects, please help star this repo. Thanks! :hugs:
|
| 20 |
|
| 21 |
### Updates
|
| 22 |
|
basicsr/archs/vqgan_arch.py
CHANGED
|
@@ -387,4 +387,49 @@ class VQAutoEncoder(nn.Module):
|
|
| 387 |
x = self.encoder(x)
|
| 388 |
quant, codebook_loss, quant_stats = self.quantize(x)
|
| 389 |
x = self.generator(quant)
|
| 390 |
-
return x, codebook_loss, quant_stats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
x = self.encoder(x)
|
| 388 |
quant, codebook_loss, quant_stats = self.quantize(x)
|
| 389 |
x = self.generator(quant)
|
| 390 |
+
return x, codebook_loss, quant_stats
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# patch based discriminator
|
| 395 |
+
@ARCH_REGISTRY.register()
|
| 396 |
+
class VQGANDiscriminator(nn.Module):
|
| 397 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
| 398 |
+
super().__init__()
|
| 399 |
+
|
| 400 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
| 401 |
+
ndf_mult = 1
|
| 402 |
+
ndf_mult_prev = 1
|
| 403 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 404 |
+
ndf_mult_prev = ndf_mult
|
| 405 |
+
ndf_mult = min(2 ** n, 8)
|
| 406 |
+
layers += [
|
| 407 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
| 408 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 409 |
+
nn.LeakyReLU(0.2, True)
|
| 410 |
+
]
|
| 411 |
+
|
| 412 |
+
ndf_mult_prev = ndf_mult
|
| 413 |
+
ndf_mult = min(2 ** n_layers, 8)
|
| 414 |
+
|
| 415 |
+
layers += [
|
| 416 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
| 417 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 418 |
+
nn.LeakyReLU(0.2, True)
|
| 419 |
+
]
|
| 420 |
+
|
| 421 |
+
layers += [
|
| 422 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
| 423 |
+
self.main = nn.Sequential(*layers)
|
| 424 |
+
|
| 425 |
+
if model_path is not None:
|
| 426 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
| 427 |
+
if 'params_d' in chkpt:
|
| 428 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
| 429 |
+
elif 'params' in chkpt:
|
| 430 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
| 431 |
+
else:
|
| 432 |
+
raise ValueError(f'Wrong params!')
|
| 433 |
+
|
| 434 |
+
def forward(self, x):
|
| 435 |
+
return self.main(x)
|