|
import torch |
|
import pytorch3d |
|
import torch.nn.functional as F |
|
|
|
from pytorch3d.ops import interpolate_face_attributes |
|
|
|
from pytorch3d.renderer import ( |
|
look_at_view_transform, |
|
FoVPerspectiveCameras, |
|
AmbientLights, |
|
PointLights, |
|
DirectionalLights, |
|
Materials, |
|
RasterizationSettings, |
|
MeshRenderer, |
|
MeshRasterizer, |
|
SoftPhongShader, |
|
SoftSilhouetteShader, |
|
HardPhongShader, |
|
TexturesVertex, |
|
TexturesUV, |
|
Materials, |
|
) |
|
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend |
|
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties |
|
from pytorch3d.renderer.mesh.shader import ShaderBase |
|
|
|
|
|
def get_cos_angle(points, normals, camera_position): |
|
""" |
|
calculate cosine similarity between view->surface and surface normal. |
|
""" |
|
|
|
if points.shape != normals.shape: |
|
msg = "Expected points and normals to have the same shape: got %r, %r" |
|
raise ValueError(msg % (points.shape, normals.shape)) |
|
|
|
|
|
matched_tensors = convert_to_tensors_and_broadcast( |
|
points, camera_position, device=points.device |
|
) |
|
_, camera_position = matched_tensors |
|
|
|
|
|
|
|
points_dims = points.shape[1:-1] |
|
expand_dims = (-1,) + (1,) * len(points_dims) |
|
|
|
if camera_position.shape != normals.shape: |
|
camera_position = camera_position.view(expand_dims + (3,)) |
|
|
|
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6) |
|
|
|
|
|
view_direction = camera_position - points |
|
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6) |
|
cos_angle = torch.sum(view_direction * normals, dim=-1, keepdim=True) |
|
cos_angle = cos_angle.clamp(0, 1) |
|
|
|
|
|
return cos_angle |
|
|
|
|
|
def _geometry_shading_with_pixels( |
|
meshes, fragments, lights, cameras, materials, texels |
|
): |
|
""" |
|
Render pixel space vertex position, normal(world), depth, and cos angle |
|
|
|
Args: |
|
meshes: Batch of meshes |
|
fragments: Fragments named tuple with the outputs of rasterization |
|
lights: Lights class containing a batch of lights |
|
cameras: Cameras class containing a batch of cameras |
|
materials: Materials class containing a batch of material properties |
|
texels: texture per pixel of shape (N, H, W, K, 3) |
|
|
|
Returns: |
|
colors: (N, H, W, K, 3) |
|
pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection. |
|
""" |
|
verts = meshes.verts_packed() |
|
faces = meshes.faces_packed() |
|
vertex_normals = meshes.verts_normals_packed() |
|
faces_verts = verts[faces] |
|
faces_normals = vertex_normals[faces] |
|
pixel_coords_in_camera = interpolate_face_attributes( |
|
fragments.pix_to_face, fragments.bary_coords, faces_verts |
|
) |
|
pixel_normals = interpolate_face_attributes( |
|
fragments.pix_to_face, fragments.bary_coords, faces_normals |
|
) |
|
|
|
cos_angles = get_cos_angle( |
|
pixel_coords_in_camera, pixel_normals, cameras.get_camera_center() |
|
) |
|
|
|
return pixel_coords_in_camera, pixel_normals, fragments.zbuf[..., None], cos_angles |
|
|
|
|
|
class HardGeometryShader(ShaderBase): |
|
""" |
|
renders common geometric informations. |
|
|
|
|
|
""" |
|
|
|
def forward(self, fragments, meshes, **kwargs): |
|
cameras = super()._get_cameras(**kwargs) |
|
texels = self.texel_from_uv(fragments, meshes) |
|
|
|
lights = kwargs.get("lights", self.lights) |
|
materials = kwargs.get("materials", self.materials) |
|
blend_params = kwargs.get("blend_params", self.blend_params) |
|
verts, normals, depths, cos_angles = _geometry_shading_with_pixels( |
|
meshes=meshes, |
|
fragments=fragments, |
|
texels=texels, |
|
lights=lights, |
|
cameras=cameras, |
|
materials=materials, |
|
) |
|
texels = meshes.sample_textures(fragments) |
|
verts = hard_rgb_blend(verts, fragments, blend_params) |
|
normals = hard_rgb_blend(normals, fragments, blend_params) |
|
depths = hard_rgb_blend(depths, fragments, blend_params) |
|
cos_angles = hard_rgb_blend(cos_angles, fragments, blend_params) |
|
from IPython import embed |
|
|
|
embed() |
|
texels = hard_rgb_blend(texels, fragments, blend_params) |
|
return verts, normals, depths, cos_angles, texels, fragments |
|
|
|
def texel_from_uv(self, fragments, meshes): |
|
texture_tmp = meshes.textures |
|
maps_tmp = texture_tmp.maps_padded() |
|
uv_color = [[[1, 0], [1, 1]], [[0, 0], [0, 1]]] |
|
uv_color = ( |
|
torch.FloatTensor(uv_color).to(maps_tmp[0].device).type(maps_tmp[0].dtype) |
|
) |
|
uv_texture = TexturesUV( |
|
[uv_color.clone() for t in maps_tmp], |
|
texture_tmp.faces_uvs_padded(), |
|
texture_tmp.verts_uvs_padded(), |
|
sampling_mode="bilinear", |
|
) |
|
meshes.textures = uv_texture |
|
texels = meshes.sample_textures(fragments) |
|
meshes.textures = texture_tmp |
|
texels = torch.cat((texels, texels[..., -1:] * 0), dim=-1) |
|
return texels |
|
|