Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
""" | |
Projected discriminator architecture from | |
"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.spectral_norm import SpectralNorm | |
from torchvision.transforms import RandomCrop, Normalize | |
import timm | |
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from ADD.th_utils import misc | |
from models.shared import ResidualBlock, FullyConnectedLayer | |
from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone | |
from models.DiffAugment import DiffAugment | |
from ADD.utils.util_net import reload_model_ | |
from functools import partial | |
class SpectralConv1d(nn.Conv1d): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) | |
class BatchNormLocal(nn.Module): | |
def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5): | |
super().__init__() | |
self.virtual_bs = virtual_bs | |
self.eps = eps | |
self.affine = affine | |
if self.affine: | |
self.weight = nn.Parameter(torch.ones(num_features)) | |
self.bias = nn.Parameter(torch.zeros(num_features)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
shape = x.size() | |
# Reshape batch into groups. | |
G = np.ceil(x.size(0)/self.virtual_bs).astype(int) | |
x = x.view(G, -1, x.size(-2), x.size(-1)) | |
# Calculate stats. | |
mean = x.mean([1, 3], keepdim=True) | |
var = x.var([1, 3], keepdim=True, unbiased=False) | |
x = (x - mean) / (torch.sqrt(var + self.eps)) | |
if self.affine: | |
x = x * self.weight[None, :, None] + self.bias[None, :, None] | |
return x.view(shape) | |
def make_block(channels: int, kernel_size: int) -> nn.Module: | |
return nn.Sequential( | |
SpectralConv1d( | |
channels, | |
channels, | |
kernel_size = kernel_size, | |
padding = kernel_size//2, | |
padding_mode = 'circular', | |
), | |
#BatchNormLocal(channels), | |
nn.GroupNorm(4, channels), | |
nn.LeakyReLU(0.2, True), | |
) | |
class DiscHead(nn.Module): | |
def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64): | |
super().__init__() | |
self.channels = channels | |
self.c_dim = c_dim | |
self.cmap_dim = cmap_dim | |
self.main = nn.Sequential( | |
make_block(channels, kernel_size=1), | |
ResidualBlock(make_block(channels, kernel_size=9)) | |
) | |
if self.c_dim > 0: | |
self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim) | |
self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0) | |
else: | |
self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) | |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
h = self.main(x) | |
out = self.cls(h) | |
if self.c_dim > 0: | |
cmap = self.cmapper(c).unsqueeze(-1) | |
out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) | |
return out | |
class DINO(torch.nn.Module): | |
def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True): | |
super().__init__() | |
self.n_hooks = len(hooks) + int(hook_patch) | |
self.model = make_vit_backbone( | |
timm.create_model('vit_small_patch16_224.dino', pretrained=False), | |
patch_size=[16,16], hooks=hooks, hook_patch=hook_patch, | |
) | |
reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth')) | |
self.model = self.model.eval().requires_grad_(False) | |
self.img_resolution = self.model.model.patch_embed.img_size[0] | |
self.embed_dim = self.model.model.embed_dim | |
self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
''' input: x in [0, 1]; output: dict of activations ''' | |
x = F.interpolate(x, self.img_resolution, mode='area') | |
x = self.norm(x) | |
features = forward_vit(self.model, x) | |
return features | |
class ProjectedDiscriminator(nn.Module): | |
def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5): | |
super().__init__() | |
self.c_dim = c_dim | |
self.diffaug = diffaug | |
self.p_crop = p_crop | |
self.dino = DINO() | |
heads = [] | |
for i in range(self.dino.n_hooks): | |
heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)], | |
self.heads = nn.ModuleDict(heads) | |
def train(self, mode: bool = True): | |
self.dino = self.dino.train(False) | |
self.heads = self.heads.train(mode) | |
return self | |
def eval(self): | |
return self.train(False) | |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
# Apply augmentation (x in [-1, 1]). | |
if self.diffaug: | |
x = DiffAugment(x, policy='translation,cutout') | |
# Transform to [0, 1]. | |
x = x.add(1).div(2) | |
# Take crops with probablity p_crop if the image is larger. | |
if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop: | |
x = RandomCrop(self.dino.img_resolution)(x) | |
# Forward pass through DINO ViT. | |
features = self.dino(x) | |
# Apply discriminator heads. | |
logits = [] | |
for k, head in self.heads.items(): | |
features[k].requires_grad_(True) | |
logits.append(head(features[k], c).view(x.size(0), -1)) | |
#logits = torch.cat(logits, dim=1) | |
return logits, features | |