from internal import stepfun import numpy as np from matplotlib import cm def weighted_percentile(x, w, ps, assume_sorted=False): """Compute the weighted percentile(s) of a single vector.""" if len(x.shape) != len(w.shape): w = np.broadcast_to(w[..., None], x.shape) x = x.reshape([-1]) w = w.reshape([-1]) if not assume_sorted: sortidx = np.argsort(x) x, w = x[sortidx], w[sortidx] acc_w = np.cumsum(w) return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x) def sinebow(h): """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" f = lambda x: np.sin(np.pi * x) ** 2 return np.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) def matte(vis, acc, dark=0.8, light=1.0, width=8): """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" bg_mask = np.logical_xor( (np.arange(acc.shape[0]) % (2 * width) // width)[:, None], (np.arange(acc.shape[1]) % (2 * width) // width)[None, :]) bg = np.where(bg_mask, light, dark) return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] def visualize_cmap(value, weight, colormap, lo=None, hi=None, percentile=99., curve_fn=lambda x: x, modulus=None, matte_background=True): """Visualize a 1D image and a 1D weighting according to some colormap. Args: value: A 1D image. weight: A weight map, in [0, 1]. colormap: A colormap function. lo: The lower bound to use when rendering, if None then use a percentile. hi: The upper bound to use when rendering, if None then use a percentile. percentile: What percentile of the value map to crop to when automatically generating `lo` and `hi`. Depends on `weight` as well as `value'. curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. matte_background: If True, matte the image over a checkerboard. Returns: A colormap rendering. """ # Identify the values that bound the middle of `value' according to `weight`. lo_auto, hi_auto = weighted_percentile( value, weight, [50 - percentile / 2, 50 + percentile / 2], assume_sorted=True) # If `lo` or `hi` are None, use the automatically-computed bounds above. eps = np.finfo(np.float32).eps lo = lo or (lo_auto - eps) hi = hi or (hi_auto + eps) # Curve all values. value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] # Wrap the values around if requested. if modulus: value = np.mod(value, modulus) / modulus else: # Otherwise, just scale to [0, 1]. value = np.nan_to_num( np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)) if colormap: colorized = colormap(value)[:, :, :3] else: if len(value.shape) != 3: raise ValueError(f'value must have 3 dims but has {len(value.shape)}') if value.shape[-1] != 3: raise ValueError( f'value must have 3 channels but has {len(value.shape[-1])}') colorized = value return matte(colorized, weight) if matte_background else colorized def visualize_coord_mod(coords, acc): """Visualize the coordinate of each point within its "cell".""" return matte(((coords + 1) % 2) / 2, acc) def visualize_rays(dist, dist_range, weights, rgbs, accumulate=False, renormalize=False, resolution=2048, bg_color=0.8): """Visualize a bundle of rays.""" dist_vis = np.linspace(*dist_range, resolution + 1) vis_rgb, vis_alpha = [], [] for ds, ws, rs in zip(dist, weights, rgbs): vis_rs, vis_ws = [], [] for d, w, r in zip(ds, ws, rs): if accumulate: # Produce the accumulated color and weight at each point along the ray. w_csum = np.cumsum(w, axis=0) rw_csum = np.cumsum((r * w[:, None]), axis=0) eps = np.finfo(np.float32).eps r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum vis_rs.append(stepfun.resample_np(dist_vis, d, r.T, use_avg=True).T) vis_ws.append(stepfun.resample_np(dist_vis, d, w.T, use_avg=True).T) vis_rgb.append(np.stack(vis_rs)) vis_alpha.append(np.stack(vis_ws)) vis_rgb = np.stack(vis_rgb, axis=1) vis_alpha = np.stack(vis_alpha, axis=1) if renormalize: # Scale the alphas so that the largest value is 1, for visualization. vis_alpha /= np.maximum(np.finfo(np.float32).eps, np.max(vis_alpha)) if resolution > vis_rgb.shape[0]: rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) stride = rep * vis_rgb.shape[1] vis_rgb = np.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:]) vis_alpha = np.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:]) # Add a strip of background pixels after each set of levels of rays. vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) vis_rgb = np.concatenate([vis_rgb, np.zeros_like(vis_rgb[:, :1])], axis=1).reshape((-1,) + vis_rgb.shape[2:]) vis_alpha = np.concatenate( [vis_alpha, np.zeros_like(vis_alpha[:, :1])], axis=1).reshape((-1,) + vis_alpha.shape[2:]) # Matte the RGB image over the background. vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None] # Remove the final row of background pixels. vis = vis[:-1] vis_alpha = vis_alpha[:-1] return vis, vis_alpha def visualize_suite(rendering, batch): """A wrapper around other visualizations for easy integration.""" depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) rgb = rendering['rgb'] acc = rendering['acc'] distance_mean = rendering['distance_mean'] distance_median = rendering['distance_median'] distance_p5 = rendering['distance_percentile_5'] distance_p95 = rendering['distance_percentile_95'] acc = np.where(np.isnan(distance_mean), np.zeros_like(acc), acc) # The xyz coordinates where rays terminate. coords = batch['origins'] + batch['directions'] * distance_mean[:, :, None] vis_depth_mean, vis_depth_median = [ visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) for x in [distance_mean, distance_median] ] # Render three depth percentiles directly to RGB channels, where the spacing # determines the color. delta == big change, epsilon = small change. # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] # Purple: A thin but even density, [x-delta, x, x+delta] # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] vis_depth_triplet = visualize_cmap( np.stack( [2 * distance_median - distance_p5, distance_median, distance_p95], axis=-1), acc, None, curve_fn=lambda x: np.log(x + np.finfo(np.float32).eps)) dist = rendering['ray_sdist'] dist_range = (0, 1) weights = rendering['ray_weights'] rgbs = [np.clip(r, 0, 1) for r in rendering['ray_rgbs']] vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) sqrt_weights = [np.sqrt(w) for w in weights] sqrt_ray_weights, ray_alpha = visualize_rays( dist, dist_range, [np.ones_like(lw) for lw in sqrt_weights], [lw[..., None] for lw in sqrt_weights], bg_color=0, ) sqrt_ray_weights = sqrt_ray_weights[..., 0] null_color = np.array([1., 0., 0.]) vis_ray_weights = np.where( ray_alpha[:, :, None] == 0, null_color[None, None], visualize_cmap( sqrt_ray_weights, np.ones_like(sqrt_ray_weights), cm.get_cmap('gray'), lo=0, hi=1, matte_background=False, ), ) vis = { 'color': rgb, 'acc': acc, 'color_matte': matte(rgb, acc), 'depth_mean': vis_depth_mean, 'depth_median': vis_depth_median, 'depth_triplet': vis_depth_triplet, 'coords_mod': visualize_coord_mod(coords, acc), 'ray_colors': vis_ray_colors, 'ray_weights': vis_ray_weights, } if 'rgb_cc' in rendering: vis['color_corrected'] = rendering['rgb_cc'] # Render every item named "normals*". for key, val in rendering.items(): if key.startswith('normals'): vis[key] = matte(val / 2. + 0.5, acc) if 'roughness' in rendering: vis['roughness'] = matte(np.tanh(rendering['roughness']), acc) return vis