Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
""" | |
The renderer is a module that takes in rays, decides where to sample along each | |
ray, and computes pixel colors using the volume rendering equation. | |
""" | |
import math | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from training_avatar_texture.volumetric_rendering.ray_marcher import MipRayMarcher2 | |
from training_avatar_texture.volumetric_rendering import math_utils | |
global Meshes, load_obj, rasterize_meshes | |
from pytorch3d.structures import Meshes | |
from pytorch3d.io import load_obj | |
from pytorch3d.renderer.mesh import rasterize_meshes | |
def generate_planes(): | |
""" | |
Defines planes by the three vectors that form the "axes" of the | |
plane. Should work with arbitrary number of planes and planes of | |
arbitrary orientation. | |
""" | |
return torch.tensor([[[1, 0, 0], | |
[0, 1, 0], | |
[0, 0, 1]], | |
[[1, 0, 0], | |
[0, 0, 1], | |
[0, 1, 0]], | |
[[0, 0, 1], | |
[0, 1, 0], | |
[1, 0, 0]]], dtype=torch.float32) | |
def project_onto_planes(planes, coordinates): | |
""" | |
Does a projection of a 3D point onto a batch of 2D planes, | |
returning 2D plane coordinates. | |
Takes plane axes of shape n_planes, 3, 3 | |
# Takes coordinates of shape N, M, 3 | |
# returns projections of shape N*n_planes, M, 2 | |
""" | |
N, M, C = coordinates.shape | |
n_planes, _, _ = planes.shape | |
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N * n_planes, M, 3) | |
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N * n_planes, 3, 3) | |
projections = torch.bmm(coordinates, inv_planes) | |
return projections[..., :2] | |
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): | |
assert padding_mode == 'zeros' | |
N, n_planes, C, H, W = plane_features.shape | |
_, M, _ = coordinates.shape | |
plane_features = plane_features.view(N * n_planes, C, H, W) | |
coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds | |
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) | |
output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, | |
align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) | |
return output_features | |
def sample_from_3dgrid(grid, coordinates): | |
""" | |
Expects coordinates in shape (batch_size, num_points_per_batch, 3) | |
Expects grid in shape (1, channels, H, W, D) | |
(Also works if grid has batch size) | |
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) | |
""" | |
batch_size, n_coords, n_dims = coordinates.shape | |
sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), | |
coordinates.reshape(batch_size, 1, 1, -1, n_dims), | |
mode='bilinear', padding_mode='zeros', align_corners=False) | |
N, C, H, W, D = sampled_features.shape | |
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H * W * D, C) | |
return sampled_features | |
class ImportanceRenderer(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.ray_marcher = MipRayMarcher2() | |
self.plane_axes = generate_planes() | |
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): | |
self.plane_axes = self.plane_axes.to(ray_origins.device) | |
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': | |
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) | |
is_ray_valid = ray_end > ray_start | |
if torch.any(is_ray_valid).item(): | |
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() | |
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() | |
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], | |
rendering_options['disparity_space_sampling']) | |
else: | |
# Create stratified depth samples | |
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], | |
rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) | |
batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape | |
# Coarse Pass | |
sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) | |
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) | |
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) | |
colors_coarse = out['rgb'] | |
densities_coarse = out['sigma'] | |
colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) | |
densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) | |
# Fine Pass | |
N_importance = rendering_options['depth_resolution_importance'] | |
if N_importance > 0: | |
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) | |
depths_fine = self.sample_importance(depths_coarse, weights, N_importance) | |
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) | |
sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) | |
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) | |
colors_fine = out['rgb'] | |
densities_fine = out['sigma'] | |
colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) | |
densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) | |
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, | |
depths_fine, colors_fine, densities_fine) | |
# Aggregate | |
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) | |
else: | |
rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) | |
return rgb_final, depth_final, weights.sum(2) | |
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): | |
sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) | |
out = decoder(sampled_features, sample_directions) | |
if options.get('density_noise', 0) > 0: | |
out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] | |
return out | |
def sort_samples(self, all_depths, all_colors, all_densities): | |
_, indices = torch.sort(all_depths, dim=-2) | |
all_depths = torch.gather(all_depths, -2, indices) | |
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) | |
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) | |
return all_depths, all_colors, all_densities | |
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): | |
all_depths = torch.cat([depths1, depths2], dim=-2) | |
all_colors = torch.cat([colors1, colors2], dim=-2) | |
all_densities = torch.cat([densities1, densities2], dim=-2) | |
if normals1 is not None and normals2 is not None: | |
all_normals = torch.cat([normals1, normals2], dim=-2) | |
else: | |
all_normals = None | |
_, indices = torch.sort(all_depths, dim=-2) | |
all_depths = torch.gather(all_depths, -2, indices) | |
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) | |
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) | |
if all_normals is not None: | |
all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) | |
return all_depths, all_colors, all_normals, all_densities | |
return all_depths, all_colors, all_densities | |
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): | |
""" | |
Return depths of approximately uniformly spaced samples along rays. | |
""" | |
N, M, _ = ray_origins.shape | |
if disparity_space_sampling: | |
depths_coarse = torch.linspace(0, | |
1, | |
depth_resolution, | |
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) | |
depth_delta = 1 / (depth_resolution - 1) | |
depths_coarse += torch.rand_like(depths_coarse) * depth_delta | |
depths_coarse = 1. / (1. / ray_start * (1. - depths_coarse) + 1. / ray_end * depths_coarse) | |
else: | |
if type(ray_start) == torch.Tensor: | |
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1, 2, 0, 3) | |
depth_delta = (ray_end - ray_start) / (depth_resolution - 1) | |
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] | |
else: | |
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, | |
1).repeat(N, M, 1, 1) | |
depth_delta = (ray_end - ray_start) / (depth_resolution - 1) | |
depths_coarse += torch.rand_like(depths_coarse) * depth_delta | |
return depths_coarse | |
def sample_importance(self, z_vals, weights, N_importance): | |
""" | |
Return depths of importance sampled points along rays. See NeRF importance sampling for more. | |
""" | |
with torch.no_grad(): | |
batch_size, num_rays, samples_per_ray, _ = z_vals.shape | |
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) | |
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher | |
# smooth weights | |
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) | |
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() | |
weights = weights + 0.01 | |
z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) | |
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], | |
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) | |
return importance_z_vals | |
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): | |
""" | |
Sample @N_importance samples from @bins with distribution defined by @weights. | |
Inputs: | |
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" | |
weights: (N_rays, N_samples_) | |
N_importance: the number of samples to draw from the distribution | |
det: deterministic or not | |
eps: a small number to prevent division by zero | |
Outputs: | |
samples: the sampled samples | |
""" | |
N_rays, N_samples_ = weights.shape | |
weights = weights + eps # prevent division by zero (don't do inplace op!) | |
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) | |
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function | |
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf], -1) # (N_rays, N_samples_+1) | |
# padded to 0~1 inclusive | |
if det: | |
u = torch.linspace(0, 1, N_importance, device=bins.device) | |
u = u.expand(N_rays, N_importance) | |
else: | |
u = torch.rand(N_rays, N_importance, device=bins.device) | |
u = u.contiguous() | |
inds = torch.searchsorted(cdf, u, right=True) | |
below = torch.clamp_min(inds - 1, 0) | |
above = torch.clamp_max(inds, N_samples_) | |
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2 * N_importance) | |
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) | |
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) | |
denom = cdf_g[..., 1] - cdf_g[..., 0] | |
denom[denom < eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled | |
# anyway, therefore any value for it is fine (set to 1 here) | |
samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (bins_g[..., 1] - bins_g[..., 0]) | |
return samples | |
def normal_forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): | |
max_batch = 100000 | |
self.plane_axes = self.plane_axes.to(ray_origins.device) | |
if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': | |
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) | |
is_ray_valid = ray_end > ray_start | |
if torch.any(is_ray_valid).item(): | |
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() | |
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() | |
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], | |
rendering_options['disparity_space_sampling']) | |
else: | |
# Create stratified depth samples | |
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], | |
rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) | |
batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape | |
# Coarse Pass | |
sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) | |
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) | |
sample_coordinates.requires_grad_() | |
# input.requires_grad_() | |
with torch.set_grad_enabled(True): | |
colors = torch.zeros((sample_coordinates.shape[0], sample_coordinates.shape[1], 32), device=sample_coordinates.device) | |
sigmas = torch.zeros((sample_coordinates.shape[0], sample_coordinates.shape[1], 1), device=sample_coordinates.device) | |
head = 0 | |
while head < sample_coordinates.shape[1]: | |
out = self.run_model(planes, decoder, sample_coordinates[:, head:head + max_batch], sample_directions[:, head:head + max_batch], | |
rendering_options) | |
colors[:, head:head + max_batch] = out['rgb'] | |
sigmas[:, head:head + max_batch] = out['sigma'] | |
head += max_batch | |
colors_coarse = colors | |
densities_coarse = sigmas | |
input_grad = torch.autograd.grad(torch.sum(densities_coarse), sample_coordinates, create_graph=False)[0] | |
normal = -input_grad | |
normals_coarse = normal.reshape(batch_size, num_rays, samples_per_ray, normal.shape[-1]) | |
colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) | |
densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) | |
# Fine Pass | |
N_importance = rendering_options['depth_resolution_importance'] | |
if N_importance > 0: | |
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) | |
depths_fine = self.sample_importance(depths_coarse, weights, N_importance) | |
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) | |
sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) | |
sample_coordinates.requires_grad_() | |
with torch.set_grad_enabled(True): | |
colors = torch.zeros((sample_coordinates.shape[0], sample_coordinates.shape[1], 32), device=sample_coordinates.device) | |
sigmas = torch.zeros((sample_coordinates.shape[0], sample_coordinates.shape[1], 1), device=sample_coordinates.device) | |
head = 0 | |
while head < sample_coordinates.shape[1]: | |
out = self.run_model(planes, decoder, sample_coordinates[:, head:head + max_batch], sample_directions[:, head:head + max_batch], | |
rendering_options) | |
colors[:, head:head + max_batch] = out['rgb'] | |
sigmas[:, head:head + max_batch] = out['sigma'] | |
head += max_batch | |
colors_fine = colors | |
densities_fine = sigmas | |
input_grad = torch.autograd.grad(torch.sum(densities_fine), sample_coordinates, create_graph=False)[0] | |
normal = -input_grad | |
normals_fine = normal.reshape(batch_size, num_rays, N_importance, normal.shape[-1]) | |
colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) | |
densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) | |
all_depths, all_colors, all_normals, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, | |
depths_fine, colors_fine, densities_fine, normals_coarse, | |
normals_fine) | |
# Aggregate | |
rgb_final, depth_final, normal_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, all_normals) | |
else: | |
rgb_final, depth_final, normal_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options, | |
normals_coarse) | |
return rgb_final, depth_final, normal_final, weights.sum(2) | |
from torch_utils import misc | |
def dict2obj(d): | |
# if isinstance(d, list): | |
# d = [dict2obj(x) for x in d] | |
if not isinstance(d, dict): | |
return d | |
class C(object): | |
pass | |
o = C() | |
for k in d: | |
o.__dict__[k] = dict2obj(d[k]) | |
return o | |
from torch_utils import persistence | |
class Pytorch3dRasterizer(nn.Module): | |
## TODO: add support for rendering non-squared images, since pytorc3d supports this now | |
""" Borrowed from https://github.com/facebookresearch/pytorch3d | |
Notice: | |
x,y,z are in image space, normalized | |
can only render squared image now | |
""" | |
def __init__(self, image_size=224): | |
""" | |
use fixed raster_settings for rendering faces | |
""" | |
super().__init__() | |
raster_settings = { | |
'image_size': image_size, | |
'blur_radius': 0.0, | |
'faces_per_pixel': 1, | |
'bin_size': None, | |
'max_faces_per_bin': None, | |
'perspective_correct': False, | |
'cull_backfaces': True | |
} | |
# raster_settings = dict2obj(raster_settings) | |
self.raster_settings = raster_settings | |
def forward(self, vertices, faces, attributes=None, h=None, w=None): | |
fixed_vertices = vertices.clone() | |
fixed_vertices[..., :2] = -fixed_vertices[..., :2] | |
raster_settings = self.raster_settings | |
if h is None and w is None: | |
image_size = raster_settings['image_size'] | |
else: | |
image_size = [h, w] | |
if h > w: | |
fixed_vertices[..., 1] = fixed_vertices[..., 1] * h / w | |
else: | |
fixed_vertices[..., 0] = fixed_vertices[..., 0] * w / h | |
meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long()) | |
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( | |
meshes_screen, | |
image_size=image_size, | |
blur_radius=raster_settings['blur_radius'], | |
faces_per_pixel=raster_settings['faces_per_pixel'], | |
bin_size=raster_settings['bin_size'], | |
max_faces_per_bin=raster_settings['max_faces_per_bin'], | |
perspective_correct=raster_settings['perspective_correct'], | |
cull_backfaces=raster_settings['cull_backfaces'] | |
) | |
vismask = (pix_to_face > -1).float() | |
D = attributes.shape[-1] | |
attributes = attributes.clone(); | |
attributes = attributes.view(attributes.shape[0] * attributes.shape[1], 3, attributes.shape[-1]) | |
N, H, W, K, _ = bary_coords.shape | |
mask = pix_to_face == -1 | |
pix_to_face = pix_to_face.clone() | |
pix_to_face[mask] = 0 | |
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) | |
pixel_face_vals = attributes.gather(0, idx).view(N, H, W, K, 3, D) | |
pixel_vals = (bary_coords[..., None] * pixel_face_vals).sum(dim=-2) | |
pixel_vals[mask] = 0 # Replace masked values in output. | |
pixel_vals = pixel_vals[:, :, :, 0].permute(0, 3, 1, 2) | |
pixel_vals = torch.cat([pixel_vals, vismask[:, :, :, 0][:, None, :, :]], dim=1) | |
# print(image_size) | |
# import ipdb; ipdb.set_trace() | |
return pixel_vals | |
# borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py | |
def face_vertices(vertices, faces): | |
""" | |
:param vertices: [batch size, number of vertices, 3] | |
:param faces: [batch size, number of faces, 3] | |
:return: [batch size, number of faces, 3, 3] | |
""" | |
assert (vertices.ndimension() == 3) | |
assert (faces.ndimension() == 3) | |
assert (vertices.shape[0] == faces.shape[0]) | |
assert (vertices.shape[2] == 3) | |
assert (faces.shape[2] == 3) | |
bs, nv = vertices.shape[:2] | |
bs, nf = faces.shape[:2] | |
device = vertices.device | |
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None] | |
vertices = vertices.reshape((bs * nv, 3)) | |
# pytorch only supports long and byte tensors for indexing | |
return vertices[faces.long()] | |
# ---------------------------- process/generate vertices, normals, faces | |
def generate_triangles(h, w, margin_x=2, margin_y=5, mask=None): | |
# quad layout: | |
# 0 1 ... w-1 | |
# w w+1 | |
# . | |
# w*h | |
triangles = [] | |
for x in range(margin_x, w - 1 - margin_x): | |
for y in range(margin_y, h - 1 - margin_y): | |
triangle0 = [y * w + x, y * w + x + 1, (y + 1) * w + x] | |
triangle1 = [y * w + x + 1, (y + 1) * w + x + 1, (y + 1) * w + x] | |
triangles.append(triangle0) | |
triangles.append(triangle1) | |
triangles = np.array(triangles) | |
triangles = triangles[:, [0, 2, 1]] | |
return triangles | |
def transform_points(points, tform, points_scale=None, out_scale=None): | |
points_2d = points[:, :, :2] | |
# 'input points must use original range' | |
if points_scale: | |
assert points_scale[0] == points_scale[1] | |
points_2d = (points_2d * 0.5 + 0.5) * points_scale[0] | |
# import ipdb; ipdb.set_trace() | |
batch_size, n_points, _ = points.shape | |
trans_points_2d = torch.bmm( | |
torch.cat([points_2d, torch.ones([batch_size, n_points, 1], device=points.device, dtype=points.dtype)], dim=-1), | |
tform | |
) | |
if out_scale: # h,w of output image size | |
trans_points_2d[:, :, 0] = trans_points_2d[:, :, 0] / out_scale[1] * 2 - 1 | |
trans_points_2d[:, :, 1] = trans_points_2d[:, :, 1] / out_scale[0] * 2 - 1 | |
trans_points = torch.cat([trans_points_2d[:, :, :2], points[:, :, 2:]], dim=-1) | |
return trans_points | |
def batch_orth_proj(X, camera): | |
''' orthgraphic projection | |
X: 3d vertices, [bz, n_point, 3] | |
camera: scale and translation, [bz, 3], [scale, tx, ty] | |
''' | |
camera = camera.clone().view(-1, 1, 3) | |
X_trans = X[:, :, :2] + camera[:, :, 1:] | |
X_trans = torch.cat([X_trans, X[:, :, 2:]], 2) | |
shape = X_trans.shape | |
Xn = (camera[:, :, 0:1] * X_trans) | |
return Xn | |
def angle2matrix(angles): | |
''' get rotation matrix from three rotation angles(degree). right-handed. | |
Args: | |
angles: [batch_size, 3] tensor containing X, Y, and Z angles. | |
x: pitch. positive for looking down. | |
y: yaw. positive for looking left. | |
z: roll. positive for tilting head right. | |
Returns: | |
R: [batch_size, 3, 3]. rotation matrices. | |
''' | |
angles = angles * (np.pi) / 180. | |
s = torch.sin(angles) | |
c = torch.cos(angles) | |
cx, cy, cz = (c[:, 0], c[:, 1], c[:, 2]) | |
sx, sy, sz = (s[:, 0], s[:, 1], s[:, 2]) | |
zeros = torch.zeros_like(s[:, 0]).to(angles.device) | |
ones = torch.ones_like(s[:, 0]).to(angles.device) | |
# Rz.dot(Ry.dot(Rx)) | |
R_flattened = torch.stack( | |
[ | |
cz * cy, cz * sy * sx - sz * cx, cz * sy * cx + sz * sx, | |
sz * cy, sz * sy * sx + cz * cx, sz * sy * cx - cz * sx, | |
-sy, cy * sx, cy * cx, | |
], | |
dim=0) # [batch_size, 9] | |
R = torch.reshape(R_flattened, (-1, 3, 3)) # [batch_size, 3, 3] | |
return R | |
import cv2 | |
# end_list = np.array([17, 22, 27, 42, 48, 31, 36, 68], dtype = np.int32) - 1 | |
def plot_kpts(image, kpts, color='r', end_list=[19]): | |
''' Draw 68 key points | |
Args: | |
image: the input image | |
kpt: (68, 3). | |
''' | |
if color == 'r': | |
c = (255, 0, 0) | |
elif color == 'g': | |
c = (0, 255, 0) | |
elif color == 'b': | |
c = (255, 0, 0) | |
image = image.copy() | |
kpts = kpts.copy() | |
radius = max(int(min(image.shape[0], image.shape[1]) / 200), 1) | |
for i in range(kpts.shape[0]): | |
st = kpts[i, :2] | |
if kpts.shape[1] == 4: | |
if kpts[i, 3] > 0.5: | |
c = (0, 255, 0) | |
else: | |
c = (0, 0, 255) | |
if i in end_list: | |
continue | |
ed = kpts[i + 1, :2] | |
image = cv2.line(image, (int(st[0]), int(st[1])), (int(ed[0]), int(ed[1])), (255, 255, 255), radius) | |
image = cv2.circle(image, (int(st[0]), int(st[1])), radius, c, radius * 2) | |
return image | |
import cv2 | |
def fill_mouth(images): | |
# Input: images: [batch, 1, h, w] | |
device = images.device | |
mouth_masks = [] | |
for image in images: | |
image = image[0].cpu().numpy() | |
image = image * 255. | |
copyImg = image.copy() | |
h, w = image.shape[:2] | |
mask = np.zeros([h + 2, w + 2], np.uint8) | |
cv2.floodFill(copyImg, mask, (0, 0), (255, 255, 255), (0, 0, 0), (254, 254, 254), cv2.FLOODFILL_FIXED_RANGE) | |
# cv2.imwrite("debug.png", copyImg) | |
copyImg = torch.tensor(copyImg).to(device).to(torch.float32) / 127.5 - 1 | |
mouth_masks.append(copyImg.unsqueeze(0)) | |
mouth_masks = torch.stack(mouth_masks, 0) | |
mouth_masks = ((mouth_masks * 2. - 1.) * -1. + 1.) / 2. | |
# images = (images.bool() | mouth_masks.bool()).float() | |
res = (images + mouth_masks).clip(0, 1) | |
return res |