silentchen's picture
first commit
19c4ddf
raw
history blame
2.72 kB
import base64
import io
from typing import Union, Optional
import numpy as np
import torch
from PIL import Image
from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera
from shap_e.models.transmitter.base import Transmitter, VectorDecoder
from shap_e.rendering.torch_mesh import TorchMesh
from shap_e.util.collections import AttrDict
def create_pan_cameras(size: int, device: torch.device, batch_size: Optional[int] = 1, dist: int = 4) -> DifferentiableCameraBatch:
origins = []
xs = []
ys = []
zs = []
for theta in np.linspace(0, 2 * np.pi, num=20):
z = np.array([np.sin(theta), np.cos(theta), -0.5])
z /= np.sqrt(np.sum(z**2))
origin = -z * dist
x = np.array([np.cos(theta), -np.sin(theta), 0.0])
y = np.cross(z, x)
origins.append(origin)
xs.append(x)
ys.append(y)
zs.append(z)
return DifferentiableCameraBatch(
shape=(batch_size, len(xs)),
flat_camera=DifferentiableProjectiveCamera(
origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device).repeat(batch_size, 1),
x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device).repeat(batch_size, 1),
y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device).repeat(batch_size, 1),
z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device).repeat(batch_size, 1),
width=size,
height=size,
x_fov=0.7,
y_fov=0.7,
),
)
@torch.no_grad()
def decode_latent_images(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
cameras: DifferentiableCameraBatch,
rendering_mode: str = "stf",
):
# import pdb; pdb.set_trace()
decoded = xm.renderer.render_views(
AttrDict(cameras=cameras),
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False),
)
import pdb; pdb.set_trace()
arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
return [Image.fromarray(x) for x in arr]
@torch.no_grad()
def decode_latent_mesh(
xm: Union[Transmitter, VectorDecoder],
latent: torch.Tensor,
) -> TorchMesh:
decoded = xm.renderer.render_views(
AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible
params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params(
latent[None]
),
options=AttrDict(rendering_mode="stf", render_with_direction=False),
)
return decoded.raw_meshes[0]