Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,331 Bytes
6ecc7d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""
Projected discriminator architecture from
"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import SpectralNorm
from torchvision.transforms import RandomCrop, Normalize
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ADD.th_utils import misc
from models.shared import ResidualBlock, FullyConnectedLayer
from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone
from models.DiffAugment import DiffAugment
from ADD.utils.util_net import reload_model_
from functools import partial
class SpectralConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)
class BatchNormLocal(nn.Module):
def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):
super().__init__()
self.virtual_bs = virtual_bs
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
# Reshape batch into groups.
G = np.ceil(x.size(0)/self.virtual_bs).astype(int)
x = x.view(G, -1, x.size(-2), x.size(-1))
# Calculate stats.
mean = x.mean([1, 3], keepdim=True)
var = x.var([1, 3], keepdim=True, unbiased=False)
x = (x - mean) / (torch.sqrt(var + self.eps))
if self.affine:
x = x * self.weight[None, :, None] + self.bias[None, :, None]
return x.view(shape)
def make_block(channels: int, kernel_size: int) -> nn.Module:
return nn.Sequential(
SpectralConv1d(
channels,
channels,
kernel_size = kernel_size,
padding = kernel_size//2,
padding_mode = 'circular',
),
#BatchNormLocal(channels),
nn.GroupNorm(4, channels),
nn.LeakyReLU(0.2, True),
)
class DiscHead(nn.Module):
def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
super().__init__()
self.channels = channels
self.c_dim = c_dim
self.cmap_dim = cmap_dim
self.main = nn.Sequential(
make_block(channels, kernel_size=1),
ResidualBlock(make_block(channels, kernel_size=9))
)
if self.c_dim > 0:
self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)
self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
else:
self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
h = self.main(x)
out = self.cls(h)
if self.c_dim > 0:
cmap = self.cmapper(c).unsqueeze(-1)
out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
return out
class DINO(torch.nn.Module):
def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):
super().__init__()
self.n_hooks = len(hooks) + int(hook_patch)
self.model = make_vit_backbone(
timm.create_model('vit_small_patch16_224.dino', pretrained=False),
patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
)
reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))
self.model = self.model.eval().requires_grad_(False)
self.img_resolution = self.model.model.patch_embed.img_size[0]
self.embed_dim = self.model.model.embed_dim
self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
def forward(self, x: torch.Tensor) -> torch.Tensor:
''' input: x in [0, 1]; output: dict of activations '''
x = F.interpolate(x, self.img_resolution, mode='area')
x = self.norm(x)
features = forward_vit(self.model, x)
return features
class ProjectedDiscriminator(nn.Module):
def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):
super().__init__()
self.c_dim = c_dim
self.diffaug = diffaug
self.p_crop = p_crop
self.dino = DINO()
heads = []
for i in range(self.dino.n_hooks):
heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],
self.heads = nn.ModuleDict(heads)
def train(self, mode: bool = True):
self.dino = self.dino.train(False)
self.heads = self.heads.train(mode)
return self
def eval(self):
return self.train(False)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# Apply augmentation (x in [-1, 1]).
if self.diffaug:
x = DiffAugment(x, policy='translation,cutout')
# Transform to [0, 1].
x = x.add(1).div(2)
# Take crops with probablity p_crop if the image is larger.
if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:
x = RandomCrop(self.dino.img_resolution)(x)
# Forward pass through DINO ViT.
features = self.dino(x)
# Apply discriminator heads.
logits = []
for k, head in self.heads.items():
features[k].requires_grad_(True)
logits.append(head(features[k], c).view(x.size(0), -1))
#logits = torch.cat(logits, dim=1)
return logits, features
|