yuandong513
feat: init
17cd746
raw
history blame
13.1 kB
import torch
import numpy as np
import math
import torch.nn as nn
from pytorch3d.structures import Meshes
from pytorch3d.io import load_obj
from pytorch3d.renderer.mesh import rasterize_meshes
from pytorch3d.ops import mesh_face_areas_normals
#-------------------------------------------------------------------------------#
def gen_tritex(vt: np.ndarray, vi: np.ndarray, vti: np.ndarray, texsize: int):
"""
Copied from MVP
Create 3 texture maps containing the vertex indices, texture vertex
indices, and barycentric coordinates
Parameters
----------
vt: uv coordinates of texels
vi: triangle list mapping into vertex positions
vti: triangle list mapping into texel coordinates
texsize: Size of the generated maps
"""
# vt = ((vt + 1. ) / 2.)[:, :2]
vt = vt[:, :2]
vt = np.array(vt, dtype=np.float32)
vi = np.array(vi, dtype=np.int32)
vti = np.array(vti, dtype=np.int32)
ntris = vi.shape[0]
texu, texv = np.meshgrid(
(np.arange(texsize) + 0.5) / texsize,
(np.arange(texsize) + 0.5) / texsize)
texuv = np.stack((texu, texv), axis=-1)
vt = vt[vti]
viim = np.zeros((texsize, texsize, 3), dtype=np.int32)
vtiim = np.zeros((texsize, texsize, 3), dtype=np.int32)
baryim = np.zeros((texsize, texsize, 3), dtype=np.float32)
for i in list(range(ntris))[::-1]:
bbox = (
max(0, int(min(vt[i, 0, 0], min(vt[i, 1, 0], vt[i, 2, 0])) * texsize) - 1),
min(texsize, int(max(vt[i, 0, 0], max(vt[i, 1, 0], vt[i, 2, 0])) * texsize) + 2),
max(0, int(min(vt[i, 0, 1], min(vt[i, 1, 1], vt[i, 2, 1])) * texsize) - 1),
min(texsize, int(max(vt[i, 0, 1], max(vt[i, 1, 1], vt[i, 2, 1])) * texsize) + 2))
v0 = vt[None, None, i, 1, :] - vt[None, None, i, 0, :]
v1 = vt[None, None, i, 2, :] - vt[None, None, i, 0, :]
v2 = texuv[bbox[2]:bbox[3], bbox[0]:bbox[1], :] - vt[None, None, i, 0, :]
d00 = np.sum(v0 * v0, axis=-1)
d01 = np.sum(v0 * v1, axis=-1)
d11 = np.sum(v1 * v1, axis=-1)
d20 = np.sum(v2 * v0, axis=-1)
d21 = np.sum(v2 * v1, axis=-1)
denom = d00 * d11 - d01 * d01
if denom != 0.:
baryv = (d11 * d20 - d01 * d21) / denom
baryw = (d00 * d21 - d01 * d20) / denom
baryu = 1. - baryv - baryw
baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
np.stack((baryu, baryv, baryw), axis=-1),
baryim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
np.stack((vi[i, 0], vi[i, 1], vi[i, 2]), axis=-1),
viim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :] = np.where(
((baryu >= 0.) & (baryv >= 0.) & (baryw >= 0.))[:, :, None],
np.stack((vti[i, 0], vti[i, 1], vti[i, 2]), axis=-1),
vtiim[bbox[2]:bbox[3], bbox[0]:bbox[1], :])
return torch.LongTensor(viim), torch.Tensor(vtiim), torch.Tensor(baryim)
# modified from https://github.com/facebookresearch/pytorch3d
class Pytorch3dRasterizer(nn.Module):
def __init__(self, image_size=224):
"""
use fixed raster_settings for rendering faces
"""
super().__init__()
raster_settings = {
'image_size': image_size,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'bin_size': None,
'max_faces_per_bin': None,
'perspective_correct': False,
'cull_backfaces': True
}
# raster_settings = dict2obj(raster_settings)
self.raster_settings = raster_settings
def forward(self, vertices, faces, h=None, w=None):
fixed_vertices = vertices.clone()
fixed_vertices[...,:2] = -fixed_vertices[...,:2]
raster_settings = self.raster_settings
if h is None and w is None:
image_size = raster_settings['image_size']
else:
image_size = [h, w]
if h>w:
fixed_vertices[..., 1] = fixed_vertices[..., 1]*h/w
else:
fixed_vertices[..., 0] = fixed_vertices[..., 0]*w/h
meshes_screen = Meshes(verts=fixed_vertices.float(), faces=faces.long())
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
image_size=image_size,
blur_radius=raster_settings['blur_radius'],
faces_per_pixel=raster_settings['faces_per_pixel'],
bin_size=raster_settings['bin_size'],
max_faces_per_bin=raster_settings['max_faces_per_bin'],
perspective_correct=raster_settings['perspective_correct'],
cull_backfaces=raster_settings['cull_backfaces']
)
return pix_to_face, bary_coords
#-------------------------------------------------------------------------------#
# borrowed from https://github.com/daniilidis-group/neural_renderer/blob/master/neural_renderer/vertices_to_faces.py
def face_vertices(vertices, faces):
"""
Indexing the coordinates of the three vertices on each face.
Args:
vertices: [bs, V, 3]
faces: [bs, F, 3]
Return:
face_to_vertices: [bs, F, 3, 3]
"""
assert (vertices.ndimension() == 3)
assert (faces.ndimension() == 3)
# assert (vertices.shape[0] == faces.shape[0])
assert (vertices.shape[2] == 3)
assert (faces.shape[2] == 3)
bs, nv = vertices.shape[:2]
bs, nf = faces.shape[:2]
device = vertices.device
faces = faces + (torch.arange(bs, dtype=torch.int32).to(device) * nv)[:, None, None]
vertices = vertices.reshape((bs * nv, 3))
# pytorch only supports long and byte tensors for indexing
return vertices[faces.long()]
def uniform_sampling_barycoords(
num_points: int,
tex_coord: torch.Tensor,
uv_faces: torch.Tensor,
d_size: float=1.0,
strict: bool=False,
use_mask: bool=True,
):
"""
Uniformly sampling barycentric coordinates using the rasterizer.
Args:
num_points: int sampling points number
tex_coord: [5150, 2] UV coords for each vert
uv_faces: [F,3] UV faces to UV coords index
d_size: const to control sampling points number
use_mask: use mask to mask valid points
Returns:
face_index [num_points] save which face each bary_coords belongs to
bary_coords [num_points, 3]
"""
uv_size = int(math.sqrt(num_points) * d_size)
uv_rasterizer = Pytorch3dRasterizer(uv_size)
tex_coord = tex_coord[None, ...]
uv_faces = uv_faces[None, ...]
tex_coord_ = torch.cat([tex_coord, tex_coord[:,:,0:1]*0.+1.], -1)
tex_coord_ = tex_coord_ * 2 - 1
tex_coord_[...,1] = - tex_coord_[...,1]
pix_to_face, bary_coords = uv_rasterizer(tex_coord_.expand(1, -1, -1), uv_faces.expand(1, -1, -1))
mask = (pix_to_face == -1)
if use_mask:
face_index = pix_to_face[~mask]
bary_coords = bary_coords[~mask]
else:
return pix_to_face, bary_coords
cur_n = face_index.shape[0]
# fix sampling number to num_points
if strict:
if cur_n < num_points:
pad_size = num_points - cur_n
new_face_index = face_index[torch.randint(0, cur_n, (pad_size,))]
new_bary_coords = torch.rand((pad_size, 3), device=bary_coords.device)
new_bary_coords = new_bary_coords / new_bary_coords.sum(dim=-1, keepdim=True)
face_index = torch.cat([face_index, new_face_index], dim=0)
bary_coords = torch.cat([bary_coords, new_bary_coords], dim=0)
elif cur_n > num_points:
face_index = face_index[:num_points]
bary_coords = bary_coords[:num_points]
return face_index, bary_coords
def random_sampling_barycoords(
num_points: int,
vertices: torch.Tensor,
faces: torch.Tensor
):
"""
Randomly sampling barycentric coordinates using the rasterizer.
Args:
num_points: int sampling points number
vertices: [V, 3]
faces: [F,3]
Returns:
face_index [num_points] save which face each bary_coords belongs to
bary_coords [num_points, 3]
"""
areas, _ = mesh_face_areas_normals(vertices.squeeze(0), faces)
g1 = torch.Generator(device=vertices.device)
g1.manual_seed(0)
face_index = areas.multinomial(
num_points, replacement=True, generator=g1
) # (N, num_samples)
uvw = torch.rand((face_index.shape[0], 3), device=vertices.device)
bary_coords = uvw / uvw.sum(dim=-1, keepdim=True)
return face_index, bary_coords
def reweight_verts_by_barycoords(
verts: torch.Tensor,
faces: torch.Tensor,
face_index: torch.Tensor,
bary_coords: torch.Tensor,
):
"""
Reweights the vertices based on the barycentric coordinates for each face.
Args:
verts: [bs, V, 3].
faces: [F, 3]
face_index: [N].
bary_coords: [N, 3].
Returns:
Reweighted vertex positions of shape [bs, N, 3].
"""
# index attributes by face
B = verts.shape[0]
face_verts = face_vertices(verts, faces.expand(B, -1, -1)) # [1, F, 3, 3]
# gather idnex for every splat
N = face_index.shape[0]
face_index_3 = face_index.view(1, N, 1, 1).expand(B, N, 3, 3)
position_vals = face_verts.gather(1, face_index_3)
# reweight
position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2)
return position_vals
def reweight_uvcoords_by_barycoords(
uvcoords: torch.Tensor,
uvfaces: torch.Tensor,
face_index: torch.Tensor,
bary_coords: torch.Tensor,
):
"""
Reweights the UV coordinates based on the barycentric coordinates for each face.
Args:
uvcoords: [bs, V', 2].
uvfaces: [F, 3].
face_index: [N].
bary_coords: [N, 3].
Returns:
Reweighted UV coordinates, shape [bs, N, 2].
"""
# homogeneous coordinates
num_v = uvcoords.shape[0]
uvcoords = torch.cat([uvcoords, torch.ones((num_v, 1)).to(uvcoords.device)], dim=1)
# index attributes by face
uvcoords = uvcoords[None, ...]
face_verts = face_vertices(uvcoords, uvfaces.expand(1, -1, -1)) # [1, F, 3, 3]
# gather idnex for every splat
N = face_index.shape[0]
face_index_3 = face_index.view(1, N, 1, 1).expand(1, N, 3, 3)
position_vals = face_verts.gather(1, face_index_3)
# reweight
position_vals = (bary_coords[..., None] * position_vals).sum(dim = -2)
return position_vals
# modified from https://github.com/computational-imaging/GSM/blob/main/main/gsm/deformer/util.py
def get_shell_verts_from_base(
template_verts: torch.Tensor,
template_faces: torch.Tensor,
offset_len: float,
num_shells: int,
deflat = False,
):
"""
Generates shell vertices by offsetting the original mesh's vertices along their normals.
Args:
template_verts: [bs, V, 3].
template_faces: [F, 3].
offset_len: Positive number specifying the offset length for generating shells.
num_shells: The number of shells to generate.
deflat: If True, performs a deflation process. Defaults to False.
Returns:
shell verts: [bs, num_shells, n, 3]
"""
out_offset_len = offset_len
if deflat:
in_offset_len = offset_len
batch_size = template_verts.shape[0]
mesh = Meshes(
verts=template_verts, faces=template_faces[None].repeat(batch_size, 1, 1)
)
# bs, n, 3
vertex_normal = mesh.verts_normals_padded()
# only for inflating
if deflat:
n_inflated_shells = num_shells//2 + 1
else:
n_inflated_shells = num_shells
linscale = torch.linspace(
out_offset_len,
0,
n_inflated_shells,
device=template_verts.device,
dtype=template_verts.dtype,
)
offset = linscale.reshape(1,n_inflated_shells, 1, 1) * vertex_normal[:, None]
if deflat:
linscale = torch.linspace(0, -in_offset_len, num_shells - n_inflated_shells + 1, device=template_verts.device, dtype=template_verts.dtype)[1:]
offset_in = linscale.reshape(1, -1, 1, 1) * vertex_normal[:, None]
offset = torch.cat([offset, offset_in], dim=1)
verts = template_verts[:, None] + offset
assert verts.isfinite().all()
return verts