stylemc-demo / w_s_converter.py
adirik's picture
update w_to_s converter
27fc598
raw
history blame
5.59 kB
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Generate images using pretrained network pickle."""
import os
import re
from typing import List
import numpy as np
import torch
from torch_utils import misc
from torch_utils import persistence
from torch_utils.ops import conv2d_resample
from torch_utils.ops import upfirdn2d
from torch_utils.ops import bias_act
from torch_utils.ops import fma
def block_forward(self, x, img, ws, shapes, force_fp32=False, fused_modconv=None, **layer_kwargs):
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
w_iter = iter(ws.unbind(dim=1))
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
if fused_modconv is None:
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
# Input.
if self.in_channels == 0:
x = self.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if self.in_channels == 0:
x = self.conv1(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
elif self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
x = y.add_(x)
else:
x = self.conv0(x, next(w_iter)[...,:shapes[0]], fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(x, next(w_iter)[...,:shapes[1]], fused_modconv=fused_modconv, **layer_kwargs)
# ToRGB.
if img is not None:
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
img = upfirdn2d.upsample2d(img, self.resample_filter)
if self.is_last or self.architecture == 'skip':
y = self.torgb(x, next(w_iter)[...,:shapes[2]], fused_modconv=fused_modconv)
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
assert x.dtype == dtype
assert img is None or img.dtype == torch.float32
return x, img
def unravel_index(index, shape):
out = []
for dim in reversed(shape):
out.append(index % dim)
index = index // dim
return tuple(reversed(out))
def w_to_s(
G,
outdir: str,
projected_w: str,
truncation_psi: float = 0.7,
noise_mode: str = "const",
):
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
os.makedirs(outdir, exist_ok=True)
# Generate images.
for i in G.parameters():
i.requires_grad = True
ws = np.load(projected_w)['w']
ws = torch.tensor(ws, device=device)
block_ws = []
with torch.autograd.profiler.record_function('split_ws'):
misc.assert_shape(ws, [None, G.synthesis.num_ws, G.synthesis.w_dim])
ws = ws.to(torch.float32)
w_idx = 0
for res in G.synthesis.block_resolutions:
block = getattr(G.synthesis, f'b{res}')
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv
styles = torch.zeros(1,26,512, device=device)
styles_idx = 0
temp_shapes = []
for res, cur_ws in zip(G.synthesis.block_resolutions, block_ws):
block = getattr(G.synthesis, f'b{res}')
if res == 4:
temp_shape = (block.conv1.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
styles[0,:1,:] = block.conv1.affine(cur_ws[0,:1,:])
styles[0,1:2,:] = block.torgb.affine(cur_ws[0,1:2,:])
block.conv1.affine = torch.nn.Identity()
block.torgb.affine = torch.nn.Identity()
styles_idx += 2
else:
temp_shape = (block.conv0.affine.weight.shape[0], block.conv1.affine.weight.shape[0], block.torgb.affine.weight.shape[0])
styles[0,styles_idx:styles_idx+1,:temp_shape[0]] = block.conv0.affine(cur_ws[0,:1,:])
styles[0,styles_idx+1:styles_idx+2,:temp_shape[1]] = block.conv1.affine(cur_ws[0,1:2,:])
styles[0,styles_idx+2:styles_idx+3,:temp_shape[2]] = block.torgb.affine(cur_ws[0,2:3,:])
block.conv0.affine = torch.nn.Identity()
block.conv1.affine = torch.nn.Identity()
block.torgb.affine = torch.nn.Identity()
styles_idx += 3
temp_shapes.append(temp_shape)
styles = styles.detach()
np.savez(f'{outdir}/input.npz', s=styles.cpu().numpy())