|
from typing import Optional |
|
|
|
import torch |
|
import pytorch3d |
|
|
|
|
|
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj |
|
from pytorch3d.ops import interpolate_face_attributes |
|
|
|
from pytorch3d.structures import Meshes |
|
from pytorch3d.renderer import ( |
|
look_at_view_transform, |
|
FoVPerspectiveCameras, |
|
AmbientLights, |
|
PointLights, |
|
DirectionalLights, |
|
Materials, |
|
RasterizationSettings, |
|
MeshRenderer, |
|
MeshRasterizer, |
|
SoftPhongShader, |
|
SoftSilhouetteShader, |
|
HardPhongShader, |
|
TexturesVertex, |
|
TexturesUV, |
|
Materials, |
|
) |
|
from pytorch3d.renderer.blending import BlendParams, hard_rgb_blend |
|
from pytorch3d.renderer.utils import convert_to_tensors_and_broadcast, TensorProperties |
|
|
|
from pytorch3d.renderer.lighting import AmbientLights |
|
from pytorch3d.renderer.materials import Materials |
|
from pytorch3d.renderer.mesh.shader import ShaderBase |
|
from pytorch3d.renderer.mesh.shading import _apply_lighting, flat_shading |
|
from pytorch3d.renderer.mesh.rasterizer import Fragments |
|
|
|
|
|
""" |
|
Customized the original pytorch3d hard flat shader to support N channel flat shading |
|
""" |
|
|
|
|
|
class HardNChannelFlatShader(ShaderBase): |
|
""" |
|
Per face lighting - the lighting model is applied using the average face |
|
position and the face normal. The blending function hard assigns |
|
the color of the closest face for each pixel. |
|
|
|
To use the default values, simply initialize the shader with the desired |
|
device e.g. |
|
|
|
.. code-block:: |
|
|
|
shader = HardFlatShader(device=torch.device("cuda:0")) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
device="cpu", |
|
cameras: Optional[TensorProperties] = None, |
|
lights: Optional[TensorProperties] = None, |
|
materials: Optional[Materials] = None, |
|
blend_params: Optional[BlendParams] = None, |
|
channels: int = 3, |
|
): |
|
self.channels = channels |
|
ones = ((1.0,) * channels,) |
|
zeros = ((0.0,) * channels,) |
|
|
|
if ( |
|
not isinstance(lights, AmbientLights) |
|
or not lights.ambient_color.shape[-1] == channels |
|
): |
|
lights = AmbientLights( |
|
ambient_color=ones, |
|
device=device, |
|
) |
|
|
|
if not materials or not materials.ambient_color.shape[-1] == channels: |
|
materials = Materials( |
|
device=device, |
|
diffuse_color=zeros, |
|
ambient_color=ones, |
|
specular_color=zeros, |
|
shininess=0.0, |
|
) |
|
|
|
blend_params_new = BlendParams(background_color=(1.0,) * channels) |
|
if not isinstance(blend_params, BlendParams): |
|
blend_params = blend_params_new |
|
else: |
|
background_color_ = blend_params.background_color |
|
if ( |
|
isinstance(background_color_, Sequence[float]) |
|
and not len(background_color_) == channels |
|
): |
|
blend_params = blend_params_new |
|
if ( |
|
isinstance(background_color_, torch.Tensor) |
|
and not background_color_.shape[-1] == channels |
|
): |
|
blend_params = blend_params_new |
|
|
|
super().__init__( |
|
device, |
|
cameras, |
|
lights, |
|
materials, |
|
blend_params, |
|
) |
|
|
|
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor: |
|
cameras = super()._get_cameras(**kwargs) |
|
texels = meshes.sample_textures(fragments) |
|
lights = kwargs.get("lights", self.lights) |
|
materials = kwargs.get("materials", self.materials) |
|
blend_params = kwargs.get("blend_params", self.blend_params) |
|
colors = flat_shading( |
|
meshes=meshes, |
|
fragments=fragments, |
|
texels=texels, |
|
lights=lights, |
|
cameras=cameras, |
|
materials=materials, |
|
) |
|
images = hard_rgb_blend(colors, fragments, blend_params) |
|
return images |
|
|