Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """ | |
| Projected discriminator architecture from | |
| "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.spectral_norm import SpectralNorm | |
| from torchvision.transforms import RandomCrop, Normalize | |
| import timm | |
| from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
| from ADD.th_utils import misc | |
| from models.shared import ResidualBlock, FullyConnectedLayer | |
| from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone | |
| from models.DiffAugment import DiffAugment | |
| from ADD.utils.util_net import reload_model_ | |
| from functools import partial | |
| class SpectralConv1d(nn.Conv1d): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) | |
| class BatchNormLocal(nn.Module): | |
| def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5): | |
| super().__init__() | |
| self.virtual_bs = virtual_bs | |
| self.eps = eps | |
| self.affine = affine | |
| if self.affine: | |
| self.weight = nn.Parameter(torch.ones(num_features)) | |
| self.bias = nn.Parameter(torch.zeros(num_features)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| shape = x.size() | |
| # Reshape batch into groups. | |
| G = np.ceil(x.size(0)/self.virtual_bs).astype(int) | |
| x = x.view(G, -1, x.size(-2), x.size(-1)) | |
| # Calculate stats. | |
| mean = x.mean([1, 3], keepdim=True) | |
| var = x.var([1, 3], keepdim=True, unbiased=False) | |
| x = (x - mean) / (torch.sqrt(var + self.eps)) | |
| if self.affine: | |
| x = x * self.weight[None, :, None] + self.bias[None, :, None] | |
| return x.view(shape) | |
| def make_block(channels: int, kernel_size: int) -> nn.Module: | |
| return nn.Sequential( | |
| SpectralConv1d( | |
| channels, | |
| channels, | |
| kernel_size = kernel_size, | |
| padding = kernel_size//2, | |
| padding_mode = 'circular', | |
| ), | |
| #BatchNormLocal(channels), | |
| nn.GroupNorm(4, channels), | |
| nn.LeakyReLU(0.2, True), | |
| ) | |
| class DiscHead(nn.Module): | |
| def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64): | |
| super().__init__() | |
| self.channels = channels | |
| self.c_dim = c_dim | |
| self.cmap_dim = cmap_dim | |
| self.main = nn.Sequential( | |
| make_block(channels, kernel_size=1), | |
| ResidualBlock(make_block(channels, kernel_size=9)) | |
| ) | |
| if self.c_dim > 0: | |
| self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim) | |
| self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0) | |
| else: | |
| self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) | |
| def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| h = self.main(x) | |
| out = self.cls(h) | |
| if self.c_dim > 0: | |
| cmap = self.cmapper(c).unsqueeze(-1) | |
| out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) | |
| return out | |
| class DINO(torch.nn.Module): | |
| def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True): | |
| super().__init__() | |
| self.n_hooks = len(hooks) + int(hook_patch) | |
| self.model = make_vit_backbone( | |
| timm.create_model('vit_small_patch16_224.dino', pretrained=False), | |
| patch_size=[16,16], hooks=hooks, hook_patch=hook_patch, | |
| ) | |
| reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth')) | |
| self.model = self.model.eval().requires_grad_(False) | |
| self.img_resolution = self.model.model.patch_embed.img_size[0] | |
| self.embed_dim = self.model.model.embed_dim | |
| self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| ''' input: x in [0, 1]; output: dict of activations ''' | |
| x = F.interpolate(x, self.img_resolution, mode='area') | |
| x = self.norm(x) | |
| features = forward_vit(self.model, x) | |
| return features | |
| class ProjectedDiscriminator(nn.Module): | |
| def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5): | |
| super().__init__() | |
| self.c_dim = c_dim | |
| self.diffaug = diffaug | |
| self.p_crop = p_crop | |
| self.dino = DINO() | |
| heads = [] | |
| for i in range(self.dino.n_hooks): | |
| heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)], | |
| self.heads = nn.ModuleDict(heads) | |
| def train(self, mode: bool = True): | |
| self.dino = self.dino.train(False) | |
| self.heads = self.heads.train(mode) | |
| return self | |
| def eval(self): | |
| return self.train(False) | |
| def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| # Apply augmentation (x in [-1, 1]). | |
| if self.diffaug: | |
| x = DiffAugment(x, policy='translation,cutout') | |
| # Transform to [0, 1]. | |
| x = x.add(1).div(2) | |
| # Take crops with probablity p_crop if the image is larger. | |
| if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop: | |
| x = RandomCrop(self.dino.img_resolution)(x) | |
| # Forward pass through DINO ViT. | |
| features = self.dino(x) | |
| # Apply discriminator heads. | |
| logits = [] | |
| for k, head in self.heads.items(): | |
| features[k].requires_grad_(True) | |
| logits.append(head(features[k], c).view(x.size(0), -1)) | |
| #logits = torch.cat(logits, dim=1) | |
| return logits, features | |