Spaces:
Running
on
Zero
Running
on
Zero
# Loss for Portrait4D, modified from EG3D: https://github.com/NVlabs/eg3d | |
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Loss functions.""" | |
import numpy as np | |
import PIL | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transform | |
# from kornia.geometry import warp_affine | |
from torch_utils import training_stats | |
from torch_utils.ops import conv2d_gradfix | |
from torch_utils.ops import upfirdn2d | |
from recon.training.discriminator.dual_discriminator_next3D import filtered_resizing | |
import cv2 | |
from PIL import Image | |
# from recon.utils.preprocess import estimate_norm_torch, estimate_norm_torch_pdfgc | |
from camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
# ---------------------------------------------------------------------------- | |
class Loss: | |
def accumulate_gradients(self, phase, real_img_app, real_img_mot, real_img_recon, real_depth_recon, | |
real_feature_recon, real_triplane_recon, | |
real_c_recon, motions_app, motions, gain, cur_nimg, real_cano_tri, real_ref_tri, | |
motion_scale=1.0, swapping_prob=0.5, half_static=False): # to be overridden by subclass | |
raise NotImplementedError() | |
# ---------------------------------------------------------------------------- | |
class AnimatableGalleryPortraitReconLoss(Loss): | |
def __init__(self, device, G, D, G_syn, D_patch=None, augment_pipe=None, lpips=None, facenet=None, pd_fgc=None, | |
gmain=1.0, r1_gamma=10, r1_gamma_patch=10, r1_gamma_uv=30, | |
r1_gamma_seg=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, | |
pl_no_weight_grad=False, blur_init_sigma=0, blur_init_sigma_patch=0, blur_fade_kimg=0, | |
blur_patch_seg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, | |
neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, | |
gpc_reg_fade_kimg=1000, gpc_reg_prob=None, discrimination_kimg=1000, dual_discrimination=False, | |
filter_mode='antialiased', patch_scale=1.0, patch_gan=0.2, masked_sampling=None, perturb_params=False, id_loss=None, | |
use_D=True, truncation_psi=0.7, conditioning_params=None, w_avg=None): | |
super().__init__() | |
self.device = device | |
self.G = G | |
self.D = D | |
self.G_syn = G_syn | |
self.D_patch = D_patch | |
self.augment_pipe = augment_pipe | |
self.lpips = lpips | |
self.pd_fgc = pd_fgc | |
self.gmain = gmain | |
self.r1_gamma = r1_gamma | |
self.r1_gamma_patch = r1_gamma_patch | |
self.r1_gamma_uv = r1_gamma_uv | |
self.r1_gamma_seg = r1_gamma_seg | |
self.style_mixing_prob = style_mixing_prob | |
self.pl_weight = pl_weight | |
self.pl_batch_shrink = pl_batch_shrink | |
self.pl_decay = pl_decay | |
self.pl_no_weight_grad = pl_no_weight_grad | |
self.pl_mean = torch.zeros([], device=device) | |
self.blur_init_sigma = blur_init_sigma | |
self.blur_init_sigma_patch = blur_init_sigma_patch | |
self.blur_fade_kimg = blur_fade_kimg | |
self.blur_patch_seg = blur_patch_seg | |
self.r1_gamma_init = r1_gamma_init | |
self.r1_gamma_fade_kimg = r1_gamma_fade_kimg | |
self.bg_reg = True | |
self.c_headpose = False | |
self.neural_rendering_resolution_initial = neural_rendering_resolution_initial | |
self.neural_rendering_resolution_final = neural_rendering_resolution_final | |
self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg | |
self.gpc_reg_fade_kimg = gpc_reg_fade_kimg | |
self.gpc_reg_prob = gpc_reg_prob | |
self.discrimination_kimg = discrimination_kimg | |
self.dual_discrimination = dual_discrimination | |
self.filter_mode = filter_mode | |
self.resample_filter = upfirdn2d.setup_filter([1, 3, 3, 1], device=device) | |
self.blur_raw_target = True | |
assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1) | |
self.patch_scale = patch_scale | |
self.masked_sampling = masked_sampling | |
self.patch_gan = patch_gan | |
self.perturb_params = perturb_params | |
self.use_D = use_D | |
self.truncation_psi = truncation_psi | |
self.conditioning_params = torch.load(conditioning_params ).to(device) | |
self.w_avg = torch.load(w_avg).to(device)[0] | |
self.id_loss = id_loss.to(device) | |
# extract pdfgc motion embedding | |
# def get_motion_feature(self, imgs, lmks, crop_size=224, crop_len=16): | |
# | |
# trans_m = estimate_norm_torch_pdfgc(lmks, imgs.shape[-1]) | |
# imgs_warp = warp_affine(imgs, trans_m, dsize=(224, 224)) | |
# imgs_warp = imgs_warp[:, :, :crop_size - crop_len * 2, crop_len:crop_size - crop_len] | |
# imgs_warp = torch.clamp(F.interpolate(imgs_warp, size=[crop_size, crop_size], mode='bilinear'), -1, 1) | |
# | |
# out = self.pd_fgc(imgs_warp) | |
# motions = torch.cat([out[1], out[2], out[3]], dim=-1) | |
# | |
# return motions | |
# generate online training data using pre-trained Next3d model. the first frame is fixed | |
def gen_data_by_G_next3D(self, z, exp_params, c, model_name_list, | |
render_res=64): | |
out_sr_img = [] | |
out_img_raw = [] | |
out_img_depth = [] | |
out_static_plane = [] | |
out_feature_img = [] | |
out_triplane = [] | |
out_texture = [] | |
out_rendering_stitch = [] | |
batchsize_subject = z.shape[0] | |
# assert batchsize_subject == 3 | |
# cam_pivot = torch.tensor(self.G_syn.rendering_kwargs.get('avg_camera_pivot', [0, 0, 0]), device=device) | |
# cam_radius = self.G_syn.rendering_kwargs.get('avg_camera_radius', 2.7) | |
# conditioning_cam2world_pose = LookAtPoseSampler.sample(np.pi / 2, np.pi / 2, cam_pivot, radius=cam_radius, | |
# device=device) | |
# conditioning_params = torch.cat([conditioning_cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1).to( | |
# device) | |
# w = G.mapping(z, conditioning_params, truncation_psi=0.7, truncation_cutoff=14) | |
# latent = latent.reshape(-1, *latent.shape[2:]) | |
assert len(model_name_list) == z.shape[0] | |
model_name_list = [name for name1 in model_name_list for name in name1 ] | |
z = z.reshape(-1, *z.shape[2:]) # (b*3, 512) | |
assert len(model_name_list) == z.shape[0] | |
exp_params = exp_params.reshape(-1, *exp_params.shape[2:]) # (b*3, 100) | |
c = c.reshape(-1, *c.shape[2:]) # (b*3, 25) | |
# out_motion = [real_motion_1, real_motion_2] | |
# random head rotation | |
angle_ys_head = torch.rand((z.shape[0], 1), device=z.device) * 0.60 * 2 - 0.60 | |
angle_ys_head2 = torch.rand((z.shape[0], 1), device=z.device) * 0.35 * 2 - 0.35 + 0.2 | |
# angle_ys_head3 = torch.rand((z.shape[0], 1), device=z.device) * 0.25 * 2 - 0.25 | |
# random camera pose | |
cam_pivot_x = torch.rand((z.shape[0], 1), device=z.device) * 0.02 - 0.01 | |
cam_pivot_y = torch.rand((z.shape[0], 1), device=z.device) * 0.02 - 0.01 | |
cam_pivot_z = torch.rand((z.shape[0], 1), device=z.device) * 0.02 - 0.01 + 0.03 | |
cam_pivot = torch.cat([cam_pivot_x * 3, cam_pivot_y * 3, cam_pivot_z * 3], dim=-1) | |
# cam_radius = torch.rand((z.shape[0], 1), device=z.device) * 0.8 + 2.7 | |
cam2world_pose = LookAtPoseSampler.sample(np.pi / 2 + angle_ys_head, np.pi / 2 - angle_ys_head2, cam_pivot, | |
radius=2.7, | |
batch_size=z.shape[0], device=z.device) | |
c_syn = torch.cat([cam2world_pose.reshape(-1, 16), c[:, 16:].reshape(-1, 9)], dim=-1) | |
prob = torch.rand((c.shape[0], 1), device=c.device) | |
c_syn = c_syn | |
c_syn_final_out = c_syn.reshape(-1, 3, c_syn.shape[-1]) | |
for batch_index, model_name in enumerate(model_name_list): | |
z_value = z[batch_index].unsqueeze(0) | |
ws = self.G_syn[model_name].mapping(z_value, self.conditioning_params, truncation_psi=self.truncation_psi, | |
truncation_cutoff=14) | |
vert_value = exp_params[batch_index].unsqueeze(0) | |
c = c_syn[batch_index].unsqueeze(0) | |
out = self.G_syn[model_name].synthesis(ws, c, vert_value, noise_mode='const', | |
neural_rendering_resolution=128, return_featmap=True | |
) | |
# img = (out['image'][0] * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
# img = img.permute(1, 2, 0) | |
# img = img.cpu().numpy() | |
# img = Image.fromarray(np.uint8(img)) | |
# print('savesavesavesavesavesave') | |
# save_dir = f'/home/liuhongyu/code/HeadArtist2/HeadGallery/training-runs-portrait4d/00073--multi_style-gpus2-batch8/{batch_index}.png' | |
# print(save_dir) | |
# # cv2.imwrite(save_dir, img) | |
# img.save(save_dir) | |
out_sr_img.append(out['image']) | |
out_img_raw.append(out['image_raw']) | |
out_img_depth.append(out['image_depth']) | |
out_static_plane.append(out['static_plane']) | |
out_feature_img.append(out['image_feature']) | |
out_triplane.append(out['triplane']) | |
out_rendering_stitch.append(out['rendering_stitch']) | |
final_out = {'image_sr': torch.cat(out_sr_img), 'image': torch.cat(out_img_raw), | |
'image_depth': torch.cat(out_img_depth), 'static_plane': torch.cat(out_static_plane), | |
'image_feature': torch.cat(out_feature_img), 'triplane': torch.cat(out_triplane), | |
'rendering_stitch': torch.cat(out_rendering_stitch), | |
'c': c_syn_final_out, | |
# 'motions': out_motion | |
} | |
return final_out | |
def run_G(self, imgs_app, imgs_mot, motions_app, motions, c, mesh, real_cano_tri, real_ref_tri, | |
neural_rendering_resolution, | |
motion_scale=1.0, swapping_prob=0.5, half_static=False): | |
motion_scale = torch.ones([imgs_app.shape[0], 1, 1], device=c.device) * motion_scale | |
if swapping_prob is not None: | |
imgs_app_swapped = imgs_mot | |
prob = torch.rand((imgs_app.shape[0], 1), device=c.device) | |
imgs_app_conditioning = torch.where(prob.reshape(imgs_app.shape[0], 1, 1, 1) < swapping_prob, | |
imgs_app_swapped, imgs_app) | |
motion_scale_conditioning = torch.where(prob.reshape(imgs_app.shape[0], 1, 1) < swapping_prob, | |
torch.zeros_like(motion_scale), motion_scale) | |
motions_app_conditioning = torch.where(prob < swapping_prob, motions, motions_app) | |
else: | |
imgs_app_conditioning = imgs_app | |
motion_scale_conditioning = motion_scale | |
motions_app_conditioning = motions_app | |
# whether or not the second half of the batchsize are static data | |
# If true, set motion scale to zero to deactivate motion-related cross-attention layers. | |
if half_static: | |
num_static = imgs_app.shape[0] // 2 | |
if swapping_prob is None: | |
motion_scale_conditioning = torch.cat([motion_scale[:num_static], motion_scale[num_static:] * 0], dim=0) | |
else: | |
prob = torch.rand((num_static, 1), device=c.device) | |
motion_scale_static = torch.where(prob.reshape(num_static, 1, 1) < 1 - swapping_prob, | |
torch.zeros_like(motion_scale[num_static:]), | |
motion_scale[num_static:]) | |
motion_scale_conditioning = torch.cat([motion_scale_conditioning[:num_static], motion_scale_static], | |
dim=0) | |
gen_output = self.G.synthesis(imgs_app_conditioning, imgs_mot, motions_app_conditioning, motions, c, mesh, | |
real_ref_tri, real_ref_tri, self.w_avg, | |
neural_rendering_resolution=neural_rendering_resolution, | |
motion_scale=motion_scale_conditioning) | |
return gen_output | |
def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False): | |
blur_size = np.floor(blur_sigma * 3) | |
if blur_size > 0: | |
with torch.autograd.profiler.record_function('blur'): | |
if self.G.has_superresolution: | |
f = torch.arange(-blur_size, blur_size + 1, device=img['image_sr'].device).div( | |
blur_sigma).square().neg().exp2() | |
img['image_sr'] = upfirdn2d.filter2d(img['image_sr'], f / f.sum()) | |
else: | |
f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div( | |
blur_sigma).square().neg().exp2() | |
img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) | |
logits = self.D(img, c, update_emas=update_emas) | |
return logits | |
def accumulate_gradients(self, phase, real_img_app, real_img_mot, real_img_recon, real_depth_recon, | |
real_feature_recon, real_triplane_recon, | |
real_c_recon, mesh, motions_app, motions, gain, cur_nimg, real_cano_tri, real_ref_tri, | |
motion_scale=1.0, swapping_prob=0.5, half_static=True ): | |
if self.G.rendering_kwargs.get('density_reg', 0) == 0: | |
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) | |
if self.r1_gamma == 0: | |
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) | |
# if self.r1_gamma_patch == 0: | |
# phase = {'D_patchreg': 'none', 'D_patchboth': 'Dmain'}.get(phase, phase) | |
blur_sigma = 0 | |
r1_gamma = self.r1_gamma | |
# r1_gamma_patch = self.r1_gamma_patch | |
# r1_gamma_uv = self.r1_gamma_uv | |
# r1_gamma_seg = self.r1_gamma_seg | |
if self.neural_rendering_resolution_final is not None: | |
alpha = min( | |
max((cur_nimg - self.discrimination_kimg * 1e3) / (self.neural_rendering_resolution_fade_kimg * 1e3), | |
0), 1) # begin fading when D starts to be optimized | |
neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * ( | |
1 - alpha) + self.neural_rendering_resolution_final * alpha)) | |
neural_rendering_resolution_patch = self.neural_rendering_resolution_final | |
else: | |
neural_rendering_resolution = self.neural_rendering_resolution_initial | |
neural_rendering_resolution_patch = neural_rendering_resolution | |
if self.G.has_superresolution: | |
real_img_raw = filtered_resizing(real_img_recon, size=neural_rendering_resolution, f=self.resample_filter, | |
filter_mode=self.filter_mode) | |
if self.blur_raw_target and blur_sigma > 0: | |
blur_size = np.floor(blur_sigma * 3) | |
if blur_size > 0: | |
f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div( | |
blur_sigma).square().neg().exp2() | |
real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) | |
real_img = {'image_sr': real_img_recon, 'image': real_img_raw, | |
} | |
else: | |
real_img = {'image': real_img_recon} | |
# Gmain: Maximize logits for generated images. | |
if phase in ['Gmain', 'Gboth']: | |
with torch.autograd.profiler.record_function('Gmain_forward'): | |
gen_img = self.run_G(real_img_app, real_img_mot, motions_app, motions, real_c_recon, mesh, real_cano_tri, real_ref_tri, | |
neural_rendering_resolution=neural_rendering_resolution, | |
motion_scale=motion_scale, swapping_prob=swapping_prob, half_static=half_static) | |
# main image-level reconstruction loss | |
gen_img_recon = gen_img['image_sr'] | |
gen_img_recon_raw = gen_img['image'] | |
gen_depth = gen_img['image_depth'] | |
gen_feature = gen_img['image_feature'] | |
gen_triplane_recon = gen_img['triplane'] | |
loss_recon_lpips = self.lpips(gen_img_recon, real_img_recon) + self.lpips(gen_img_recon_raw, | |
real_img_raw) | |
training_stats.report('Loss/G/lrecon_lpips', loss_recon_lpips) | |
loss_recon_l1 = torch.abs(gen_img_recon - real_img_recon).mean() + torch.abs( | |
gen_img_recon_raw - real_img_raw).mean() | |
training_stats.report('Loss/G/lrecon_l1', loss_recon_l1) | |
# use id loss after seeing 400k images | |
if cur_nimg < 400 * 1e3: | |
loss_id = 0 | |
else: | |
loss_id = self.id_loss(gen_img_recon, real_img_recon) | |
training_stats.report('G_Loss/real/loss_id', loss_id) | |
# use depth loss before seeing 400k images | |
if real_depth_recon is not None: | |
if real_depth_recon.shape != gen_depth.shape: | |
real_depth_recon = F.interpolate(real_depth_recon, size=[gen_depth.shape[2], gen_depth.shape[3]], mode='bilinear', | |
antialias=True) | |
loss_recon_depth = torch.abs( | |
(real_depth_recon - gen_depth)).mean() | |
training_stats.report('Loss/G/lrecon_depth', loss_recon_depth) | |
else: | |
loss_recon_depth = 0. | |
# use feature map loss before seeing 400k images | |
if real_feature_recon is not None: | |
if real_feature_recon.shape != gen_feature.shape: | |
real_feature_recon = F.interpolate(real_feature_recon, size=[gen_feature.shape[2], gen_feature.shape[3]], mode='bilinear', | |
antialias=True) | |
loss_recon_feature = torch.abs(real_feature_recon - gen_feature).mean() | |
training_stats.report('Loss/G/lrecon_feature', loss_recon_feature) | |
else: | |
loss_recon_feature = 0. | |
# use triplane feature loss before seeing 400k images | |
if real_triplane_recon is not None: | |
loss_recon_triplane = torch.abs(real_triplane_recon - gen_triplane_recon).mean() | |
training_stats.report('Loss/G/lrecon_triplane', loss_recon_triplane) | |
else: | |
loss_recon_triplane = 0. | |
loss_recon = loss_recon_lpips + loss_recon_l1 + loss_recon_depth + loss_recon_feature + loss_recon_triplane*0.1 + loss_id | |
# adversarial loss after warm-up stage | |
if cur_nimg >= self.discrimination_kimg * 1e3 and self.use_D: | |
gen_logits = self.run_D(gen_img, real_c_recon, blur_sigma=blur_sigma) | |
loss_Gmain = torch.nn.functional.softplus(-gen_logits) | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
training_stats.report('Loss/G/loss', loss_Gmain) | |
else: | |
loss_Gmain = None | |
with torch.autograd.profiler.record_function('Gmain_backward'): | |
loss_G = loss_recon.mean() | |
if loss_Gmain is not None: | |
loss_G += loss_Gmain.mean() * self.gmain | |
loss_G.mul(gain).backward() | |
# # Density Regularization | |
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs[ | |
'reg_type'] == 'l1': | |
initial_coordinates = torch.rand((real_c_recon.shape[0], 1000, 3), device=real_c_recon.device) * 2 - 1 | |
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * \ | |
self.G.rendering_kwargs['density_reg_p_dist'] | |
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) | |
motion_scale = torch.ones([real_img_app.shape[0], 1, 1], device=real_img_app.device) * motion_scale | |
if swapping_prob is not None: | |
real_img_app_swapped = real_img_mot | |
prob = torch.rand((real_img_app.shape[0], 1), device=real_img_app.device) | |
real_img_app_conditioning = torch.where(prob.reshape(real_img_app.shape[0], 1, 1, 1) < swapping_prob, | |
real_img_app_swapped, real_img_app) | |
motion_scale_conditioning = torch.where(prob.reshape(real_img_app.shape[0], 1, 1) < swapping_prob, | |
torch.zeros_like(motion_scale), motion_scale) | |
motions_app_conditioning = torch.where(prob < swapping_prob, motions, motions_app) | |
else: | |
real_img_app_conditioning = real_img_app | |
motion_scale_conditioning = motion_scale | |
motions_app_conditioning = motions_app | |
if half_static: | |
num_static = real_img_app.shape[0] // 2 | |
if swapping_prob is None: | |
motion_scale_conditioning = torch.cat([motion_scale[:num_static], motion_scale[num_static:] * 0], | |
dim=0) | |
else: | |
prob = torch.rand((num_static, 1), device=real_img_app.device) | |
motion_scale_static = torch.where(prob.reshape(num_static, 1, 1) < 1 - swapping_prob, | |
torch.zeros_like(motion_scale[num_static:]), | |
motion_scale[num_static:]) | |
motion_scale_conditioning = torch.cat([motion_scale_conditioning[:num_static], motion_scale_static], | |
dim=0) | |
out = self.G.sample_mixed(real_img_app_conditioning, real_img_mot, mesh, self.w_avg, motions_app_conditioning, motions, | |
all_coordinates, torch.randn_like(all_coordinates), real_cano_tri, real_ref_tri, | |
motion_scale=motion_scale_conditioning) | |
if isinstance(out, tuple): | |
TVloss = 0 | |
for out_ in out: | |
sigma = out_['sigma'][:, :initial_coordinates.shape[1] * 2] | |
sigma_initial = sigma[:, :sigma.shape[1] // 2] | |
sigma_perturbed = sigma[:, sigma.shape[1] // 2:] | |
TVloss += torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ | |
'density_reg'] / len(out) | |
training_stats.report('Loss/G/TVloss', TVloss) | |
else: | |
sigma = out['sigma'][:, :initial_coordinates.shape[1] * 2] | |
sigma_initial = sigma[:, :sigma.shape[1] // 2] | |
sigma_perturbed = sigma[:, sigma.shape[1] // 2:] | |
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs[ | |
'density_reg'] | |
training_stats.report('Loss/G/TVloss', TVloss) | |
(TVloss).mul(gain).backward() | |
# Dmain: Minimize logits for generated images. | |
if cur_nimg >= self.discrimination_kimg * 1e3 and self.use_D: | |
loss_Dgen = 0 | |
if phase in ['Dmain', 'Dboth']: | |
with torch.autograd.profiler.record_function('Dgen_forward'): | |
gen_img = self.run_G(real_img_app, real_img_mot, motions_app, motions, real_c_recon, mesh, real_cano_tri, | |
real_ref_tri, | |
neural_rendering_resolution=neural_rendering_resolution, | |
motion_scale=motion_scale, swapping_prob=swapping_prob, | |
half_static=half_static) | |
gen_logits = self.run_D(gen_img, real_c_recon, blur_sigma=blur_sigma, update_emas=True) | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
loss_Dgen = torch.nn.functional.softplus(gen_logits) | |
with torch.autograd.profiler.record_function('Dgen_backward'): | |
loss_Dgen.mean().mul(gain).backward() | |
# Dmain: Maximize logits for real images. | |
# Dr1: Apply R1 regularization. | |
if phase in ['Dmain', 'Dreg', 'Dboth']: | |
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' | |
with torch.autograd.profiler.record_function(name + '_forward'): | |
real_img_tmp_image = real_img['image_sr'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) | |
real_img_tmp_image_raw = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) | |
real_img_tmp = {'image_sr': real_img_tmp_image, 'image': real_img_tmp_image_raw} | |
real_logits = self.run_D(real_img_tmp, real_c_recon, blur_sigma=blur_sigma) | |
training_stats.report('Loss/scores/real', real_logits) | |
training_stats.report('Loss/signs/real', real_logits.sign()) | |
loss_Dreal = 0 | |
if phase in ['Dmain', 'Dboth']: | |
loss_Dreal = torch.nn.functional.softplus(-real_logits) | |
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) | |
loss_Dr1 = 0 | |
if phase in ['Dreg', 'Dboth']: | |
if self.dual_discrimination: | |
with torch.autograd.profiler.record_function( | |
'r1_grads'), conv2d_gradfix.no_weight_gradients(): | |
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], | |
inputs=[real_img_tmp['image_sr'], | |
real_img_tmp['image']], create_graph=True, | |
only_inputs=True) | |
r1_grads_image = r1_grads[0] | |
r1_grads_image_raw = r1_grads[1] | |
r1_penalty = r1_grads_image.square().sum([1, 2, 3]) + r1_grads_image_raw.square().sum( | |
[1, 2, 3]) | |
else: # single discrimination | |
with torch.autograd.profiler.record_function( | |
'r1_grads'), conv2d_gradfix.no_weight_gradients(): | |
if self.G.has_superresolution: | |
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], | |
inputs=[real_img_tmp['image_sr']], create_graph=True, | |
only_inputs=True) | |
else: | |
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], | |
inputs=[real_img_tmp['image']], create_graph=True, | |
only_inputs=True) | |
r1_grads_image = r1_grads[0] | |
r1_penalty = r1_grads_image.square().sum([1, 2, 3]) | |
loss_Dr1 = r1_penalty * (r1_gamma / 2) | |
training_stats.report('Loss/r1_penalty', r1_penalty) | |
training_stats.report('Loss/D/reg', loss_Dr1) | |
with torch.autograd.profiler.record_function(name + '_backward'): | |
(loss_Dreal + loss_Dr1).mean().mul(gain).backward() | |