|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility functions.""" |
|
import collections |
|
import os |
|
from os import path |
|
import pickle |
|
from absl import flags |
|
import flax |
|
import jax |
|
import jax.numpy as jnp |
|
import jax.scipy as jsp |
|
import numpy as np |
|
from PIL import Image |
|
import yaml |
|
from nerf import datasets |
|
|
|
BASE_DIR = "" |
|
INTERNAL = False |
|
|
|
|
|
@flax.struct.dataclass |
|
class TrainState: |
|
optimizer: flax.optim.Optimizer |
|
|
|
|
|
@flax.struct.dataclass |
|
class Stats: |
|
loss: float |
|
psnr: float |
|
loss_c: float |
|
psnr_c: float |
|
weight_l2: float |
|
|
|
|
|
Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs")) |
|
|
|
|
|
def namedtuple_map(fn, tup): |
|
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple.""" |
|
return type(tup)(*map(fn, tup)) |
|
|
|
|
|
def define_flags(): |
|
"""Define flags for both training and evaluation modes.""" |
|
flags.DEFINE_string("train_dir", None, "where to store ckpts and logs") |
|
flags.DEFINE_string("data_dir", None, "input data directory.") |
|
flags.DEFINE_string("config", None, |
|
"using config files to set hyperparameters.") |
|
|
|
|
|
flags.DEFINE_bool("use_semantic_loss", True, |
|
"whether use semantic loss or not") |
|
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP") |
|
flags.DEFINE_string("clip_output_dtype", "float16", |
|
"float32/ float16 (float16 for memory saving)") |
|
flags.DEFINE_integer("sc_loss_every", 16, |
|
"no. of steps to take before performing semantic loss evaluation") |
|
flags.DEFINE_float("sc_loss_mult", 1e-2, |
|
"weighting for semantic loss from CLIP") |
|
|
|
|
|
|
|
flags.DEFINE_enum("dataset", "blender", |
|
list(k for k in datasets.dataset_dict.keys()), |
|
"The type of dataset feed to nerf.") |
|
flags.DEFINE_enum( |
|
"batching", "single_image", ["single_image", "all_images"], |
|
"source of ray sampling when collecting training batch," |
|
"single_image for sampling from only one image in a batch," |
|
"all_images for sampling from all the training images.") |
|
flags.DEFINE_bool( |
|
"white_bkgd", True, "using white color as default background." |
|
"(used in the blender dataset only)") |
|
flags.DEFINE_integer("batch_size", 1024, |
|
"the number of rays in a mini-batch (for training).") |
|
flags.DEFINE_integer("factor", 4, |
|
"the downsample factor of images, 0 for no downsample.") |
|
flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.") |
|
flags.DEFINE_bool( |
|
"render_path", False, "render generated path if set true." |
|
"(used in the llff dataset only)") |
|
flags.DEFINE_integer( |
|
"llffhold", 8, "will take every 1/N images as LLFF test set." |
|
"(used in the llff dataset only)") |
|
flags.DEFINE_bool( |
|
"use_pixel_centers", False, |
|
"If True, generate rays through the center of each pixel. Note: While " |
|
"this is the correct way to handle rays, it is not the way rays are " |
|
"handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR " |
|
"compared to Vanilla NeRF.") |
|
|
|
|
|
flags.DEFINE_string("model", "nerf", "name of model to use.") |
|
flags.DEFINE_float("near", 2., "near clip of volumetric rendering.") |
|
flags.DEFINE_float("far", 6., "far clip of volumentric rendering.") |
|
flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.") |
|
flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.") |
|
flags.DEFINE_integer("net_depth_condition", 1, |
|
"depth of the second part of MLP.") |
|
flags.DEFINE_integer("net_width_condition", 128, |
|
"width of the second part of MLP.") |
|
flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay") |
|
flags.DEFINE_integer( |
|
"skip_layer", 4, "add a skip connection to the output vector of every" |
|
"skip_layer layers.") |
|
flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.") |
|
flags.DEFINE_integer("num_sigma_channels", 1, |
|
"the number of density channels.") |
|
flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.") |
|
flags.DEFINE_integer("min_deg_point", 0, |
|
"Minimum degree of positional encoding for points.") |
|
flags.DEFINE_integer("max_deg_point", 10, |
|
"Maximum degree of positional encoding for points.") |
|
flags.DEFINE_integer("deg_view", 4, |
|
"Degree of positional encoding for viewdirs.") |
|
flags.DEFINE_integer( |
|
"num_coarse_samples", 64, |
|
"the number of samples on each ray for the coarse model.") |
|
flags.DEFINE_integer("num_fine_samples", 128, |
|
"the number of samples on each ray for the fine model.") |
|
flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.") |
|
flags.DEFINE_float( |
|
"noise_std", None, "std dev of noise added to regularize sigma output." |
|
"(used in the llff dataset only)") |
|
flags.DEFINE_bool("lindisp", False, |
|
"sampling linearly in disparity rather than depth.") |
|
flags.DEFINE_string("net_activation", "relu", |
|
"activation function used within the MLP.") |
|
flags.DEFINE_string("rgb_activation", "sigmoid", |
|
"activation function used to produce RGB.") |
|
flags.DEFINE_string("sigma_activation", "relu", |
|
"activation function used to produce density.") |
|
flags.DEFINE_bool( |
|
"legacy_posenc_order", False, |
|
"If True, revert the positional encoding feature order to an older version of this codebase." |
|
) |
|
|
|
|
|
flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.") |
|
flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.") |
|
flags.DEFINE_integer( |
|
"lr_delay_steps", 0, "The number of steps at the beginning of " |
|
"training to reduce the learning rate by lr_delay_mult") |
|
flags.DEFINE_float( |
|
"lr_delay_mult", 1., "A multiplier on the learning rate when the step " |
|
"is < lr_delay_steps") |
|
flags.DEFINE_float("grad_max_norm", 0., |
|
"The gradient clipping magnitude (disabled if == 0).") |
|
flags.DEFINE_float("grad_max_val", 0., |
|
"The gradient clipping value (disabled if == 0).") |
|
|
|
flags.DEFINE_integer("max_steps", 1000000, |
|
"the number of optimization steps.") |
|
flags.DEFINE_integer("stop_sc_loss", 1000000, |
|
"the number of sc_loss optimization steps") |
|
flags.DEFINE_integer("save_every", 10000, |
|
"the number of steps to save a checkpoint.") |
|
flags.DEFINE_integer("print_every", 100, |
|
"the number of steps between reports to tensorboard.") |
|
flags.DEFINE_integer( |
|
"render_every", 5000, "the number of steps to render a test image," |
|
"better to be x00 for accurate step time record.") |
|
flags.DEFINE_integer("gc_every", 10000, |
|
"the number of steps to run python garbage collection.") |
|
flags.DEFINE_integer("few_shot", -1, |
|
"the number of images.") |
|
|
|
|
|
flags.DEFINE_bool( |
|
"eval_once", True, |
|
"evaluate the model only once if true, otherwise keeping evaluating new" |
|
"checkpoints if there's any.") |
|
flags.DEFINE_bool("save_output", True, |
|
"save predicted images to disk if True.") |
|
flags.DEFINE_integer( |
|
"chunk", 1024, |
|
"the size of chunks for evaluation inferences, set to the value that" |
|
"fits your GPU/TPU memory.") |
|
flags.DEFINE_bool("generate_gif_only", False, |
|
"in eval.py, we only generate GIF file for the trained model") |
|
|
|
|
|
def update_flags(args): |
|
"""Update the flags in `args` with the contents of the config YAML file.""" |
|
pth = path.join(BASE_DIR, args.config + ".yaml") |
|
with open_file(pth, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
invalid_args = list(set(configs.keys()) - set(dir(args))) |
|
if invalid_args: |
|
raise ValueError(f"Invalid args {invalid_args} in {pth}.") |
|
args.__dict__.update(configs) |
|
|
|
def open_file(pth, mode="r"): |
|
if not INTERNAL: |
|
return open(pth, mode=mode) |
|
|
|
|
|
def file_exists(pth): |
|
if not INTERNAL: |
|
return path.exists(pth) |
|
|
|
|
|
def listdir(pth): |
|
if not INTERNAL: |
|
return os.listdir(pth) |
|
|
|
|
|
def isdir(pth): |
|
if not INTERNAL: |
|
return path.isdir(pth) |
|
|
|
|
|
def makedirs(pth): |
|
if not INTERNAL: |
|
os.makedirs(pth) |
|
|
|
|
|
def render_image(render_fn, rays, rng, normalize_disp, chunk=8192): |
|
"""Render all the pixels of an image (in test mode). |
|
|
|
Args: |
|
render_fn: function, jit-ed render function. |
|
rays: a `Rays` namedtuple, the rays to be rendered. |
|
rng: jnp.ndarray, random number generator (used in training mode only). |
|
normalize_disp: bool, if true then normalize `disp` to [0, 1]. |
|
chunk: int, the size of chunks to render sequentially. |
|
|
|
Returns: |
|
rgb: jnp.ndarray, rendered color image. |
|
disp: jnp.ndarray, rendered disparity image. |
|
acc: jnp.ndarray, rendered accumulated weights per pixel. |
|
""" |
|
height, width = rays[0].shape[:2] |
|
num_rays = height * width |
|
rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays) |
|
unused_rng, key_0, key_1 = jax.random.split(rng, 3) |
|
host_id = jax.host_id() |
|
results = [] |
|
for i in range(0, num_rays, chunk): |
|
|
|
chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays) |
|
chunk_size = chunk_rays[0].shape[0] |
|
rays_remaining = chunk_size % jax.device_count() |
|
if rays_remaining != 0: |
|
padding = jax.device_count() - rays_remaining |
|
chunk_rays = namedtuple_map( |
|
lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays) |
|
else: |
|
padding = 0 |
|
|
|
|
|
rays_per_host = chunk_rays[0].shape[0] // jax.process_count() |
|
start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host |
|
chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays) |
|
chunk_results = render_fn(key_0, key_1, chunk_rays)[-1] |
|
results.append([unshard(x, padding) for x in chunk_results]) |
|
|
|
rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)] |
|
|
|
if normalize_disp: |
|
disp = (disp - disp.min()) / (disp.max() - disp.min()) |
|
return (rgb.reshape((height, width, -1)), disp.reshape( |
|
(height, width, -1)), acc.reshape((height, width, -1))) |
|
|
|
|
|
def compute_psnr(mse): |
|
"""Compute psnr value given mse (we assume the maximum pixel value is 1). |
|
|
|
Args: |
|
mse: float, mean square error of pixels. |
|
|
|
Returns: |
|
psnr: float, the psnr value. |
|
""" |
|
return -10. * jnp.log(mse) / jnp.log(10.) |
|
|
|
|
|
def compute_ssim(img0, |
|
img1, |
|
max_val, |
|
filter_size=11, |
|
filter_sigma=1.5, |
|
k1=0.01, |
|
k2=0.03, |
|
return_map=False): |
|
"""Computes SSIM from two images. |
|
|
|
This function was modeled after tf.image.ssim, and should produce comparable |
|
output. |
|
|
|
Args: |
|
img0: array. An image of size [..., width, height, num_channels]. |
|
img1: array. An image of size [..., width, height, num_channels]. |
|
max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. |
|
filter_size: int >= 1. Window size. |
|
filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. |
|
k1: float > 0. One of the SSIM dampening parameters. |
|
k2: float > 0. One of the SSIM dampening parameters. |
|
return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned |
|
|
|
Returns: |
|
Each image's mean SSIM, or a tensor of individual values if `return_map`. |
|
""" |
|
|
|
hw = filter_size // 2 |
|
shift = (2 * hw - filter_size + 1) / 2 |
|
f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2 |
|
filt = jnp.exp(-0.5 * f_i) |
|
filt /= jnp.sum(filt) |
|
|
|
|
|
filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid") |
|
filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid") |
|
|
|
|
|
num_dims = len(img0.shape) |
|
map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1]) |
|
for d in map_axes: |
|
filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d) |
|
filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d) |
|
filt_fn = lambda z: filt_fn1(filt_fn2(z)) |
|
|
|
mu0 = filt_fn(img0) |
|
mu1 = filt_fn(img1) |
|
mu00 = mu0 * mu0 |
|
mu11 = mu1 * mu1 |
|
mu01 = mu0 * mu1 |
|
sigma00 = filt_fn(img0 ** 2) - mu00 |
|
sigma11 = filt_fn(img1 ** 2) - mu11 |
|
sigma01 = filt_fn(img0 * img1) - mu01 |
|
|
|
|
|
|
|
sigma00 = jnp.maximum(0., sigma00) |
|
sigma11 = jnp.maximum(0., sigma11) |
|
sigma01 = jnp.sign(sigma01) * jnp.minimum( |
|
jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01)) |
|
|
|
c1 = (k1 * max_val) ** 2 |
|
c2 = (k2 * max_val) ** 2 |
|
numer = (2 * mu01 + c1) * (2 * sigma01 + c2) |
|
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) |
|
ssim_map = numer / denom |
|
ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims))) |
|
return ssim_map if return_map else ssim |
|
|
|
|
|
def save_img(img, pth): |
|
"""Save an image to disk. |
|
|
|
Args: |
|
img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1] |
|
before saved to pth. |
|
pth: string, path to save the image to. |
|
""" |
|
with open_file(pth, "wb") as imgout: |
|
Image.fromarray(np.array( |
|
(np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG") |
|
|
|
|
|
def learning_rate_decay(step, |
|
lr_init, |
|
lr_final, |
|
max_steps, |
|
lr_delay_steps=0, |
|
lr_delay_mult=1): |
|
"""Continuous learning rate decay function. |
|
|
|
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and |
|
is log-linearly interpolated elsewhere (equivalent to exponential decay). |
|
If lr_delay_steps>0 then the learning rate will be scaled by some smooth |
|
function of lr_delay_mult, such that the initial learning rate is |
|
lr_init*lr_delay_mult at the beginning of optimization but will be eased back |
|
to the normal learning rate when steps>lr_delay_steps. |
|
|
|
Args: |
|
step: int, the current optimization step. |
|
lr_init: float, the initial learning rate. |
|
lr_final: float, the final learning rate. |
|
max_steps: int, the number of steps during optimization. |
|
lr_delay_steps: int, the number of steps to delay the full learning rate. |
|
lr_delay_mult: float, the multiplier on the rate when delaying it. |
|
|
|
Returns: |
|
lr: the learning for current step 'step'. |
|
""" |
|
if lr_delay_steps > 0: |
|
|
|
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( |
|
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) |
|
else: |
|
delay_rate = 1. |
|
t = np.clip(step / max_steps, 0, 1) |
|
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) |
|
return delay_rate * log_lerp |
|
|
|
|
|
def shard(xs): |
|
"""Split data into shards for multiple devices along the first dimension.""" |
|
''' |
|
if 'embedding' in xs: |
|
xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels']) |
|
xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays']) |
|
xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0) |
|
xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays']) |
|
else: |
|
xs = jax.tree_map( |
|
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x |
|
, xs) |
|
|
|
return xs |
|
''' |
|
return jax.tree_map( |
|
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x |
|
, xs) |
|
|
|
|
|
def to_device(xs): |
|
"""Transfer data to devices (GPU/TPU).""" |
|
return jax.tree_map(jnp.array, xs) |
|
|
|
|
|
def unshard(x, padding=0): |
|
"""Collect the sharded tensor to the shape before sharding.""" |
|
y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:])) |
|
if padding > 0: |
|
y = y[:-padding] |
|
return y |
|
|
|
|
|
def write_pickle(data, fn): |
|
with open(fn, 'wb') as f: |
|
pickle.dump(data, f) |
|
return None |
|
|
|
|
|
def read_pickle(fn): |
|
with open(fn, 'rb') as f: |
|
data = pickle.load(f) |
|
return data |
|
|