Spaces:
Build error
Build error
Update apps/Normal.py
Browse files- apps/Normal.py +31 -38
apps/Normal.py
CHANGED
|
@@ -10,13 +10,9 @@ import pytorch_lightning as pl
|
|
| 10 |
torch.backends.cudnn.benchmark = True
|
| 11 |
|
| 12 |
logging.getLogger("lightning").setLevel(logging.ERROR)
|
| 13 |
-
import warnings
|
| 14 |
-
|
| 15 |
-
warnings.filterwarnings("ignore")
|
| 16 |
|
| 17 |
|
| 18 |
class Normal(pl.LightningModule):
|
| 19 |
-
|
| 20 |
def __init__(self, cfg):
|
| 21 |
super(Normal, self).__init__()
|
| 22 |
self.cfg = cfg
|
|
@@ -42,28 +38,26 @@ class Normal(pl.LightningModule):
|
|
| 42 |
weight_decay = self.cfg.weight_decay
|
| 43 |
momentum = self.cfg.momentum
|
| 44 |
|
| 45 |
-
optim_params_N_F = [
|
| 46 |
-
"params": self.netG.netF.parameters(),
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
optim_params_N_B = [{
|
| 50 |
-
"params": self.netG.netB.parameters(),
|
| 51 |
-
"lr": self.lr_N
|
| 52 |
-
}]
|
| 53 |
|
| 54 |
-
optimizer_N_F = torch.optim.Adam(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
|
| 58 |
-
optimizer_N_B = torch.optim.Adam(
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
| 63 |
-
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
|
|
|
| 64 |
|
| 65 |
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
| 66 |
-
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
|
|
|
| 67 |
|
| 68 |
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
| 69 |
optims = [optimizer_N_F, optimizer_N_B]
|
|
@@ -78,11 +72,13 @@ class Normal(pl.LightningModule):
|
|
| 78 |
for name in render_tensor.keys():
|
| 79 |
result_list.append(
|
| 80 |
resize(
|
| 81 |
-
((render_tensor[name].cpu().numpy()[0] + 1.0) /
|
| 82 |
-
|
|
|
|
| 83 |
(height, height),
|
| 84 |
anti_aliasing=True,
|
| 85 |
-
)
|
|
|
|
| 86 |
result_array = np.concatenate(result_list, axis=1)
|
| 87 |
|
| 88 |
return result_array
|
|
@@ -96,16 +92,14 @@ class Normal(pl.LightningModule):
|
|
| 96 |
for name in self.in_nml:
|
| 97 |
in_tensor[name] = batch[name]
|
| 98 |
|
| 99 |
-
FB_tensor = {
|
| 100 |
-
|
| 101 |
-
"normal_B": batch["normal_B"]
|
| 102 |
-
}
|
| 103 |
|
| 104 |
self.netG.train()
|
| 105 |
|
| 106 |
preds_F, preds_B = self.netG(in_tensor)
|
| 107 |
-
error_NF, error_NB = self.netG.get_norm_error(
|
| 108 |
-
|
| 109 |
|
| 110 |
(opt_nf, opt_nb) = self.optimizers()
|
| 111 |
|
|
@@ -175,19 +169,18 @@ class Normal(pl.LightningModule):
|
|
| 175 |
for name in self.in_nml:
|
| 176 |
in_tensor[name] = batch[name]
|
| 177 |
|
| 178 |
-
FB_tensor = {
|
| 179 |
-
|
| 180 |
-
"normal_B": batch["normal_B"]
|
| 181 |
-
}
|
| 182 |
|
| 183 |
self.netG.train()
|
| 184 |
|
| 185 |
preds_F, preds_B = self.netG(in_tensor)
|
| 186 |
-
error_NF, error_NB = self.netG.get_norm_error(
|
| 187 |
-
|
| 188 |
|
| 189 |
-
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train)
|
| 190 |
-
|
|
|
|
| 191 |
|
| 192 |
with torch.no_grad():
|
| 193 |
nmlF, nmlB = self.netG(in_tensor)
|
|
@@ -217,4 +210,4 @@ class Normal(pl.LightningModule):
|
|
| 217 |
|
| 218 |
tf_log = tf_log_convert(metrics_log)
|
| 219 |
|
| 220 |
-
return {"log": tf_log}
|
|
|
|
| 10 |
torch.backends.cudnn.benchmark = True
|
| 11 |
|
| 12 |
logging.getLogger("lightning").setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class Normal(pl.LightningModule):
|
|
|
|
| 16 |
def __init__(self, cfg):
|
| 17 |
super(Normal, self).__init__()
|
| 18 |
self.cfg = cfg
|
|
|
|
| 38 |
weight_decay = self.cfg.weight_decay
|
| 39 |
momentum = self.cfg.momentum
|
| 40 |
|
| 41 |
+
optim_params_N_F = [
|
| 42 |
+
{"params": self.netG.netF.parameters(), "lr": self.lr_N}]
|
| 43 |
+
optim_params_N_B = [
|
| 44 |
+
{"params": self.netG.netB.parameters(), "lr": self.lr_N}]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
optimizer_N_F = torch.optim.Adam(
|
| 47 |
+
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay
|
| 48 |
+
)
|
| 49 |
|
| 50 |
+
optimizer_N_B = torch.optim.Adam(
|
| 51 |
+
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay
|
| 52 |
+
)
|
| 53 |
|
| 54 |
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR(
|
| 55 |
+
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
| 56 |
+
)
|
| 57 |
|
| 58 |
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR(
|
| 59 |
+
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma
|
| 60 |
+
)
|
| 61 |
|
| 62 |
self.schedulers = [scheduler_N_F, scheduler_N_B]
|
| 63 |
optims = [optimizer_N_F, optimizer_N_B]
|
|
|
|
| 72 |
for name in render_tensor.keys():
|
| 73 |
result_list.append(
|
| 74 |
resize(
|
| 75 |
+
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose(
|
| 76 |
+
1, 2, 0
|
| 77 |
+
),
|
| 78 |
(height, height),
|
| 79 |
anti_aliasing=True,
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
result_array = np.concatenate(result_list, axis=1)
|
| 83 |
|
| 84 |
return result_array
|
|
|
|
| 92 |
for name in self.in_nml:
|
| 93 |
in_tensor[name] = batch[name]
|
| 94 |
|
| 95 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
| 96 |
+
"normal_B": batch["normal_B"]}
|
|
|
|
|
|
|
| 97 |
|
| 98 |
self.netG.train()
|
| 99 |
|
| 100 |
preds_F, preds_B = self.netG(in_tensor)
|
| 101 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
| 102 |
+
preds_F, preds_B, FB_tensor)
|
| 103 |
|
| 104 |
(opt_nf, opt_nb) = self.optimizers()
|
| 105 |
|
|
|
|
| 169 |
for name in self.in_nml:
|
| 170 |
in_tensor[name] = batch[name]
|
| 171 |
|
| 172 |
+
FB_tensor = {"normal_F": batch["normal_F"],
|
| 173 |
+
"normal_B": batch["normal_B"]}
|
|
|
|
|
|
|
| 174 |
|
| 175 |
self.netG.train()
|
| 176 |
|
| 177 |
preds_F, preds_B = self.netG(in_tensor)
|
| 178 |
+
error_NF, error_NB = self.netG.get_norm_error(
|
| 179 |
+
preds_F, preds_B, FB_tensor)
|
| 180 |
|
| 181 |
+
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or (
|
| 182 |
+
batch_idx == 0
|
| 183 |
+
):
|
| 184 |
|
| 185 |
with torch.no_grad():
|
| 186 |
nmlF, nmlB = self.netG(in_tensor)
|
|
|
|
| 210 |
|
| 211 |
tf_log = tf_log_convert(metrics_log)
|
| 212 |
|
| 213 |
+
return {"log": tf_log}
|