import os import time import logging import argparse import numpy as np import cv2 as cv import trimesh import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from shutil import copyfile from icecream import ic from tqdm import tqdm from pyhocon import ConfigFactory from models.dataset_mvdiff import Dataset from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF from models.renderer import NeuSRenderer import pdb import math def ranking_loss(error, penalize_ratio=0.7, type='mean'): error, indices = torch.sort(error) # only sum relatively small errors s_error = torch.index_select(error, 0, index=indices[: int(penalize_ratio * indices.shape[0])]) if type == 'mean': return torch.mean(s_error) elif type == 'sum': return torch.sum(s_error) class Runner: def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False, data_dir=None): self.device = torch.device('cuda') # Configuration self.conf_path = conf_path f = open(self.conf_path) conf_text = f.read() conf_text = conf_text.replace('CASE_NAME', case) f.close() self.conf = ConfigFactory.parse_string(conf_text) self.conf['dataset']['data_dir'] = data_dir self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case) self.base_exp_dir = self.conf['general.base_exp_dir'] os.makedirs(self.base_exp_dir, exist_ok=True) self.dataset = Dataset(self.conf['dataset']) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=self.conf['train']['batch_size'], shuffle=True, num_workers=64, ) self.iter_step = 1 # Training parameters self.end_iter = self.conf.get_int('train.end_iter') self.save_freq = self.conf.get_int('train.save_freq') self.report_freq = self.conf.get_int('train.report_freq') self.val_freq = self.conf.get_int('train.val_freq') self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') self.batch_size = self.conf.get_int('train.batch_size') self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') self.learning_rate = self.conf.get_float('train.learning_rate') self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) # Weights self.color_weight = self.conf.get_float('train.color_weight') self.igr_weight = self.conf.get_float('train.igr_weight') self.mask_weight = self.conf.get_float('train.mask_weight') self.normal_weight = self.conf.get_float('train.normal_weight') self.sparse_weight = self.conf.get_float('train.sparse_weight') self.is_continue = is_continue self.mode = mode self.model_list = [] self.writer = None # Networks params_to_train_slow = [] self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device) self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device) # params_to_train += list(self.nerf_outside.parameters()) params_to_train_slow += list(self.sdf_network.parameters()) params_to_train_slow += list(self.deviation_network.parameters()) # params_to_train += list(self.color_network.parameters()) self.optimizer = torch.optim.Adam( [{'params': params_to_train_slow}, {'params': self.color_network.parameters(), 'lr': self.learning_rate * 2}], lr=self.learning_rate ) self.renderer = NeuSRenderer( self.nerf_outside, self.sdf_network, self.deviation_network, self.color_network, **self.conf['model.neus_renderer'] ) # Load checkpoint latest_model_name = None if is_continue: model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) model_list = [] for model_name in model_list_raw: if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter: model_list.append(model_name) model_list.sort() latest_model_name = model_list[-1] if latest_model_name is not None: logging.info('Find checkpoint: {}'.format(latest_model_name)) self.load_checkpoint(latest_model_name) # Backup codes and configs for debug if self.mode[:5] == 'train': self.file_backup() def train(self): self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) self.update_learning_rate() res_step = self.end_iter - self.iter_step image_perm = self.get_image_perm() num_train_epochs = math.ceil(res_step / len(self.dataloader)) print("training ", num_train_epochs, " epoches") for epoch in range(num_train_epochs): # for iter_i in tqdm(range(res_step)): print("epoch ", epoch) for iter_i, data in enumerate(self.dataloader): # img_idx = image_perm[self.iter_step % len(image_perm)] # data = self.dataset.gen_random_rays_at(img_idx, self.batch_size) data = data.cuda() rays_o, rays_d, true_rgb, mask, true_normal, cosines = ( data[:, :3], data[:, 3:6], data[:, 6:9], data[:, 9:10], data[:, 10:13], data[:, 13:], ) # near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) near, far = self.dataset.get_near_far() background_rgb = None if self.use_white_bkgd: background_rgb = torch.ones([1, 3]) if self.mask_weight > 0.0: mask = (mask > 0.5).float() else: mask = torch.ones_like(mask) cosines[cosines > -0.1] = 0 mask = ((mask > 0) & (cosines < -0.1)).to(torch.float32) mask_sum = mask.sum() + 1e-5 render_out = self.renderer.render( rays_o, rays_d, near, far, background_rgb=background_rgb, cos_anneal_ratio=self.get_cos_anneal_ratio() ) color_fine = render_out['color_fine'] s_val = render_out['s_val'] cdf_fine = render_out['cdf_fine'] gradient_error = render_out['gradient_error'] weight_max = render_out['weight_max'] weight_sum = render_out['weight_sum'] # Loss # color_error = (color_fine - true_rgb) * mask # color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum color_errors = (color_fine - true_rgb).abs().sum(dim=1) color_fine_loss = ranking_loss(color_errors[mask[:, 0] > 0]) psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb) ** 2 * mask).sum() / (mask_sum * 3.0)).sqrt()) eikonal_loss = gradient_error # pdb.set_trace() mask_errors = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask, reduction='none') mask_loss = ranking_loss(mask_errors[:, 0], penalize_ratio=0.8) def feasible(key): return (key in render_out) and (render_out[key] is not None) # calculate normal loss n_samples = self.renderer.n_samples + self.renderer.n_importance normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] if feasible('inside_sphere'): normals = normals * render_out['inside_sphere'][..., None] normals = normals.sum(dim=1) # pdb.set_trace() normal_errors = 1 - F.cosine_similarity(normals, true_normal, dim=1) # normal_error = normal_error * mask[:, 0] # normal_loss = F.l1_loss(normal_error, torch.zeros_like(normal_error), reduction='sum') / mask_sum normal_errors = normal_errors * torch.exp(cosines.abs()[:, 0]) / torch.exp(cosines.abs()).sum() normal_loss = ranking_loss(normal_errors[mask[:, 0] > 0], penalize_ratio=0.9, type='sum') sparse_loss = render_out['sparse_loss'] loss = ( color_fine_loss * self.color_weight + eikonal_loss * self.igr_weight + sparse_loss * self.sparse_weight + mask_loss * self.mask_weight + normal_loss * self.normal_weight ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.writer.add_scalar('Loss/loss', loss, self.iter_step) self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step) self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step) self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step) self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step) self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step) self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) if self.iter_step % self.report_freq == 0: print(self.base_exp_dir) print( 'iter:{:8>d} loss = {:4>f} color_ls = {:4>f} eik_ls = {:4>f} normal_ls = {:4>f} mask_ls = {:4>f} sparse_ls = {:4>f} lr={:5>f}'.format( self.iter_step, loss, color_fine_loss, eikonal_loss, normal_loss, mask_loss, sparse_loss, self.optimizer.param_groups[0]['lr'], ) ) print('iter:{:8>d} s_val = {:4>f}'.format(self.iter_step, s_val.mean())) if self.iter_step % self.val_mesh_freq == 0: self.validate_mesh(resolution=256) self.update_learning_rate() self.iter_step += 1 if self.iter_step % self.val_freq == 0: self.validate_image(idx=0) self.validate_image(idx=1) self.validate_image(idx=2) self.validate_image(idx=3) if self.iter_step % self.save_freq == 0: self.save_checkpoint() if self.iter_step % len(image_perm) == 0: image_perm = self.get_image_perm() def get_image_perm(self): return torch.randperm(self.dataset.n_images) def get_cos_anneal_ratio(self): if self.anneal_end == 0.0: return 1.0 else: return np.min([1.0, self.iter_step / self.anneal_end]) def update_learning_rate(self): if self.iter_step < self.warm_up_end: learning_factor = self.iter_step / self.warm_up_end else: alpha = self.learning_rate_alpha progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha for g in self.optimizer.param_groups: g['lr'] = self.learning_rate * learning_factor def file_backup(self): dir_lis = self.conf['general.recording'] os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) for dir_name in dir_lis: cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) os.makedirs(cur_dir, exist_ok=True) files = os.listdir(dir_name) for f_name in files: if f_name[-3:] == '.py': copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) def load_checkpoint(self, checkpoint_name): checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) self.nerf_outside.load_state_dict(checkpoint['nerf']) self.sdf_network.load_state_dict(checkpoint['sdf_network_fine']) self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) self.color_network.load_state_dict(checkpoint['color_network_fine']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.iter_step = checkpoint['iter_step'] logging.info('End') def save_checkpoint(self): checkpoint = { 'nerf': self.nerf_outside.state_dict(), 'sdf_network_fine': self.sdf_network.state_dict(), 'variance_network_fine': self.deviation_network.state_dict(), 'color_network_fine': self.color_network.state_dict(), 'optimizer': self.optimizer.state_dict(), 'iter_step': self.iter_step, } os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) def validate_image(self, idx=-1, resolution_level=-1): if idx < 0: idx = np.random.randint(self.dataset.n_images) print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) if resolution_level < 0: resolution_level = self.validate_resolution_level rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) H, W, _ = rays_o.shape rays_o = rays_o.reshape(-1, 3).split(self.batch_size) rays_d = rays_d.reshape(-1, 3).split(self.batch_size) out_rgb_fine = [] out_normal_fine = [] out_mask = [] for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): # near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) near, far = self.dataset.get_near_far() background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None render_out = self.renderer.render( rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb ) def feasible(key): return (key in render_out) and (render_out[key] is not None) if feasible('color_fine'): out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) if feasible('gradients') and feasible('weights'): n_samples = self.renderer.n_samples + self.renderer.n_importance normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] if feasible('inside_sphere'): normals = normals * render_out['inside_sphere'][..., None] normals = normals.sum(dim=1).detach().cpu().numpy() out_normal_fine.append(normals) if feasible('weight_sum'): out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy()) del render_out img_fine = None if len(out_rgb_fine) > 0: img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) mask_map = None if len(out_mask) > 0: mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, -1]) * 256).clip(0, 255) normal_img = None if len(out_normal_fine) > 0: normal_img = np.concatenate(out_normal_fine, axis=0) rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]).reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255) os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) for i in range(img_fine.shape[-1]): if len(out_rgb_fine) > 0: cv.imwrite( os.path.join(self.base_exp_dir, 'validations_fine', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), np.concatenate( [ img_fine[..., i], self.dataset.image_at(idx, resolution_level=resolution_level), self.dataset.mask_at(idx, resolution_level=resolution_level), ] ), ) if len(out_normal_fine) > 0: cv.imwrite( os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), np.concatenate([normal_img[..., i], self.dataset.normal_cam_at(idx, resolution_level=resolution_level)])[:, :, ::-1], ) if len(out_mask) > 0: cv.imwrite(os.path.join(self.base_exp_dir, 'normals', '{:0>8d}_{}_{}_mask.png'.format(self.iter_step, i, idx)), mask_map[..., i]) def save_maps(self, idx, img_idx, resolution_level=1): view_types = ['front', 'back', 'left', 'right'] print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) H, W, _ = rays_o.shape rays_o = rays_o.reshape(-1, 3).split(self.batch_size) rays_d = rays_d.reshape(-1, 3).split(self.batch_size) out_rgb_fine = [] out_normal_fine = [] out_mask = [] for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): # near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) near, far = self.dataset.get_near_far() background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None render_out = self.renderer.render( rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb ) def feasible(key): return (key in render_out) and (render_out[key] is not None) if feasible('color_fine'): out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) if feasible('gradients') and feasible('weights'): n_samples = self.renderer.n_samples + self.renderer.n_importance normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] if feasible('inside_sphere'): normals = normals * render_out['inside_sphere'][..., None] normals = normals.sum(dim=1).detach().cpu().numpy() out_normal_fine.append(normals) if feasible('weight_sum'): out_mask.append(render_out['weight_sum'].detach().clip(0, 1).cpu().numpy()) del render_out img_fine = None if len(out_rgb_fine) > 0: img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) mask_map = None if len(out_mask) > 0: mask_map = (np.concatenate(out_mask, axis=0).reshape([H, W, 1]) * 256).clip(0, 255) normal_img = None if len(out_normal_fine) > 0: normal_img = np.concatenate(out_normal_fine, axis=0) # rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) world_normal_img = (normal_img.reshape([H, W, 3]) * 128 + 128).clip(0, 255) os.makedirs(os.path.join(self.base_exp_dir, 'coarse_maps'), exist_ok=True) img_rgba = np.concatenate([img_fine[:, :, ::-1], mask_map], axis=-1) normal_rgba = np.concatenate([world_normal_img[:, :, ::-1], mask_map], axis=-1) cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_mlp_%03d_%s.png" % (img_idx, view_types[idx])), img_rgba) cv.imwrite(os.path.join(self.base_exp_dir, 'coarse_maps', "normals_grad_%03d_%s.png" % (img_idx, view_types[idx])), normal_rgba) def render_novel_image(self, idx_0, idx_1, ratio, resolution_level): """ Interpolate view between two cameras. """ rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level) H, W, _ = rays_o.shape rays_o = rays_o.reshape(-1, 3).split(self.batch_size) rays_d = rays_d.reshape(-1, 3).split(self.batch_size) out_rgb_fine = [] for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): # near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) near, far = self.dataset.get_near_far() background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None render_out = self.renderer.render( rays_o_batch, rays_d_batch, near, far, cos_anneal_ratio=self.get_cos_anneal_ratio(), background_rgb=background_rgb ) out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) del render_out img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8) return img_fine def validate_mesh(self, world_space=False, resolution=64, threshold=0.0): bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) vertices, triangles, vertex_colors = self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold) os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) if world_space: vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None] mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors) # mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step))) # export as glb mesh.export(os.path.join(self.base_exp_dir, 'meshes', 'tmp.glb')) logging.info('End') def interpolate_view(self, img_idx_0, img_idx_1): images = [] n_frames = 60 for i in range(n_frames): print(i) images.append(self.render_novel_image(img_idx_0, img_idx_1, np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, resolution_level=4)) for i in range(n_frames): images.append(images[n_frames - i - 1]) fourcc = cv.VideoWriter_fourcc(*'mp4v') video_dir = os.path.join(self.base_exp_dir, 'render') os.makedirs(video_dir, exist_ok=True) h, w, _ = images[0].shape writer = cv.VideoWriter(os.path.join(video_dir, '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), fourcc, 30, (w, h)) for image in images: writer.write(image) writer.release() if __name__ == '__main__': print('Hello Wooden') torch.set_default_tensor_type('torch.FloatTensor') FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" logging.basicConfig(level=logging.DEBUG, format=FORMAT) parser = argparse.ArgumentParser() parser.add_argument('--conf', type=str, default='./confs/base.conf') parser.add_argument('--mode', type=str, default='train') parser.add_argument('--mcube_threshold', type=float, default=0.0) parser.add_argument('--is_continue', default=False, action="store_true") parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--case', type=str, default='') parser.add_argument('--data_dir', type=str, default='') args = parser.parse_args() torch.cuda.set_device(args.gpu) runner = Runner(args.conf, args.mode, args.case, args.is_continue, args.data_dir) if args.mode == 'train': runner.train() runner.validate_mesh(world_space=False, resolution=256, threshold=args.mcube_threshold) elif args.mode == 'save_maps': for i in range(4): runner.save_maps(idx=i, img_idx=runner.dataset.object_viewidx) elif args.mode == 'validate_mesh': runner.validate_mesh(world_space=False, resolution=512, threshold=args.mcube_threshold) elif args.mode.startswith('interpolate'): # Interpolate views given two image indices _, img_idx_0, img_idx_1 = args.mode.split('_') img_idx_0 = int(img_idx_0) img_idx_1 = int(img_idx_1) runner.interpolate_view(img_idx_0, img_idx_1)