Spaces:
Runtime error
Runtime error
File size: 4,012 Bytes
97a6728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import functools
import torch
import tops
from tops import logger
from dp2.utils import forward_D_fake
from .utils import nsgan_d_loss, nsgan_g_loss
from .r1_regularization import r1_regularization
from .pl_regularization import PLRegularization
class StyleGAN2Loss:
def __init__(
self,
D,
G,
r1_opts: dict,
EP_lambd: float,
lazy_reg_interval: int,
lazy_regularization: bool,
pl_reg_opts: dict,
) -> None:
self.gradient_step_D = 0
self._lazy_reg_interval = lazy_reg_interval
self.D = D
self.G = G
self.EP_lambd = EP_lambd
self.lazy_regularization = lazy_regularization
self.r1_reg = functools.partial(
r1_regularization, **r1_opts, lazy_reg_interval=lazy_reg_interval,
lazy_regularization=lazy_regularization)
self.do_PL_Reg = False
if pl_reg_opts.weight > 0:
self.pl_reg = PLRegularization(**pl_reg_opts)
self.do_PL_Reg = True
self.pl_start_nimg = pl_reg_opts.start_nimg
def D_loss(self, batch: dict, grad_scaler):
to_log = {}
# Forward through G and D
do_GP = self.lazy_regularization and self.gradient_step_D % self._lazy_reg_interval == 0
if do_GP:
batch["img"] = batch["img"].detach().requires_grad_(True)
with torch.cuda.amp.autocast(enabled=tops.AMP()):
with torch.no_grad():
G_fake = self.G(**batch, update_emas=True)
D_out_real = self.D(**batch)
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
# Non saturating loss
nsgan_loss = nsgan_d_loss(D_out_real["score"], D_out_fake["score"])
tops.assert_shape(nsgan_loss, (batch["img"].shape[0], ))
to_log["d_loss"] = nsgan_loss.mean()
total_loss = nsgan_loss
epsilon_penalty = D_out_real["score"].pow(2).view(-1)
to_log["epsilon_penalty"] = epsilon_penalty.mean()
tops.assert_shape(epsilon_penalty, total_loss.shape)
total_loss = total_loss + epsilon_penalty * self.EP_lambd
# Improved gradient penalty with lazy regularization
# Gradient penalty applies specialized autocast.
if do_GP:
gradient_pen, grad_unscaled = self.r1_reg(
batch["img"], D_out_real["score"], batch["mask"], scaler=grad_scaler)
to_log["r1_gradient_penalty"] = grad_unscaled.mean()
tops.assert_shape(gradient_pen, total_loss.shape)
total_loss = total_loss + gradient_pen
batch["img"] = batch["img"].detach().requires_grad_(False)
if "score" in D_out_real:
to_log["real_scores"] = D_out_real["score"]
to_log["real_logits_sign"] = D_out_real["score"].sign()
to_log["fake_logits_sign"] = D_out_fake["score"].sign()
to_log["fake_scores"] = D_out_fake["score"]
to_log = {key: item.mean().detach() for key, item in to_log.items()}
self.gradient_step_D += 1
return total_loss.mean(), to_log
def G_loss(self, batch: dict, grad_scaler):
with torch.cuda.amp.autocast(enabled=tops.AMP()):
to_log = {}
# Forward through G and D
G_fake = self.G(**batch)
D_out_fake = forward_D_fake(batch, G_fake["img"], self.D)
# Adversarial Loss
total_loss = nsgan_g_loss(D_out_fake["score"]).view(-1)
to_log["g_loss"] = total_loss.mean()
tops.assert_shape(total_loss, (batch["img"].shape[0], ))
if self.do_PL_Reg and logger.global_step() >= self.pl_start_nimg:
pl_reg, to_log_ = self.pl_reg(self.G, batch, grad_scaler=grad_scaler)
total_loss = total_loss + pl_reg.mean()
to_log.update(to_log_)
to_log = {key: item.mean().detach() for key, item in to_log.items()}
return total_loss.mean(), to_log
|