import torch.nn as nn from lib.config import cfg import torch from lib.networks.renderer import if_clight_renderer from lib.train import make_optimizer class NetworkWrapper(nn.Module): def __init__(self, net): super(NetworkWrapper, self).__init__() self.net = net self.renderer = if_clight_renderer.Renderer(self.net) self.img2mse = lambda x, y : torch.mean((x - y) ** 2) self.acc_crit = torch.nn.functional.smooth_l1_loss def forward(self, batch): ret = self.renderer.render(batch) scalar_stats = {} loss = 0 mask = batch['mask_at_box'] img_loss = self.img2mse(ret['rgb_map'][mask], batch['rgb'][mask]) scalar_stats.update({'img_loss': img_loss}) loss += img_loss if 'rgb0' in ret: img_loss0 = self.img2mse(ret['rgb0'], batch['rgb']) scalar_stats.update({'img_loss0': img_loss0}) loss += img_loss0 scalar_stats.update({'loss': loss}) image_stats = {} return ret, loss, scalar_stats, image_stats