|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
import re |
|
import os |
|
import random |
|
from pathlib import Path |
|
from types import SimpleNamespace |
|
from utils import download_ckpt |
|
from config import Config |
|
from netdissect import proggan, zdataset |
|
from . import biggan |
|
from . import stylegan |
|
from . import stylegan2 |
|
from abc import abstractmethod, ABC as AbstractBaseClass |
|
from functools import singledispatch |
|
|
|
class BaseModel(AbstractBaseClass, torch.nn.Module): |
|
|
|
|
|
def __init__(self, model_name, class_name): |
|
super(BaseModel, self).__init__() |
|
self.model_name = model_name |
|
self.outclass = class_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod |
|
def partial_forward(self, x, layer_name): |
|
pass |
|
|
|
|
|
@abstractmethod |
|
def sample_latent(self, n_samples=1, seed=None, truncation=None): |
|
pass |
|
|
|
|
|
|
|
def get_max_latents(self): |
|
return 1 |
|
|
|
|
|
|
|
def latent_space_name(self): |
|
return 'Z' |
|
|
|
def get_latent_shape(self): |
|
return tuple(self.sample_latent(1).shape) |
|
|
|
def get_latent_dims(self): |
|
return np.prod(self.get_latent_shape()) |
|
|
|
def set_output_class(self, new_class): |
|
self.outclass = new_class |
|
|
|
|
|
def forward(self, x): |
|
out = self.model.forward(x) |
|
return 0.5*(out+1) |
|
|
|
|
|
def sample_np(self, z=None, n_samples=1, seed=None): |
|
if z is None: |
|
z = self.sample_latent(n_samples, seed=seed) |
|
elif isinstance(z, list): |
|
z = [torch.tensor(l).to(self.device) if not torch.is_tensor(l) else l for l in z] |
|
elif not torch.is_tensor(z): |
|
z = torch.tensor(z).to(self.device) |
|
img = self.forward(z) |
|
img_np = img.permute(0, 2, 3, 1).cpu().detach().numpy() |
|
return np.clip(img_np, 0.0, 1.0).squeeze() |
|
|
|
|
|
def get_conditional_state(self, z): |
|
return None |
|
|
|
|
|
def set_conditional_state(self, z, c): |
|
return z |
|
|
|
def named_modules(self, *args, **kwargs): |
|
return self.model.named_modules(*args, **kwargs) |
|
|
|
|
|
class StyleGAN2(BaseModel): |
|
def __init__(self, device, class_name, truncation=1.0, use_w=False): |
|
super(StyleGAN2, self).__init__('StyleGAN2', class_name or 'ffhq') |
|
self.device = device |
|
self.truncation = truncation |
|
self.latent_avg = None |
|
self.w_primary = use_w |
|
|
|
|
|
configs = { |
|
|
|
'ffhq': 1024, |
|
'car': 512, |
|
'cat': 256, |
|
'church': 256, |
|
'horse': 256, |
|
|
|
'bedrooms': 256, |
|
'kitchen': 256, |
|
'places': 256, |
|
'lookbook': 512 |
|
} |
|
|
|
assert self.outclass in configs, \ |
|
f'Invalid StyleGAN2 class {self.outclass}, should be one of [{", ".join(configs.keys())}]' |
|
|
|
self.resolution = configs[self.outclass] |
|
self.name = f'StyleGAN2-{self.outclass}' |
|
self.has_latent_residual = True |
|
self.load_model() |
|
self.set_noise_seed(0) |
|
|
|
def latent_space_name(self): |
|
return 'W' if self.w_primary else 'Z' |
|
|
|
def use_w(self): |
|
self.w_primary = True |
|
|
|
def use_z(self): |
|
self.w_primary = False |
|
|
|
|
|
def download_checkpoint(self, outfile): |
|
checkpoints = { |
|
'horse': 'https://drive.google.com/uc?export=download&id=18SkqWAkgt0fIwDEf2pqeaenNi4OoCo-0', |
|
'ffhq': 'https://drive.google.com/uc?export=download&id=1FJRwzAkV-XWbxgTwxEmEACvuqF5DsBiV', |
|
'church': 'https://drive.google.com/uc?export=download&id=1HFM694112b_im01JT7wop0faftw9ty5g', |
|
'car': 'https://drive.google.com/uc?export=download&id=1iRoWclWVbDBAy5iXYZrQnKYSbZUqXI6y', |
|
'cat': 'https://drive.google.com/uc?export=download&id=15vJP8GDr0FlRYpE8gD7CdeEz2mXrQMgN', |
|
'places': 'https://drive.google.com/uc?export=download&id=1X8-wIH3aYKjgDZt4KMOtQzN1m4AlCVhm', |
|
'bedrooms': 'https://drive.google.com/uc?export=download&id=1nZTW7mjazs-qPhkmbsOLLA_6qws-eNQu', |
|
'kitchen': 'https://drive.google.com/uc?export=download&id=15dCpnZ1YLAnETAPB0FGmXwdBclbwMEkZ', |
|
'lookbook': 'https://drive.google.com/uc?export=download&id=1-F-RMkbHUv_S_k-_olh43mu5rDUMGYKe' |
|
} |
|
|
|
url = checkpoints[self.outclass] |
|
download_ckpt(url, outfile) |
|
|
|
def load_model(self): |
|
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') |
|
checkpoint = Path(checkpoint_root) / f'stylegan2/stylegan2_{self.outclass}_{self.resolution}.pt' |
|
|
|
self.model = stylegan2.Generator(self.resolution, 512, 8).to(self.device) |
|
|
|
if not checkpoint.is_file(): |
|
os.makedirs(checkpoint.parent, exist_ok=True) |
|
self.download_checkpoint(checkpoint) |
|
|
|
ckpt = torch.load(checkpoint) |
|
self.model.load_state_dict(ckpt['g_ema'], strict=False) |
|
self.latent_avg = 0 |
|
|
|
def sample_latent(self, n_samples=1, seed=None, truncation=None): |
|
if seed is None: |
|
seed = np.random.randint(np.iinfo(np.int32).max) |
|
|
|
rng = np.random.RandomState(seed) |
|
z = torch.from_numpy( |
|
rng.standard_normal(512 * n_samples) |
|
.reshape(n_samples, 512)).float().to(self.device) |
|
|
|
if self.w_primary: |
|
z = self.model.style(z) |
|
|
|
return z |
|
|
|
def get_max_latents(self): |
|
return self.model.n_latent |
|
|
|
def set_output_class(self, new_class): |
|
if self.outclass != new_class: |
|
raise RuntimeError('StyleGAN2: cannot change output class without reloading') |
|
|
|
def forward(self, x): |
|
x = x if isinstance(x, list) else [x] |
|
out, _ = self.model(x, noise=self.noise, |
|
truncation=self.truncation, truncation_latent=self.latent_avg, input_is_w=self.w_primary) |
|
return 0.5*(out+1) |
|
|
|
def partial_forward(self, x, layer_name): |
|
styles = x if isinstance(x, list) else [x] |
|
inject_index = None |
|
noise = self.noise |
|
|
|
if not self.w_primary: |
|
styles = [self.model.style(s) for s in styles] |
|
|
|
if len(styles) == 1: |
|
|
|
inject_index = self.model.n_latent |
|
latent = self.model.strided_style(styles[0].unsqueeze(1).repeat(1, inject_index, 1)) |
|
elif len(styles) == 2: |
|
|
|
if inject_index is None: |
|
inject_index = random.randint(1, self.model.n_latent - 1) |
|
|
|
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) |
|
latent2 = styles[1].unsqueeze(1).repeat(1, self.model.n_latent - inject_index, 1) |
|
|
|
latent = self.model.strided_style(torch.cat([latent, latent2], 1)) |
|
else: |
|
|
|
assert len(styles) == self.model.n_latent, f'Expected {self.model.n_latents} latents, got {len(styles)}' |
|
styles = torch.stack(styles, dim=1) |
|
latent = self.model.strided_style(styles) |
|
|
|
if 'style' in layer_name: |
|
return |
|
|
|
out = self.model.input(latent) |
|
if 'input' == layer_name: |
|
return |
|
|
|
out = self.model.conv1(out, latent[:, 0], noise=noise[0]) |
|
if 'conv1' in layer_name: |
|
return |
|
|
|
skip = self.model.to_rgb1(out, latent[:, 1]) |
|
if 'to_rgb1' in layer_name: |
|
return |
|
|
|
i = 1 |
|
noise_i = 1 |
|
|
|
for conv1, conv2, to_rgb in zip( |
|
self.model.convs[::2], self.model.convs[1::2], self.model.to_rgbs |
|
): |
|
out = conv1(out, latent[:, i], noise=noise[noise_i]) |
|
if f'convs.{i-1}' in layer_name: |
|
return |
|
|
|
out = conv2(out, latent[:, i + 1], noise=noise[noise_i + 1]) |
|
if f'convs.{i}' in layer_name: |
|
return |
|
|
|
skip = to_rgb(out, latent[:, i + 2], skip) |
|
if f'to_rgbs.{i//2}' in layer_name: |
|
return |
|
|
|
i += 2 |
|
noise_i += 2 |
|
|
|
image = skip |
|
|
|
raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') |
|
|
|
def set_noise_seed(self, seed): |
|
torch.manual_seed(seed) |
|
self.noise = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=self.device)] |
|
|
|
for i in range(3, self.model.log_size + 1): |
|
for _ in range(2): |
|
self.noise.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=self.device)) |
|
|
|
|
|
class StyleGAN(BaseModel): |
|
def __init__(self, device, class_name, truncation=1.0, use_w=False): |
|
super(StyleGAN, self).__init__('StyleGAN', class_name or 'ffhq') |
|
self.device = device |
|
self.w_primary = use_w |
|
|
|
configs = { |
|
|
|
'ffhq': 1024, |
|
'celebahq': 1024, |
|
'bedrooms': 256, |
|
'cars': 512, |
|
'cats': 256, |
|
|
|
|
|
'vases': 1024, |
|
'wikiart': 512, |
|
'fireworks': 512, |
|
'abstract': 512, |
|
'anime': 512, |
|
'ukiyo-e': 512, |
|
} |
|
|
|
assert self.outclass in configs, \ |
|
f'Invalid StyleGAN class {self.outclass}, should be one of [{", ".join(configs.keys())}]' |
|
|
|
self.resolution = configs[self.outclass] |
|
self.name = f'StyleGAN-{self.outclass}' |
|
self.has_latent_residual = True |
|
self.load_model() |
|
self.set_noise_seed(0) |
|
|
|
def latent_space_name(self): |
|
return 'W' if self.w_primary else 'Z' |
|
|
|
def use_w(self): |
|
self.w_primary = True |
|
|
|
def use_z(self): |
|
self.w_primary = False |
|
|
|
def load_model(self): |
|
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') |
|
checkpoint = Path(checkpoint_root) / f'stylegan/stylegan_{self.outclass}_{self.resolution}.pt' |
|
|
|
self.model = stylegan.StyleGAN_G(self.resolution).to(self.device) |
|
|
|
urls_tf = { |
|
'vases': 'https://thisvesseldoesnotexist.s3-us-west-2.amazonaws.com/public/network-snapshot-008980.pkl', |
|
'fireworks': 'https://mega.nz/#!7uBHnACY!quIW-pjdDa7NqnZOYh1z5UemWwPOW6HkYSoJ4usCg9U', |
|
'abstract': 'https://mega.nz/#!vCQyHQZT!zdeOg3VvT4922Z2UfxO51xgAfJD-NAK2nW7H_jMlilU', |
|
'anime': 'https://mega.nz/#!vawjXISI!F7s13yRicxDA3QYqYDL2kjnc2K7Zk3DwCIYETREmBP4', |
|
'ukiyo-e': 'https://drive.google.com/uc?id=1CHbJlci9NhVFifNQb3vCGu6zw4eqzvTd', |
|
} |
|
|
|
urls_torch = { |
|
'celebahq': 'https://drive.google.com/uc?export=download&id=1lGcRwNoXy_uwXkD6sy43aAa-rMHRR7Ad', |
|
'bedrooms': 'https://drive.google.com/uc?export=download&id=1r0_s83-XK2dKlyY3WjNYsfZ5-fnH8QgI', |
|
'ffhq': 'https://drive.google.com/uc?export=download&id=1GcxTcLDPYxQqcQjeHpLUutGzwOlXXcks', |
|
'cars': 'https://drive.google.com/uc?export=download&id=1aaUXHRHjQ9ww91x4mtPZD0w50fsIkXWt', |
|
'cats': 'https://drive.google.com/uc?export=download&id=1JzA5iiS3qPrztVofQAjbb0N4xKdjOOyV', |
|
'wikiart': 'https://drive.google.com/uc?export=download&id=1fN3noa7Rsl9slrDXsgZVDsYFxV0O08Vx', |
|
} |
|
|
|
if not checkpoint.is_file(): |
|
os.makedirs(checkpoint.parent, exist_ok=True) |
|
if self.outclass in urls_torch: |
|
download_ckpt(urls_torch[self.outclass], checkpoint) |
|
else: |
|
checkpoint_tf = checkpoint.with_suffix('.pkl') |
|
if not checkpoint_tf.is_file(): |
|
download_ckpt(urls_tf[self.outclass], checkpoint_tf) |
|
print('Converting TensorFlow checkpoint to PyTorch') |
|
self.model.export_from_tf(checkpoint_tf) |
|
|
|
self.model.load_weights(checkpoint) |
|
|
|
def sample_latent(self, n_samples=1, seed=None, truncation=None): |
|
if seed is None: |
|
seed = np.random.randint(np.iinfo(np.int32).max) |
|
|
|
rng = np.random.RandomState(seed) |
|
noise = torch.from_numpy( |
|
rng.standard_normal(512 * n_samples) |
|
.reshape(n_samples, 512)).float().to(self.device) |
|
|
|
if self.w_primary: |
|
noise = self.model._modules['g_mapping'].forward(noise) |
|
|
|
return noise |
|
|
|
def get_max_latents(self): |
|
return 18 |
|
|
|
def set_output_class(self, new_class): |
|
if self.outclass != new_class: |
|
raise RuntimeError('StyleGAN: cannot change output class without reloading') |
|
|
|
def forward(self, x): |
|
out = self.model.forward(x, latent_is_w=self.w_primary) |
|
return 0.5*(out+1) |
|
|
|
|
|
def partial_forward(self, x, layer_name): |
|
mapping = self.model._modules['g_mapping'] |
|
G = self.model._modules['g_synthesis'] |
|
trunc = self.model._modules.get('truncation', lambda x : x) |
|
|
|
if not self.w_primary: |
|
x = mapping.forward(x) |
|
|
|
if isinstance(x, list): |
|
x = torch.stack(x, dim=1) |
|
else: |
|
x = x.unsqueeze(1).expand(-1, 18, -1) |
|
|
|
|
|
if 'g_mapping' in layer_name: |
|
return |
|
|
|
x = trunc(x) |
|
if layer_name == 'truncation': |
|
return |
|
|
|
|
|
def iterate(m, name, seen): |
|
children = getattr(m, '_modules', []) |
|
if len(children) > 0: |
|
for child_name, module in children.items(): |
|
seen += iterate(module, f'{name}.{child_name}', seen) |
|
return seen |
|
else: |
|
return [name] |
|
|
|
|
|
batch_size = x.size(0) |
|
for i, (n, m) in enumerate(G.blocks.items()): |
|
if i == 0: |
|
r = m(x[:, 2*i:2*i+2]) |
|
else: |
|
r = m(r, x[:, 2*i:2*i+2]) |
|
|
|
children = iterate(m, f'g_synthesis.blocks.{n}', []) |
|
for c in children: |
|
if layer_name in c: |
|
return |
|
|
|
raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') |
|
|
|
|
|
def set_noise_seed(self, seed): |
|
G = self.model._modules['g_synthesis'] |
|
|
|
def for_each_child(this, name, func): |
|
children = getattr(this, '_modules', []) |
|
for child_name, module in children.items(): |
|
for_each_child(module, f'{name}.{child_name}', func) |
|
func(this, name) |
|
|
|
def modify(m, name): |
|
if isinstance(m, stylegan.NoiseLayer): |
|
H, W = [int(s) for s in name.split('.')[2].split('x')] |
|
torch.random.manual_seed(seed) |
|
m.noise = torch.randn(1, 1, H, W, device=self.device, dtype=torch.float32) |
|
|
|
|
|
for_each_child(G, 'g_synthesis', modify) |
|
|
|
class GANZooModel(BaseModel): |
|
def __init__(self, device, model_name): |
|
super(GANZooModel, self).__init__(model_name, 'default') |
|
self.device = device |
|
self.base_model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', |
|
model_name, pretrained=True, useGPU=(device.type == 'cuda')) |
|
self.model = self.base_model.netG.to(self.device) |
|
self.name = model_name |
|
self.has_latent_residual = False |
|
|
|
def sample_latent(self, n_samples=1, seed=0, truncation=None): |
|
|
|
noise, _ = self.base_model.buildNoiseData(n_samples) |
|
return noise |
|
|
|
|
|
def partial_forward(self, x, layer_name): |
|
return self.forward(x) |
|
|
|
def get_conditional_state(self, z): |
|
return z[:, -20:] |
|
|
|
def set_conditional_state(self, z, c): |
|
z[:, -20:] = c |
|
return z |
|
|
|
def forward(self, x): |
|
out = self.base_model.test(x) |
|
return 0.5*(out+1) |
|
|
|
|
|
class ProGAN(BaseModel): |
|
def __init__(self, device, lsun_class=None): |
|
super(ProGAN, self).__init__('ProGAN', lsun_class) |
|
self.device = device |
|
|
|
|
|
valid_classes = [ 'bedroom', 'churchoutdoor', 'conferenceroom', 'diningroom', 'kitchen', 'livingroom', 'restaurant' ] |
|
assert self.outclass in valid_classes, \ |
|
f'Invalid LSUN class {self.outclass}, should be one of {valid_classes}' |
|
|
|
self.load_model() |
|
self.name = f'ProGAN-{self.outclass}' |
|
self.has_latent_residual = False |
|
|
|
def load_model(self): |
|
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') |
|
checkpoint = Path(checkpoint_root) / f'progan/{self.outclass}_lsun.pth' |
|
|
|
if not checkpoint.is_file(): |
|
os.makedirs(checkpoint.parent, exist_ok=True) |
|
url = f'http://netdissect.csail.mit.edu/data/ganmodel/karras/{self.outclass}_lsun.pth' |
|
download_ckpt(url, checkpoint) |
|
|
|
self.model = proggan.from_pth_file(str(checkpoint.resolve())).to(self.device) |
|
|
|
def sample_latent(self, n_samples=1, seed=None, truncation=None): |
|
if seed is None: |
|
seed = np.random.randint(np.iinfo(np.int32).max) |
|
noise = zdataset.z_sample_for_model(self.model, n_samples, seed=seed)[...] |
|
return noise.to(self.device) |
|
|
|
def forward(self, x): |
|
if isinstance(x, list): |
|
assert len(x) == 1, "ProGAN only supports a single global latent" |
|
x = x[0] |
|
|
|
out = self.model.forward(x) |
|
return 0.5*(out+1) |
|
|
|
|
|
def partial_forward(self, x, layer_name): |
|
assert isinstance(self.model, torch.nn.Sequential), 'Expected sequential model' |
|
|
|
if isinstance(x, list): |
|
assert len(x) == 1, "ProGAN only supports a single global latent" |
|
x = x[0] |
|
|
|
x = x.view(x.shape[0], x.shape[1], 1, 1) |
|
for name, module in self.model._modules.items(): |
|
x = module(x) |
|
if name == layer_name: |
|
return |
|
|
|
raise RuntimeError(f'Layer {layer_name} not encountered in partial_forward') |
|
|
|
|
|
class BigGAN(BaseModel): |
|
def __init__(self, device, resolution, class_name, truncation=1.0): |
|
super(BigGAN, self).__init__(f'BigGAN-{resolution}', class_name) |
|
self.device = device |
|
self.truncation = truncation |
|
self.load_model(f'biggan-deep-{resolution}') |
|
self.set_output_class(class_name or 'husky') |
|
self.name = f'BigGAN-{resolution}-{self.outclass}-t{self.truncation}' |
|
self.has_latent_residual = True |
|
|
|
|
|
|
|
def load_model(self, name): |
|
if name not in biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP: |
|
raise RuntimeError('Unknown BigGAN model name', name) |
|
|
|
checkpoint_root = os.environ.get('GANCONTROL_CHECKPOINT_DIR', Path(__file__).parent / 'checkpoints') |
|
model_path = Path(checkpoint_root) / name |
|
|
|
os.makedirs(model_path, exist_ok=True) |
|
|
|
model_file = model_path / biggan.model.WEIGHTS_NAME |
|
config_file = model_path / biggan.model.CONFIG_NAME |
|
model_url = biggan.model.PRETRAINED_MODEL_ARCHIVE_MAP[name] |
|
config_url = biggan.model.PRETRAINED_CONFIG_ARCHIVE_MAP[name] |
|
|
|
for filename, url in ((model_file, model_url), (config_file, config_url)): |
|
if not filename.is_file(): |
|
print('Downloading', url) |
|
with open(filename, 'wb') as f: |
|
if url.startswith("s3://"): |
|
biggan.s3_get(url, f) |
|
else: |
|
biggan.http_get(url, f) |
|
|
|
self.model = biggan.BigGAN.from_pretrained(model_path).to(self.device) |
|
|
|
def sample_latent(self, n_samples=1, truncation=None, seed=None): |
|
if seed is None: |
|
seed = np.random.randint(np.iinfo(np.int32).max) |
|
|
|
noise_vector = biggan.truncated_noise_sample(truncation=truncation or self.truncation, batch_size=n_samples, seed=seed) |
|
noise = torch.from_numpy(noise_vector) |
|
|
|
return noise.to(self.device) |
|
|
|
|
|
def get_max_latents(self): |
|
return len(self.model.config.layers) + 1 |
|
|
|
def get_conditional_state(self, z): |
|
return self.v_class |
|
|
|
def set_conditional_state(self, z, c): |
|
self.v_class = c |
|
|
|
def is_valid_class(self, class_id): |
|
if isinstance(class_id, int): |
|
return class_id < 1000 |
|
elif isinstance(class_id, str): |
|
return biggan.one_hot_from_names([class_id.replace(' ', '_')]) is not None |
|
else: |
|
raise RuntimeError(f'Unknown class identifier {class_id}') |
|
|
|
def set_output_class(self, class_id): |
|
if isinstance(class_id, int): |
|
self.v_class = torch.from_numpy(biggan.one_hot_from_int([class_id])).to(self.device) |
|
self.outclass = f'class{class_id}' |
|
elif isinstance(class_id, str): |
|
self.outclass = class_id.replace(' ', '_') |
|
self.v_class = torch.from_numpy(biggan.one_hot_from_names([class_id])).to(self.device) |
|
else: |
|
raise RuntimeError(f'Unknown class identifier {class_id}') |
|
|
|
def forward(self, x): |
|
|
|
if isinstance(x, list): |
|
c = self.v_class.repeat(x[0].shape[0], 1) |
|
class_vector = len(x)*[c] |
|
else: |
|
class_vector = self.v_class.repeat(x.shape[0], 1) |
|
out = self.model.forward(x, class_vector, self.truncation) |
|
return 0.5*(out+1) |
|
|
|
|
|
|
|
def partial_forward(self, x, layer_name): |
|
if layer_name in ['embeddings', 'generator.gen_z']: |
|
n_layers = 0 |
|
elif 'generator.layers' in layer_name: |
|
layer_base = re.match('^generator\.layers\.[0-9]+', layer_name)[0] |
|
n_layers = int(layer_base.split('.')[-1]) + 1 |
|
else: |
|
n_layers = len(self.model.config.layers) |
|
|
|
if not isinstance(x, list): |
|
x = self.model.n_latents*[x] |
|
|
|
if isinstance(self.v_class, list): |
|
labels = [c.repeat(x[0].shape[0], 1) for c in class_label] |
|
embed = [self.model.embeddings(l) for l in labels] |
|
else: |
|
class_label = self.v_class.repeat(x[0].shape[0], 1) |
|
embed = len(x)*[self.model.embeddings(class_label)] |
|
|
|
assert len(x) == self.model.n_latents, f'Expected {self.model.n_latents} latents, got {len(x)}' |
|
assert len(embed) == self.model.n_latents, f'Expected {self.model.n_latents} class vectors, got {len(class_label)}' |
|
|
|
cond_vectors = [torch.cat((z, e), dim=1) for (z, e) in zip(x, embed)] |
|
|
|
|
|
z = self.model.generator.gen_z(cond_vectors[0]) |
|
z = z.view(-1, 4, 4, 16 * self.model.generator.config.channel_width) |
|
z = z.permute(0, 3, 1, 2).contiguous() |
|
|
|
cond_idx = 1 |
|
for i, layer in enumerate(self.model.generator.layers[:n_layers]): |
|
if isinstance(layer, biggan.GenBlock): |
|
z = layer(z, cond_vectors[cond_idx], self.truncation) |
|
cond_idx += 1 |
|
else: |
|
z = layer(z) |
|
|
|
return None |
|
|
|
|
|
@singledispatch |
|
def get_model(name, output_class, device, **kwargs): |
|
|
|
inst = kwargs.get('inst', None) |
|
model = kwargs.get('model', None) |
|
|
|
if inst or model: |
|
cached = model or inst.model |
|
|
|
network_same = (cached.model_name == name) |
|
outclass_same = (cached.outclass == output_class) |
|
can_change_class = ('BigGAN' in name) |
|
|
|
if network_same and (outclass_same or can_change_class): |
|
cached.set_output_class(output_class) |
|
return cached |
|
|
|
if name == 'DCGAN': |
|
import warnings |
|
warnings.filterwarnings("ignore", message="nn.functional.tanh is deprecated") |
|
model = GANZooModel(device, 'DCGAN') |
|
elif name == 'ProGAN': |
|
model = ProGAN(device, output_class) |
|
elif 'BigGAN' in name: |
|
assert '-' in name, 'Please specify BigGAN resolution, e.g. BigGAN-512' |
|
model = BigGAN(device, name.split('-')[-1], class_name=output_class) |
|
elif name == 'StyleGAN': |
|
model = StyleGAN(device, class_name=output_class) |
|
elif name == 'StyleGAN2': |
|
model = StyleGAN2(device, class_name=output_class) |
|
else: |
|
raise RuntimeError(f'Unknown model {name}') |
|
|
|
return model |
|
|
|
|
|
@get_model.register(Config) |
|
def _(cfg, device, **kwargs): |
|
kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) |
|
return get_model(cfg.model, cfg.output_class, device, **kwargs) |
|
|
|
|
|
@singledispatch |
|
def get_instrumented_model(name, output_class, layers, device, **kwargs): |
|
model = get_model(name, output_class, device, **kwargs) |
|
model.eval() |
|
|
|
inst = kwargs.get('inst', None) |
|
if inst: |
|
inst.close() |
|
|
|
if not isinstance(layers, list): |
|
layers = [layers] |
|
|
|
|
|
module_names = [name for (name, _) in model.named_modules()] |
|
for layer_name in layers: |
|
if not layer_name in module_names: |
|
print(f"Layer '{layer_name}' not found in model!") |
|
print("Available layers:", '\n'.join(module_names)) |
|
raise RuntimeError(f"Unknown layer '{layer_name}''") |
|
|
|
|
|
if hasattr(model, 'use_z'): |
|
model.use_z() |
|
|
|
from netdissect.modelconfig import create_instrumented_model |
|
inst = create_instrumented_model(SimpleNamespace( |
|
model = model, |
|
layers = layers, |
|
cuda = device.type == 'cuda', |
|
gen = True, |
|
latent_shape = model.get_latent_shape() |
|
)) |
|
|
|
if kwargs.get('use_w', False): |
|
model.use_w() |
|
|
|
return inst |
|
|
|
|
|
@get_instrumented_model.register(Config) |
|
def _(cfg, device, **kwargs): |
|
kwargs['use_w'] = kwargs.get('use_w', cfg.use_w) |
|
return get_instrumented_model(cfg.model, cfg.output_class, cfg.layer, device, **kwargs) |
|
|