NightRaven109's picture
Upload 73 files
6ecc7d4 verified
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""
Projected discriminator architecture from
"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.spectral_norm import SpectralNorm
from torchvision.transforms import RandomCrop, Normalize
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ADD.th_utils import misc
from models.shared import ResidualBlock, FullyConnectedLayer
from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone
from models.DiffAugment import DiffAugment
from ADD.utils.util_net import reload_model_
from functools import partial
class SpectralConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)
class BatchNormLocal(nn.Module):
def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):
super().__init__()
self.virtual_bs = virtual_bs
self.eps = eps
self.affine = affine
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = x.size()
# Reshape batch into groups.
G = np.ceil(x.size(0)/self.virtual_bs).astype(int)
x = x.view(G, -1, x.size(-2), x.size(-1))
# Calculate stats.
mean = x.mean([1, 3], keepdim=True)
var = x.var([1, 3], keepdim=True, unbiased=False)
x = (x - mean) / (torch.sqrt(var + self.eps))
if self.affine:
x = x * self.weight[None, :, None] + self.bias[None, :, None]
return x.view(shape)
def make_block(channels: int, kernel_size: int) -> nn.Module:
return nn.Sequential(
SpectralConv1d(
channels,
channels,
kernel_size = kernel_size,
padding = kernel_size//2,
padding_mode = 'circular',
),
#BatchNormLocal(channels),
nn.GroupNorm(4, channels),
nn.LeakyReLU(0.2, True),
)
class DiscHead(nn.Module):
def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
super().__init__()
self.channels = channels
self.c_dim = c_dim
self.cmap_dim = cmap_dim
self.main = nn.Sequential(
make_block(channels, kernel_size=1),
ResidualBlock(make_block(channels, kernel_size=9))
)
if self.c_dim > 0:
self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)
self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
else:
self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
h = self.main(x)
out = self.cls(h)
if self.c_dim > 0:
cmap = self.cmapper(c).unsqueeze(-1)
out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
return out
class DINO(torch.nn.Module):
def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):
super().__init__()
self.n_hooks = len(hooks) + int(hook_patch)
self.model = make_vit_backbone(
timm.create_model('vit_small_patch16_224.dino', pretrained=False),
patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
)
reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))
self.model = self.model.eval().requires_grad_(False)
self.img_resolution = self.model.model.patch_embed.img_size[0]
self.embed_dim = self.model.model.embed_dim
self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
def forward(self, x: torch.Tensor) -> torch.Tensor:
''' input: x in [0, 1]; output: dict of activations '''
x = F.interpolate(x, self.img_resolution, mode='area')
x = self.norm(x)
features = forward_vit(self.model, x)
return features
class ProjectedDiscriminator(nn.Module):
def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):
super().__init__()
self.c_dim = c_dim
self.diffaug = diffaug
self.p_crop = p_crop
self.dino = DINO()
heads = []
for i in range(self.dino.n_hooks):
heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],
self.heads = nn.ModuleDict(heads)
def train(self, mode: bool = True):
self.dino = self.dino.train(False)
self.heads = self.heads.train(mode)
return self
def eval(self):
return self.train(False)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# Apply augmentation (x in [-1, 1]).
if self.diffaug:
x = DiffAugment(x, policy='translation,cutout')
# Transform to [0, 1].
x = x.add(1).div(2)
# Take crops with probablity p_crop if the image is larger.
if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:
x = RandomCrop(self.dino.img_resolution)(x)
# Forward pass through DINO ViT.
features = self.dino(x)
# Apply discriminator heads.
logits = []
for k, head in self.heads.items():
features[k].requires_grad_(True)
logits.append(head(features[k], c).view(x.size(0), -1))
#logits = torch.cat(logits, dim=1)
return logits, features