Spaces:
Paused
Paused
Upload 5 files
Browse files- sf3d/sf3d_box_uv_unwrap.py +610 -0
- sf3d/sf3d_system.py +482 -0
- sf3d/sf3d_texture_baker.py +87 -0
- sf3d/sf3d_texture_baker.slang +93 -0
- sf3d/sf3d_utils.py +91 -0
sf3d/sf3d_box_uv_unwrap.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from jaxtyping import Float, Integer
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from sf3d.models.utils import dot, triangle_intersection_2d
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _box_assign_vertex_to_cube_face(
|
| 13 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
| 14 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
| 15 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
| 16 |
+
bbox: Float[Tensor, "2 3"],
|
| 17 |
+
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
|
| 18 |
+
# Test to not have a scaled model to fit the space better
|
| 19 |
+
# bbox_min = bbox[:1].mean(-1, keepdim=True)
|
| 20 |
+
# bbox_max = bbox[1:].mean(-1, keepdim=True)
|
| 21 |
+
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
|
| 22 |
+
|
| 23 |
+
# Create a [0, 1] normalized vertex position
|
| 24 |
+
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
|
| 25 |
+
# And to [-1, 1]
|
| 26 |
+
v_pos_normalized = 2.0 * v_pos_normalized - 1.0
|
| 27 |
+
|
| 28 |
+
# Get all vertex positions for each triangle
|
| 29 |
+
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
|
| 30 |
+
v0 = v_pos_normalized[triangle_idxs[:, 0]]
|
| 31 |
+
v1 = v_pos_normalized[triangle_idxs[:, 1]]
|
| 32 |
+
v2 = v_pos_normalized[triangle_idxs[:, 2]]
|
| 33 |
+
tri_stack = torch.stack([v0, v1, v2], dim=1)
|
| 34 |
+
|
| 35 |
+
vn0 = vertex_normals[triangle_idxs[:, 0]]
|
| 36 |
+
vn1 = vertex_normals[triangle_idxs[:, 1]]
|
| 37 |
+
vn2 = vertex_normals[triangle_idxs[:, 2]]
|
| 38 |
+
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
|
| 39 |
+
|
| 40 |
+
# Just average the normals per face
|
| 41 |
+
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
|
| 42 |
+
|
| 43 |
+
# Now decide based on the face normal in which box map we project
|
| 44 |
+
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
|
| 45 |
+
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
|
| 46 |
+
|
| 47 |
+
axis = torch.tensor(
|
| 48 |
+
[
|
| 49 |
+
[1, 0, 0], # 0
|
| 50 |
+
[-1, 0, 0], # 1
|
| 51 |
+
[0, 1, 0], # 2
|
| 52 |
+
[0, -1, 0], # 3
|
| 53 |
+
[0, 0, 1], # 4
|
| 54 |
+
[0, 0, -1], # 5
|
| 55 |
+
],
|
| 56 |
+
device=face_normal.device,
|
| 57 |
+
dtype=face_normal.dtype,
|
| 58 |
+
)
|
| 59 |
+
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
|
| 60 |
+
index = face_normal_axis.argmax(-1)
|
| 61 |
+
|
| 62 |
+
max_axis, uc, vc = (
|
| 63 |
+
torch.ones_like(abs_x),
|
| 64 |
+
torch.zeros_like(tri_stack[..., :1]),
|
| 65 |
+
torch.zeros_like(tri_stack[..., :1]),
|
| 66 |
+
)
|
| 67 |
+
mask_pos_x = index == 0
|
| 68 |
+
max_axis[mask_pos_x] = abs_x[mask_pos_x]
|
| 69 |
+
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
|
| 70 |
+
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
|
| 71 |
+
|
| 72 |
+
mask_neg_x = index == 1
|
| 73 |
+
max_axis[mask_neg_x] = abs_x[mask_neg_x]
|
| 74 |
+
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
|
| 75 |
+
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
|
| 76 |
+
|
| 77 |
+
mask_pos_y = index == 2
|
| 78 |
+
max_axis[mask_pos_y] = abs_y[mask_pos_y]
|
| 79 |
+
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
|
| 80 |
+
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
|
| 81 |
+
|
| 82 |
+
mask_neg_y = index == 3
|
| 83 |
+
max_axis[mask_neg_y] = abs_y[mask_neg_y]
|
| 84 |
+
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
|
| 85 |
+
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
|
| 86 |
+
|
| 87 |
+
mask_pos_z = index == 4
|
| 88 |
+
max_axis[mask_pos_z] = abs_z[mask_pos_z]
|
| 89 |
+
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
|
| 90 |
+
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
|
| 91 |
+
|
| 92 |
+
mask_neg_z = index == 5
|
| 93 |
+
max_axis[mask_neg_z] = abs_z[mask_neg_z]
|
| 94 |
+
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
|
| 95 |
+
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
|
| 96 |
+
|
| 97 |
+
# UC from [-1, 1] to [0, 1]
|
| 98 |
+
max_dim_div = max_axis.max(dim=0, keepdims=True).values
|
| 99 |
+
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
| 100 |
+
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
|
| 101 |
+
|
| 102 |
+
uv = torch.stack([uc, vc], dim=-1)
|
| 103 |
+
|
| 104 |
+
return uv, index
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _assign_faces_uv_to_atlas_index(
|
| 108 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
| 109 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
| 110 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
| 111 |
+
face_index: Integer[Tensor, "Nf 3"],
|
| 112 |
+
) -> Integer[Tensor, "Nf"]: # noqa: F821
|
| 113 |
+
triangle_pos = vertex_positions[triangle_idxs]
|
| 114 |
+
# We need to do perform 3 overlap checks.
|
| 115 |
+
# The first set is placed in the upper two thirds of the UV atlas.
|
| 116 |
+
# Conceptually, this is the direct visible surfaces from the each cube side
|
| 117 |
+
# The second set is placed in the lower thirds and the left half of the UV atlas.
|
| 118 |
+
# This is the first set of occluded surfaces. They will also be saved in the projected fashion
|
| 119 |
+
# The third pass finds all non assigned faces. They will be placed in the bottom right half of
|
| 120 |
+
# the UV atlas in scattered fashion.
|
| 121 |
+
assign_idx = face_index.clone()
|
| 122 |
+
for overlap_step in range(3):
|
| 123 |
+
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
|
| 124 |
+
for i in range(overlap_step * 6, (overlap_step + 1) * 6):
|
| 125 |
+
mask = assign_idx == i
|
| 126 |
+
if not mask.any():
|
| 127 |
+
continue
|
| 128 |
+
# Get all elements belonging to the projection face
|
| 129 |
+
uv_triangle = face_uv[mask]
|
| 130 |
+
cur_triangle_pos = triangle_pos[mask]
|
| 131 |
+
# Find the center of the uv coordinates
|
| 132 |
+
center_uv = uv_triangle.mean(dim=1, keepdim=True)
|
| 133 |
+
# And also the radius of the triangle
|
| 134 |
+
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
|
| 135 |
+
|
| 136 |
+
potentially_overlapping_mask = (
|
| 137 |
+
# Find all close triangles
|
| 138 |
+
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
|
| 139 |
+
# Do not select the same element by offseting with an large valued identity matrix
|
| 140 |
+
+ torch.eye(
|
| 141 |
+
uv_triangle.shape[0],
|
| 142 |
+
device=uv_triangle.device,
|
| 143 |
+
dtype=uv_triangle.dtype,
|
| 144 |
+
).unsqueeze(-1)
|
| 145 |
+
* 1000
|
| 146 |
+
)
|
| 147 |
+
# Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
|
| 148 |
+
potentially_overlapping_mask = (
|
| 149 |
+
potentially_overlapping_mask
|
| 150 |
+
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
|
| 151 |
+
).squeeze(-1)
|
| 152 |
+
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
|
| 153 |
+
|
| 154 |
+
# Only unique triangles (A|B and B|A should be the same)
|
| 155 |
+
f = torch.min(overlap_coords, dim=-1).values
|
| 156 |
+
s = torch.max(overlap_coords, dim=-1).values
|
| 157 |
+
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
|
| 158 |
+
first, second = overlap_coords.unbind(-1)
|
| 159 |
+
|
| 160 |
+
# Get the triangles
|
| 161 |
+
tri_1 = uv_triangle[first]
|
| 162 |
+
tri_2 = uv_triangle[second]
|
| 163 |
+
|
| 164 |
+
# Perform the actual set with the reduced number of potentially overlapping triangles
|
| 165 |
+
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
|
| 166 |
+
|
| 167 |
+
# So we now need to detect which triangles are the occluded ones.
|
| 168 |
+
# We always assume the first to be the visible one (the others should move)
|
| 169 |
+
# In the previous step we use a lexigraphical sort to get the unique pairs
|
| 170 |
+
# In this we use a sort based on the orthographic projection
|
| 171 |
+
ax = 0 if i < 2 else 1 if i < 4 else 2
|
| 172 |
+
use_max = i % 2 == 1
|
| 173 |
+
|
| 174 |
+
tri1_c = cur_triangle_pos[first].mean(dim=1)
|
| 175 |
+
tri2_c = cur_triangle_pos[second].mean(dim=1)
|
| 176 |
+
|
| 177 |
+
mark_first = (
|
| 178 |
+
(tri1_c[..., ax] > tri2_c[..., ax])
|
| 179 |
+
if use_max
|
| 180 |
+
else (tri1_c[..., ax] < tri2_c[..., ax])
|
| 181 |
+
)
|
| 182 |
+
first[mark_first] = second[mark_first]
|
| 183 |
+
|
| 184 |
+
# Lastly the same index can be tested multiple times.
|
| 185 |
+
# If one marks it as overlapping we keep it marked as such.
|
| 186 |
+
# We do this by testing if it has been marked at least once.
|
| 187 |
+
unique_idx, rev_idx = torch.unique(first, return_inverse=True)
|
| 188 |
+
|
| 189 |
+
add = torch.zeros_like(unique_idx, dtype=torch.float32)
|
| 190 |
+
add.index_add_(0, rev_idx, its.float())
|
| 191 |
+
its_mask = add > 0
|
| 192 |
+
|
| 193 |
+
# And fill it in the overlapping indicator
|
| 194 |
+
idx = torch.where(mask)[0][unique_idx]
|
| 195 |
+
overlapping_indicator[idx] = its_mask
|
| 196 |
+
|
| 197 |
+
# Move the index to the overlap regions (shift by 6)
|
| 198 |
+
assign_idx[overlapping_indicator] += 6
|
| 199 |
+
|
| 200 |
+
# We do not care about the correct face placement after the first 2 slices
|
| 201 |
+
max_idx = 6 * 2
|
| 202 |
+
return assign_idx.clamp(0, max_idx)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _find_slice_offset_and_scale(
|
| 206 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
| 207 |
+
) -> Tuple[
|
| 208 |
+
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
|
| 209 |
+
]: # noqa: F821
|
| 210 |
+
# 6 due to the 6 cube faces
|
| 211 |
+
off = 1 / 3
|
| 212 |
+
dupl_off = 1 / 6
|
| 213 |
+
|
| 214 |
+
# Here, we need to decide how to pack the textures in the case of overlap
|
| 215 |
+
def x_offset_calc(x, i):
|
| 216 |
+
offset_calc = i // 6
|
| 217 |
+
# Initial coordinates - just 3x2 grid
|
| 218 |
+
if offset_calc == 0:
|
| 219 |
+
return off * x
|
| 220 |
+
else:
|
| 221 |
+
# Smaller 3x2 grid plus eventual shift to right for
|
| 222 |
+
# second overlap
|
| 223 |
+
return dupl_off * x + min(offset_calc - 1, 1) * 0.5
|
| 224 |
+
|
| 225 |
+
def y_offset_calc(x, i):
|
| 226 |
+
offset_calc = i // 6
|
| 227 |
+
# Initial coordinates - just a 3x2 grid
|
| 228 |
+
if offset_calc == 0:
|
| 229 |
+
return off * x
|
| 230 |
+
else:
|
| 231 |
+
# Smaller coordinates in the lowest row
|
| 232 |
+
return dupl_off * x + off * 2
|
| 233 |
+
|
| 234 |
+
offset_x = torch.zeros_like(index, dtype=torch.float32)
|
| 235 |
+
offset_y = torch.zeros_like(index, dtype=torch.float32)
|
| 236 |
+
offset_x_vals = [0, 1, 2, 0, 1, 2]
|
| 237 |
+
offset_y_vals = [0, 0, 0, 1, 1, 1]
|
| 238 |
+
for i in range(index.max().item() + 1):
|
| 239 |
+
mask = index == i
|
| 240 |
+
if not mask.any():
|
| 241 |
+
continue
|
| 242 |
+
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
|
| 243 |
+
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
|
| 244 |
+
|
| 245 |
+
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
|
| 246 |
+
# All overlap elements are saved in half scale
|
| 247 |
+
div_x[index >= 6] = 6
|
| 248 |
+
div_y = div_x.clone() # Same for y
|
| 249 |
+
# Except for the random overlaps
|
| 250 |
+
div_x[index >= 12] = 2
|
| 251 |
+
# But the random overlaps are saved in a large block in the lower thirds
|
| 252 |
+
div_y[index >= 12] = 3
|
| 253 |
+
|
| 254 |
+
return offset_x, offset_y, div_x, div_y
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def rotation_flip_matrix_2d(
|
| 258 |
+
rad: float, flip_x: bool = False, flip_y: bool = False
|
| 259 |
+
) -> Float[Tensor, "2 2"]:
|
| 260 |
+
cos = math.cos(rad)
|
| 261 |
+
sin = math.sin(rad)
|
| 262 |
+
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
|
| 263 |
+
flip_mat = torch.tensor(
|
| 264 |
+
[
|
| 265 |
+
[-1 if flip_x else 1, 0],
|
| 266 |
+
[0, -1 if flip_y else 1],
|
| 267 |
+
],
|
| 268 |
+
dtype=torch.float32,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
return flip_mat @ rot_mat
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def calculate_tangents(
|
| 275 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
| 276 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
| 277 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
| 278 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
| 279 |
+
) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
|
| 280 |
+
vn_idx = [None] * 3
|
| 281 |
+
pos = [None] * 3
|
| 282 |
+
tex = face_uv.unbind(1)
|
| 283 |
+
for i in range(0, 3):
|
| 284 |
+
pos[i] = vertex_positions[triangle_idxs[:, i]]
|
| 285 |
+
# t_nrm_idx is always the same as t_pos_idx
|
| 286 |
+
vn_idx[i] = triangle_idxs[:, i]
|
| 287 |
+
|
| 288 |
+
tangents = torch.zeros_like(vertex_normals)
|
| 289 |
+
tansum = torch.zeros_like(vertex_normals)
|
| 290 |
+
|
| 291 |
+
# Compute tangent space for each triangle
|
| 292 |
+
duv1 = tex[1] - tex[0]
|
| 293 |
+
duv2 = tex[2] - tex[0]
|
| 294 |
+
dpos1 = pos[1] - pos[0]
|
| 295 |
+
dpos2 = pos[2] - pos[0]
|
| 296 |
+
|
| 297 |
+
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
|
| 298 |
+
|
| 299 |
+
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
|
| 300 |
+
|
| 301 |
+
# Avoid division by zero for degenerated texture coordinates
|
| 302 |
+
denom_safe = denom.clip(1e-6)
|
| 303 |
+
tang = tng_nom / denom_safe
|
| 304 |
+
|
| 305 |
+
# Update all 3 vertices
|
| 306 |
+
for i in range(0, 3):
|
| 307 |
+
idx = vn_idx[i][:, None].repeat(1, 3)
|
| 308 |
+
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
| 309 |
+
tansum.scatter_add_(
|
| 310 |
+
0, idx, torch.ones_like(tang)
|
| 311 |
+
) # tansum[n_i] = tansum[n_i] + 1
|
| 312 |
+
# Also normalize it. Here we do not normalize the individual triangles first so larger area
|
| 313 |
+
# triangles influence the tangent space more
|
| 314 |
+
tangents = tangents / tansum
|
| 315 |
+
|
| 316 |
+
# Normalize and make sure tangent is perpendicular to normal
|
| 317 |
+
tangents = F.normalize(tangents, dim=1)
|
| 318 |
+
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
|
| 319 |
+
|
| 320 |
+
return tangents
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def _rotate_uv_slices_consistent_space(
|
| 324 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
| 325 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
| 326 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
| 327 |
+
uv: Float[Tensor, "Nf 3 2"],
|
| 328 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
| 329 |
+
):
|
| 330 |
+
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
|
| 331 |
+
pos_stack = torch.stack(
|
| 332 |
+
[
|
| 333 |
+
-vertex_positions[..., 1],
|
| 334 |
+
vertex_positions[..., 0],
|
| 335 |
+
torch.zeros_like(vertex_positions[..., 0]),
|
| 336 |
+
],
|
| 337 |
+
dim=-1,
|
| 338 |
+
)
|
| 339 |
+
expected_tangents = F.normalize(
|
| 340 |
+
torch.linalg.cross(
|
| 341 |
+
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
|
| 342 |
+
),
|
| 343 |
+
-1,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
actual_tangents = tangents[triangle_idxs]
|
| 347 |
+
expected_tangents = expected_tangents[triangle_idxs]
|
| 348 |
+
|
| 349 |
+
def rotation_matrix_2d(theta):
|
| 350 |
+
c, s = torch.cos(theta), torch.sin(theta)
|
| 351 |
+
return torch.tensor([[c, -s], [s, c]])
|
| 352 |
+
|
| 353 |
+
# Now find the rotation
|
| 354 |
+
index_mod = index % 6 # Shouldn't happen. Just for safety
|
| 355 |
+
for i in range(6):
|
| 356 |
+
mask = index_mod == i
|
| 357 |
+
if not mask.any():
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
|
| 361 |
+
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
|
| 362 |
+
|
| 363 |
+
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
|
| 364 |
+
cross_product = (
|
| 365 |
+
actual_mean_tangent[0] * expected_mean_tangent[1]
|
| 366 |
+
- actual_mean_tangent[1] * expected_mean_tangent[0]
|
| 367 |
+
)
|
| 368 |
+
angle = torch.atan2(cross_product, dot_product)
|
| 369 |
+
|
| 370 |
+
rot_matrix = rotation_matrix_2d(angle).to(mask.device)
|
| 371 |
+
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered
|
| 372 |
+
uv_cur = uv[mask] * 2 - 1 # Center it first
|
| 373 |
+
# Rotate it
|
| 374 |
+
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
|
| 375 |
+
|
| 376 |
+
# Rescale uv[mask] to be within the 0-1 range
|
| 377 |
+
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
|
| 378 |
+
|
| 379 |
+
return uv
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _handle_slice_uvs(
|
| 383 |
+
uv: Float[Tensor, "Nf 3 2"],
|
| 384 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
| 385 |
+
island_padding: float,
|
| 386 |
+
max_index: int = 6 * 2,
|
| 387 |
+
) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
|
| 388 |
+
uc, vc = uv.unbind(-1)
|
| 389 |
+
|
| 390 |
+
# Get the second slice (The first overlap)
|
| 391 |
+
index_filter = [index == i for i in range(6, max_index)]
|
| 392 |
+
|
| 393 |
+
# Normalize them to always fully fill the atlas patch
|
| 394 |
+
for i, fi in enumerate(index_filter):
|
| 395 |
+
if fi.sum() > 0:
|
| 396 |
+
# Scale the slice but only up to a factor of 2
|
| 397 |
+
# This keeps the texture resolution with the first slice in line (Half space in UV)
|
| 398 |
+
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
|
| 399 |
+
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
|
| 400 |
+
|
| 401 |
+
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
| 402 |
+
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
|
| 403 |
+
|
| 404 |
+
return torch.stack([uc_padded, vc_padded], dim=-1)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _handle_remaining_uvs(
|
| 408 |
+
uv: Float[Tensor, "Nf 3 2"],
|
| 409 |
+
index: Integer[Tensor, "Nf"], # noqa: F821
|
| 410 |
+
island_padding: float,
|
| 411 |
+
) -> Float[Tensor, "Nf 3 2"]:
|
| 412 |
+
uc, vc = uv.unbind(-1)
|
| 413 |
+
# Get all remaining elements
|
| 414 |
+
remaining_filter = index >= 6 * 2
|
| 415 |
+
squares_left = remaining_filter.sum()
|
| 416 |
+
|
| 417 |
+
if squares_left == 0:
|
| 418 |
+
return uv
|
| 419 |
+
|
| 420 |
+
uc = uc[remaining_filter]
|
| 421 |
+
vc = vc[remaining_filter]
|
| 422 |
+
|
| 423 |
+
# Or remaining triangles are distributed in a rectangle
|
| 424 |
+
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
|
| 425 |
+
ratio = 0.5 * (1 / 3) # 1.5
|
| 426 |
+
# sqrt(744/(0.5*(1/3)))
|
| 427 |
+
|
| 428 |
+
mult = math.sqrt(squares_left / ratio)
|
| 429 |
+
num_square_width = int(math.ceil(0.5 * mult))
|
| 430 |
+
num_square_height = int(math.ceil(squares_left / num_square_width))
|
| 431 |
+
|
| 432 |
+
width = 1 / num_square_width
|
| 433 |
+
height = 1 / num_square_height
|
| 434 |
+
|
| 435 |
+
# The idea is again to keep the texture resolution consistent with the first slice
|
| 436 |
+
# This only occupys half the region in the texture chart but the scaling on the squares
|
| 437 |
+
# assumes full coverage.
|
| 438 |
+
clip_val = min(width, height) * 1.5
|
| 439 |
+
# Now normalize the UVs with taking into account the maximum scaling
|
| 440 |
+
uc = (uc - uc.min(dim=1, keepdim=True).values) / (
|
| 441 |
+
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
|
| 442 |
+
).clip(clip_val)
|
| 443 |
+
vc = (vc - vc.min(dim=1, keepdim=True).values) / (
|
| 444 |
+
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
|
| 445 |
+
).clip(clip_val)
|
| 446 |
+
# Add a small padding
|
| 447 |
+
uc = (
|
| 448 |
+
uc * (1 - island_padding * num_square_width * 0.5)
|
| 449 |
+
+ island_padding * num_square_width * 0.25
|
| 450 |
+
).clip(0, 1)
|
| 451 |
+
vc = (
|
| 452 |
+
vc * (1 - island_padding * num_square_height * 0.5)
|
| 453 |
+
+ island_padding * num_square_height * 0.25
|
| 454 |
+
).clip(0, 1)
|
| 455 |
+
|
| 456 |
+
uc = uc * width
|
| 457 |
+
vc = vc * height
|
| 458 |
+
|
| 459 |
+
# And calculate offsets for each element
|
| 460 |
+
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
|
| 461 |
+
x_idx = idx % num_square_width
|
| 462 |
+
y_idx = idx // num_square_width
|
| 463 |
+
# And move each triangle to its own spot
|
| 464 |
+
uc = uc + x_idx[:, None] * width
|
| 465 |
+
vc = vc + y_idx[:, None] * height
|
| 466 |
+
|
| 467 |
+
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
| 468 |
+
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
|
| 469 |
+
|
| 470 |
+
uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
|
| 471 |
+
|
| 472 |
+
return uv
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def _distribute_individual_uvs_in_atlas(
|
| 476 |
+
face_uv: Float[Tensor, "Nf 3 2"],
|
| 477 |
+
assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
|
| 478 |
+
offset_x: Float[Tensor, "Nf"], # noqa: F821
|
| 479 |
+
offset_y: Float[Tensor, "Nf"], # noqa: F821
|
| 480 |
+
div_x: Float[Tensor, "Nf"], # noqa: F821
|
| 481 |
+
div_y: Float[Tensor, "Nf"], # noqa: F821
|
| 482 |
+
island_padding: float,
|
| 483 |
+
):
|
| 484 |
+
# Place the slice first
|
| 485 |
+
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
|
| 486 |
+
# Then handle the remaining overlap elements
|
| 487 |
+
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
|
| 488 |
+
|
| 489 |
+
uc, vc = placed_uv.unbind(-1)
|
| 490 |
+
uc = uc / div_x[:, None] + offset_x[:, None]
|
| 491 |
+
vc = vc / div_y[:, None] + offset_y[:, None]
|
| 492 |
+
|
| 493 |
+
uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
|
| 494 |
+
|
| 495 |
+
return uv
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _get_unique_face_uv(
|
| 499 |
+
uv: Float[Tensor, "Nf 3 2"],
|
| 500 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
| 501 |
+
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
|
| 502 |
+
# And add the face to uv index mapping
|
| 503 |
+
vtex_idx = unique_idx.view(-1, 3)
|
| 504 |
+
|
| 505 |
+
return unique_uv, vtex_idx
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def _align_mesh_with_main_axis(
|
| 509 |
+
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
|
| 510 |
+
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
|
| 511 |
+
# Use pca to find the 2 main axis (third is derived by cross product)
|
| 512 |
+
# Set the random seed so it's repeatable
|
| 513 |
+
torch.manual_seed(0)
|
| 514 |
+
_, _, v = torch.pca_lowrank(vertex_positions, q=2)
|
| 515 |
+
main_axis, seconday_axis = v[:, 0], v[:, 1]
|
| 516 |
+
|
| 517 |
+
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
|
| 518 |
+
# Orthogonalize the second axis
|
| 519 |
+
seconday_axis: Float[Tensor, "3"] = F.normalize(
|
| 520 |
+
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
|
| 521 |
+
)
|
| 522 |
+
# Create perpendicular third axis
|
| 523 |
+
third_axis: Float[Tensor, "3"] = F.normalize(
|
| 524 |
+
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
# Check to which canonical axis each aligns
|
| 528 |
+
main_axis_max_idx = main_axis.abs().argmax().item()
|
| 529 |
+
seconday_axis_max_idx = seconday_axis.abs().argmax().item()
|
| 530 |
+
third_axis_max_idx = third_axis.abs().argmax().item()
|
| 531 |
+
|
| 532 |
+
# Now sort the axes based on the argmax so they align with thecanonoical axes
|
| 533 |
+
# If two axes have the same argmax move one of them
|
| 534 |
+
all_possible_axis = {0, 1, 2}
|
| 535 |
+
cur_index = 1
|
| 536 |
+
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
|
| 537 |
+
# Find missing axis
|
| 538 |
+
missing_axis = all_possible_axis - set(
|
| 539 |
+
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
|
| 540 |
+
)
|
| 541 |
+
missing_axis = missing_axis.pop()
|
| 542 |
+
# Just assign it to third axis as it had the smallest contribution to the
|
| 543 |
+
# overall shape
|
| 544 |
+
if cur_index == 1:
|
| 545 |
+
third_axis_max_idx = missing_axis
|
| 546 |
+
elif cur_index == 2:
|
| 547 |
+
seconday_axis_max_idx = missing_axis
|
| 548 |
+
else:
|
| 549 |
+
raise ValueError("Could not find 3 unique axis")
|
| 550 |
+
cur_index += 1
|
| 551 |
+
|
| 552 |
+
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
|
| 553 |
+
raise ValueError("Could not find 3 unique axis")
|
| 554 |
+
|
| 555 |
+
axes = [None] * 3
|
| 556 |
+
axes[main_axis_max_idx] = main_axis
|
| 557 |
+
axes[seconday_axis_max_idx] = seconday_axis
|
| 558 |
+
axes[third_axis_max_idx] = third_axis
|
| 559 |
+
# Create rotation matrix from the individual axes
|
| 560 |
+
rot_mat = torch.stack(axes, dim=1).T
|
| 561 |
+
|
| 562 |
+
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
|
| 563 |
+
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
|
| 564 |
+
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
|
| 565 |
+
|
| 566 |
+
return vertex_positions, vertex_normals
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def box_projection_uv_unwrap(
|
| 570 |
+
vertex_positions: Float[Tensor, "Nv 3"],
|
| 571 |
+
vertex_normals: Float[Tensor, "Nv 3"],
|
| 572 |
+
triangle_idxs: Integer[Tensor, "Nf 3"],
|
| 573 |
+
island_padding: float,
|
| 574 |
+
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
|
| 575 |
+
# Align the mesh with main axis directions first
|
| 576 |
+
vertex_positions, vertex_normals = _align_mesh_with_main_axis(
|
| 577 |
+
vertex_positions, vertex_normals
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
bbox: Float[Tensor, "2 3"] = torch.stack(
|
| 581 |
+
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
|
| 582 |
+
)
|
| 583 |
+
# First decide in which cube face the triangle is placed
|
| 584 |
+
face_uv, face_index = _box_assign_vertex_to_cube_face(
|
| 585 |
+
vertex_positions, vertex_normals, triangle_idxs, bbox
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# Rotate the UV islands in a way that they align with the radial z tangent space
|
| 589 |
+
face_uv = _rotate_uv_slices_consistent_space(
|
| 590 |
+
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Then find where where the face is placed in the atlas.
|
| 594 |
+
# This has to detect potential overlaps
|
| 595 |
+
assigned_atlas_index = _assign_faces_uv_to_atlas_index(
|
| 596 |
+
vertex_positions, triangle_idxs, face_uv, face_index
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
# Then figure out the final place in the atlas based on the assignment
|
| 600 |
+
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
|
| 601 |
+
assigned_atlas_index
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# Next distribute the faces in the uv atlas
|
| 605 |
+
placed_uv = _distribute_individual_uvs_in_atlas(
|
| 606 |
+
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# And get the unique per-triangle UV coordinates
|
| 610 |
+
return _get_unique_face_uv(placed_uv)
|
sf3d/sf3d_system.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import trimesh
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
from jaxtyping import Float
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from safetensors.torch import load_model
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
|
| 17 |
+
from sf3d.models.isosurface import MarchingTetrahedraHelper
|
| 18 |
+
from sf3d.models.mesh import Mesh
|
| 19 |
+
from sf3d.models.utils import (
|
| 20 |
+
BaseModule,
|
| 21 |
+
ImageProcessor,
|
| 22 |
+
convert_data,
|
| 23 |
+
dilate_fill,
|
| 24 |
+
dot,
|
| 25 |
+
find_class,
|
| 26 |
+
float32_to_uint8_np,
|
| 27 |
+
normalize,
|
| 28 |
+
scale_tensor,
|
| 29 |
+
)
|
| 30 |
+
from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
|
| 31 |
+
|
| 32 |
+
from .texture_baker import TextureBaker
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SF3D(BaseModule):
|
| 36 |
+
@dataclass
|
| 37 |
+
class Config(BaseModule.Config):
|
| 38 |
+
cond_image_size: int
|
| 39 |
+
isosurface_resolution: int
|
| 40 |
+
isosurface_threshold: float = 10.0
|
| 41 |
+
radius: float = 1.0
|
| 42 |
+
background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
|
| 43 |
+
default_fovy_deg: float = 40.0
|
| 44 |
+
default_distance: float = 1.6
|
| 45 |
+
|
| 46 |
+
camera_embedder_cls: str = ""
|
| 47 |
+
camera_embedder: dict = field(default_factory=dict)
|
| 48 |
+
|
| 49 |
+
image_tokenizer_cls: str = ""
|
| 50 |
+
image_tokenizer: dict = field(default_factory=dict)
|
| 51 |
+
|
| 52 |
+
tokenizer_cls: str = ""
|
| 53 |
+
tokenizer: dict = field(default_factory=dict)
|
| 54 |
+
|
| 55 |
+
backbone_cls: str = ""
|
| 56 |
+
backbone: dict = field(default_factory=dict)
|
| 57 |
+
|
| 58 |
+
post_processor_cls: str = ""
|
| 59 |
+
post_processor: dict = field(default_factory=dict)
|
| 60 |
+
|
| 61 |
+
decoder_cls: str = ""
|
| 62 |
+
decoder: dict = field(default_factory=dict)
|
| 63 |
+
|
| 64 |
+
image_estimator_cls: str = ""
|
| 65 |
+
image_estimator: dict = field(default_factory=dict)
|
| 66 |
+
|
| 67 |
+
global_estimator_cls: str = ""
|
| 68 |
+
global_estimator: dict = field(default_factory=dict)
|
| 69 |
+
|
| 70 |
+
cfg: Config
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def from_pretrained(
|
| 74 |
+
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
|
| 75 |
+
):
|
| 76 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 77 |
+
config_path = os.path.join(pretrained_model_name_or_path, config_name)
|
| 78 |
+
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
|
| 79 |
+
else:
|
| 80 |
+
config_path = hf_hub_download(
|
| 81 |
+
repo_id=pretrained_model_name_or_path, filename=config_name
|
| 82 |
+
)
|
| 83 |
+
weight_path = hf_hub_download(
|
| 84 |
+
repo_id=pretrained_model_name_or_path, filename=weight_name
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
cfg = OmegaConf.load(config_path)
|
| 88 |
+
OmegaConf.resolve(cfg)
|
| 89 |
+
model = cls(cfg)
|
| 90 |
+
load_model(model, weight_path)
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
@property
|
| 94 |
+
def device(self):
|
| 95 |
+
return next(self.parameters()).device
|
| 96 |
+
|
| 97 |
+
def configure(self):
|
| 98 |
+
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
|
| 99 |
+
self.cfg.image_tokenizer
|
| 100 |
+
)
|
| 101 |
+
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
|
| 102 |
+
self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
|
| 103 |
+
self.cfg.camera_embedder
|
| 104 |
+
)
|
| 105 |
+
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
|
| 106 |
+
self.post_processor = find_class(self.cfg.post_processor_cls)(
|
| 107 |
+
self.cfg.post_processor
|
| 108 |
+
)
|
| 109 |
+
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
|
| 110 |
+
self.image_estimator = find_class(self.cfg.image_estimator_cls)(
|
| 111 |
+
self.cfg.image_estimator
|
| 112 |
+
)
|
| 113 |
+
self.global_estimator = find_class(self.cfg.global_estimator_cls)(
|
| 114 |
+
self.cfg.global_estimator
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.bbox: Float[Tensor, "2 3"]
|
| 118 |
+
self.register_buffer(
|
| 119 |
+
"bbox",
|
| 120 |
+
torch.as_tensor(
|
| 121 |
+
[
|
| 122 |
+
[-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
|
| 123 |
+
[self.cfg.radius, self.cfg.radius, self.cfg.radius],
|
| 124 |
+
],
|
| 125 |
+
dtype=torch.float32,
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
self.isosurface_helper = MarchingTetrahedraHelper(
|
| 129 |
+
self.cfg.isosurface_resolution,
|
| 130 |
+
os.path.join(
|
| 131 |
+
os.path.dirname(__file__),
|
| 132 |
+
"..",
|
| 133 |
+
"load",
|
| 134 |
+
"tets",
|
| 135 |
+
f"{self.cfg.isosurface_resolution}_tets.npz",
|
| 136 |
+
),
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.baker = TextureBaker()
|
| 140 |
+
self.image_processor = ImageProcessor()
|
| 141 |
+
|
| 142 |
+
def triplane_to_meshes(
|
| 143 |
+
self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
|
| 144 |
+
) -> list[Mesh]:
|
| 145 |
+
meshes = []
|
| 146 |
+
for i in range(triplanes.shape[0]):
|
| 147 |
+
triplane = triplanes[i]
|
| 148 |
+
grid_vertices = scale_tensor(
|
| 149 |
+
self.isosurface_helper.grid_vertices.to(triplanes.device),
|
| 150 |
+
self.isosurface_helper.points_range,
|
| 151 |
+
self.bbox,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
values = self.query_triplane(grid_vertices, triplane)
|
| 155 |
+
decoded = self.decoder(values, include=["vertex_offset", "density"])
|
| 156 |
+
sdf = decoded["density"] - self.cfg.isosurface_threshold
|
| 157 |
+
|
| 158 |
+
deform = decoded["vertex_offset"].squeeze(0)
|
| 159 |
+
|
| 160 |
+
mesh: Mesh = self.isosurface_helper(
|
| 161 |
+
sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
|
| 162 |
+
)
|
| 163 |
+
mesh.v_pos = scale_tensor(
|
| 164 |
+
mesh.v_pos, self.isosurface_helper.points_range, self.bbox
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
meshes.append(mesh)
|
| 168 |
+
|
| 169 |
+
return meshes
|
| 170 |
+
|
| 171 |
+
def query_triplane(
|
| 172 |
+
self,
|
| 173 |
+
positions: Float[Tensor, "*B N 3"],
|
| 174 |
+
triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
|
| 175 |
+
) -> Float[Tensor, "*B N F"]:
|
| 176 |
+
batched = positions.ndim == 3
|
| 177 |
+
if not batched:
|
| 178 |
+
# no batch dimension
|
| 179 |
+
triplanes = triplanes[None, ...]
|
| 180 |
+
positions = positions[None, ...]
|
| 181 |
+
assert triplanes.ndim == 5 and positions.ndim == 3
|
| 182 |
+
|
| 183 |
+
positions = scale_tensor(
|
| 184 |
+
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
|
| 188 |
+
(positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
|
| 189 |
+
dim=-3,
|
| 190 |
+
).to(triplanes.dtype)
|
| 191 |
+
out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
|
| 192 |
+
rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
|
| 193 |
+
rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
|
| 194 |
+
align_corners=True,
|
| 195 |
+
mode="bilinear",
|
| 196 |
+
)
|
| 197 |
+
out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
|
| 198 |
+
|
| 199 |
+
return out
|
| 200 |
+
|
| 201 |
+
def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
|
| 202 |
+
# if batch[rgb_cond] is only one view, add a view dimension
|
| 203 |
+
if len(batch["rgb_cond"].shape) == 4:
|
| 204 |
+
batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
|
| 205 |
+
batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
|
| 206 |
+
batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
|
| 207 |
+
batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
|
| 208 |
+
batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
|
| 209 |
+
batch_size, n_input_views = batch["rgb_cond"].shape[:2]
|
| 210 |
+
|
| 211 |
+
camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
|
| 212 |
+
camera_embeds = self.camera_embedder(**batch)
|
| 213 |
+
|
| 214 |
+
input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
|
| 215 |
+
rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
|
| 216 |
+
modulation_cond=camera_embeds,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
input_image_tokens = rearrange(
|
| 220 |
+
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
|
| 224 |
+
|
| 225 |
+
tokens = self.backbone(
|
| 226 |
+
tokens,
|
| 227 |
+
encoder_hidden_states=input_image_tokens,
|
| 228 |
+
modulation_cond=None,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
direct_codes = self.tokenizer.detokenize(tokens)
|
| 232 |
+
scene_codes = self.post_processor(direct_codes)
|
| 233 |
+
return scene_codes, direct_codes
|
| 234 |
+
|
| 235 |
+
def run_image(
|
| 236 |
+
self,
|
| 237 |
+
image: Image,
|
| 238 |
+
bake_resolution: int,
|
| 239 |
+
estimate_illumination: bool = False,
|
| 240 |
+
) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
|
| 241 |
+
if image.mode != "RGBA":
|
| 242 |
+
raise ValueError("Image must be in RGBA mode")
|
| 243 |
+
img_cond = (
|
| 244 |
+
torch.from_numpy(
|
| 245 |
+
np.asarray(
|
| 246 |
+
image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
|
| 247 |
+
).astype(np.float32)
|
| 248 |
+
/ 255.0
|
| 249 |
+
)
|
| 250 |
+
.float()
|
| 251 |
+
.clip(0, 1)
|
| 252 |
+
.to(self.device)
|
| 253 |
+
)
|
| 254 |
+
mask_cond = img_cond[:, :, -1:]
|
| 255 |
+
rgb_cond = torch.lerp(
|
| 256 |
+
torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
|
| 257 |
+
img_cond[:, :, :3],
|
| 258 |
+
mask_cond,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
|
| 262 |
+
intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
|
| 263 |
+
self.cfg.default_fovy_deg,
|
| 264 |
+
self.cfg.cond_image_size,
|
| 265 |
+
self.cfg.cond_image_size,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
batch = {
|
| 269 |
+
"rgb_cond": rgb_cond,
|
| 270 |
+
"mask_cond": mask_cond,
|
| 271 |
+
"c2w_cond": c2w_cond.unsqueeze(0),
|
| 272 |
+
"intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
|
| 273 |
+
"intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
meshes, global_dict = self.generate_mesh(
|
| 277 |
+
batch, bake_resolution, estimate_illumination
|
| 278 |
+
)
|
| 279 |
+
return meshes[0], global_dict
|
| 280 |
+
|
| 281 |
+
def generate_mesh(
|
| 282 |
+
self,
|
| 283 |
+
batch,
|
| 284 |
+
bake_resolution: int,
|
| 285 |
+
estimate_illumination: bool = False,
|
| 286 |
+
) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
|
| 287 |
+
batch["rgb_cond"] = self.image_processor(
|
| 288 |
+
batch["rgb_cond"], self.cfg.cond_image_size
|
| 289 |
+
)
|
| 290 |
+
batch["mask_cond"] = self.image_processor(
|
| 291 |
+
batch["mask_cond"], self.cfg.cond_image_size
|
| 292 |
+
)
|
| 293 |
+
scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
|
| 294 |
+
|
| 295 |
+
global_dict = {}
|
| 296 |
+
if self.image_estimator is not None:
|
| 297 |
+
global_dict.update(
|
| 298 |
+
self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
|
| 299 |
+
)
|
| 300 |
+
if self.global_estimator is not None and estimate_illumination:
|
| 301 |
+
global_dict.update(self.global_estimator(non_postprocessed_codes))
|
| 302 |
+
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
| 305 |
+
meshes = self.triplane_to_meshes(scene_codes)
|
| 306 |
+
|
| 307 |
+
rets = []
|
| 308 |
+
for i, mesh in enumerate(meshes):
|
| 309 |
+
# Check for empty mesh
|
| 310 |
+
if mesh.v_pos.shape[0] == 0:
|
| 311 |
+
rets.append(trimesh.Trimesh())
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
mesh.unwrap_uv()
|
| 315 |
+
|
| 316 |
+
# Build textures
|
| 317 |
+
rast = self.baker.rasterize(
|
| 318 |
+
mesh.v_tex, mesh.t_pos_idx, bake_resolution
|
| 319 |
+
)
|
| 320 |
+
bake_mask = self.baker.get_mask(rast)
|
| 321 |
+
|
| 322 |
+
pos_bake = self.baker.interpolate(
|
| 323 |
+
mesh.v_pos,
|
| 324 |
+
rast,
|
| 325 |
+
mesh.t_pos_idx,
|
| 326 |
+
mesh.v_tex,
|
| 327 |
+
)
|
| 328 |
+
gb_pos = pos_bake[bake_mask]
|
| 329 |
+
|
| 330 |
+
tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
|
| 331 |
+
decoded = self.decoder(
|
| 332 |
+
tri_query, exclude=["density", "vertex_offset"]
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
nrm = self.baker.interpolate(
|
| 336 |
+
mesh.v_nrm,
|
| 337 |
+
rast,
|
| 338 |
+
mesh.t_pos_idx,
|
| 339 |
+
mesh.v_tex,
|
| 340 |
+
)
|
| 341 |
+
gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
|
| 342 |
+
decoded["normal"] = gb_nrm
|
| 343 |
+
|
| 344 |
+
# Check if any keys in global_dict start with decoded_
|
| 345 |
+
for k, v in global_dict.items():
|
| 346 |
+
if k.startswith("decoder_"):
|
| 347 |
+
decoded[k.replace("decoder_", "")] = v[i]
|
| 348 |
+
|
| 349 |
+
mat_out = {
|
| 350 |
+
"albedo": decoded["features"],
|
| 351 |
+
"roughness": decoded["roughness"],
|
| 352 |
+
"metallic": decoded["metallic"],
|
| 353 |
+
"normal": normalize(decoded["perturb_normal"]),
|
| 354 |
+
"bump": None,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
for k, v in mat_out.items():
|
| 358 |
+
if v is None:
|
| 359 |
+
continue
|
| 360 |
+
if v.shape[0] == 1:
|
| 361 |
+
# Skip and directly add a single value
|
| 362 |
+
mat_out[k] = v[0]
|
| 363 |
+
else:
|
| 364 |
+
f = torch.zeros(
|
| 365 |
+
bake_resolution,
|
| 366 |
+
bake_resolution,
|
| 367 |
+
v.shape[-1],
|
| 368 |
+
dtype=v.dtype,
|
| 369 |
+
device=v.device,
|
| 370 |
+
)
|
| 371 |
+
if v.shape == f.shape:
|
| 372 |
+
continue
|
| 373 |
+
if k == "normal":
|
| 374 |
+
# Use un-normalized tangents here so that larger smaller tris
|
| 375 |
+
# Don't effect the tangents that much
|
| 376 |
+
tng = self.baker.interpolate(
|
| 377 |
+
mesh.v_tng,
|
| 378 |
+
rast,
|
| 379 |
+
mesh.t_pos_idx,
|
| 380 |
+
mesh.v_tex,
|
| 381 |
+
)
|
| 382 |
+
gb_tng = tng[bake_mask]
|
| 383 |
+
gb_tng = F.normalize(gb_tng, dim=-1)
|
| 384 |
+
gb_btng = F.normalize(
|
| 385 |
+
torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
|
| 386 |
+
)
|
| 387 |
+
normal = F.normalize(mat_out["normal"], dim=-1)
|
| 388 |
+
|
| 389 |
+
bump = torch.cat(
|
| 390 |
+
# Check if we have to flip some things
|
| 391 |
+
(
|
| 392 |
+
dot(normal, gb_tng),
|
| 393 |
+
dot(normal, gb_btng),
|
| 394 |
+
dot(normal, gb_nrm).clip(
|
| 395 |
+
0.3, 1
|
| 396 |
+
), # Never go below 0.3. This would indicate a flipped (or close to one) normal
|
| 397 |
+
),
|
| 398 |
+
-1,
|
| 399 |
+
)
|
| 400 |
+
bump = (bump * 0.5 + 0.5).clamp(0, 1)
|
| 401 |
+
|
| 402 |
+
f[bake_mask] = bump.view(-1, 3)
|
| 403 |
+
mat_out["bump"] = f
|
| 404 |
+
else:
|
| 405 |
+
f[bake_mask] = v.view(-1, v.shape[-1])
|
| 406 |
+
mat_out[k] = f
|
| 407 |
+
|
| 408 |
+
def uv_padding(arr):
|
| 409 |
+
if arr.ndim == 1:
|
| 410 |
+
return arr
|
| 411 |
+
return (
|
| 412 |
+
dilate_fill(
|
| 413 |
+
arr.permute(2, 0, 1)[None, ...],
|
| 414 |
+
bake_mask.unsqueeze(0).unsqueeze(0),
|
| 415 |
+
iterations=bake_resolution // 150,
|
| 416 |
+
)
|
| 417 |
+
.squeeze(0)
|
| 418 |
+
.permute(1, 2, 0)
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
verts_np = convert_data(mesh.v_pos)
|
| 422 |
+
faces = convert_data(mesh.t_pos_idx)
|
| 423 |
+
uvs = convert_data(mesh.v_tex)
|
| 424 |
+
|
| 425 |
+
basecolor_tex = Image.fromarray(
|
| 426 |
+
float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
|
| 427 |
+
).convert("RGB")
|
| 428 |
+
basecolor_tex.format = "JPEG"
|
| 429 |
+
|
| 430 |
+
metallic = mat_out["metallic"].squeeze().cpu().item()
|
| 431 |
+
roughness = mat_out["roughness"].squeeze().cpu().item()
|
| 432 |
+
|
| 433 |
+
if "bump" in mat_out and mat_out["bump"] is not None:
|
| 434 |
+
bump_np = convert_data(uv_padding(mat_out["bump"]))
|
| 435 |
+
bump_up = np.ones_like(bump_np)
|
| 436 |
+
bump_up[..., :2] = 0.5
|
| 437 |
+
bump_up[..., 2:] = 1
|
| 438 |
+
bump_tex = Image.fromarray(
|
| 439 |
+
float32_to_uint8_np(
|
| 440 |
+
bump_np,
|
| 441 |
+
dither=True,
|
| 442 |
+
# Do not dither if something is perfectly flat
|
| 443 |
+
dither_mask=np.all(
|
| 444 |
+
bump_np == bump_up, axis=-1, keepdims=True
|
| 445 |
+
).astype(np.float32),
|
| 446 |
+
)
|
| 447 |
+
).convert("RGB")
|
| 448 |
+
bump_tex.format = (
|
| 449 |
+
"JPEG" # PNG would be better but the assets are larger
|
| 450 |
+
)
|
| 451 |
+
else:
|
| 452 |
+
bump_tex = None
|
| 453 |
+
|
| 454 |
+
material = trimesh.visual.material.PBRMaterial(
|
| 455 |
+
baseColorTexture=basecolor_tex,
|
| 456 |
+
roughnessFactor=roughness,
|
| 457 |
+
metallicFactor=metallic,
|
| 458 |
+
normalTexture=bump_tex,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
tmesh = trimesh.Trimesh(
|
| 462 |
+
vertices=verts_np,
|
| 463 |
+
faces=faces,
|
| 464 |
+
visual=trimesh.visual.texture.TextureVisuals(
|
| 465 |
+
uv=uvs, material=material
|
| 466 |
+
),
|
| 467 |
+
)
|
| 468 |
+
rot = trimesh.transformations.rotation_matrix(
|
| 469 |
+
np.radians(-90), [1, 0, 0]
|
| 470 |
+
)
|
| 471 |
+
tmesh.apply_transform(rot)
|
| 472 |
+
tmesh.apply_transform(
|
| 473 |
+
trimesh.transformations.rotation_matrix(
|
| 474 |
+
np.radians(90), [0, 1, 0]
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
tmesh.invert()
|
| 479 |
+
|
| 480 |
+
rets.append(tmesh)
|
| 481 |
+
|
| 482 |
+
return rets, global_dict
|
sf3d/sf3d_texture_baker.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import slangtorch
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from jaxtyping import Bool, Float
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TextureBaker(nn.Module):
|
| 11 |
+
def __init__(self):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.baker = slangtorch.loadModule(
|
| 14 |
+
os.path.join(os.path.dirname(__file__), "texture_baker.slang")
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def rasterize(
|
| 18 |
+
self,
|
| 19 |
+
uv: Float[Tensor, "Nv 2"],
|
| 20 |
+
face_indices: Float[Tensor, "Nf 3"],
|
| 21 |
+
bake_resolution: int,
|
| 22 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
|
| 23 |
+
if not face_indices.is_cuda or not uv.is_cuda:
|
| 24 |
+
raise ValueError("All input tensors must be on cuda")
|
| 25 |
+
|
| 26 |
+
face_indices = face_indices.to(torch.int32)
|
| 27 |
+
uv = uv.to(torch.float32)
|
| 28 |
+
|
| 29 |
+
rast_result = torch.empty(
|
| 30 |
+
bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
block_size = 16
|
| 34 |
+
grid_size = bake_resolution // block_size
|
| 35 |
+
self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
|
| 36 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return rast_result
|
| 40 |
+
|
| 41 |
+
def get_mask(
|
| 42 |
+
self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
|
| 43 |
+
) -> Bool[Tensor, "bake_resolution bake_resolution"]:
|
| 44 |
+
return rast[..., -1] >= 0
|
| 45 |
+
|
| 46 |
+
def interpolate(
|
| 47 |
+
self,
|
| 48 |
+
attr: Float[Tensor, "Nv 3"],
|
| 49 |
+
rast: Float[Tensor, "bake_resolution bake_resolution 4"],
|
| 50 |
+
face_indices: Float[Tensor, "Nf 3"],
|
| 51 |
+
uv: Float[Tensor, "Nv 2"],
|
| 52 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
| 53 |
+
# Make sure all input tensors are on torch
|
| 54 |
+
if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
|
| 55 |
+
raise ValueError("All input tensors must be on cuda")
|
| 56 |
+
|
| 57 |
+
attr = attr.to(torch.float32)
|
| 58 |
+
face_indices = face_indices.to(torch.int32)
|
| 59 |
+
uv = uv.to(torch.float32)
|
| 60 |
+
|
| 61 |
+
pos_bake = torch.zeros(
|
| 62 |
+
rast.shape[0],
|
| 63 |
+
rast.shape[1],
|
| 64 |
+
3,
|
| 65 |
+
device=attr.device,
|
| 66 |
+
dtype=attr.dtype,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
block_size = 16
|
| 70 |
+
grid_size = rast.shape[0] // block_size
|
| 71 |
+
self.baker.interpolate(
|
| 72 |
+
attr=attr, indices=face_indices, rast=rast, output=pos_bake
|
| 73 |
+
).launchRaw(
|
| 74 |
+
blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return pos_bake
|
| 78 |
+
|
| 79 |
+
def forward(
|
| 80 |
+
self,
|
| 81 |
+
attr: Float[Tensor, "Nv 3"],
|
| 82 |
+
uv: Float[Tensor, "Nv 2"],
|
| 83 |
+
face_indices: Float[Tensor, "Nf 3"],
|
| 84 |
+
bake_resolution: int,
|
| 85 |
+
) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
|
| 86 |
+
rast = self.rasterize(uv, face_indices, bake_resolution)
|
| 87 |
+
return self.interpolate(attr, rast, face_indices, uv)
|
sf3d/sf3d_texture_baker.slang
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// xy: 2D test position
|
| 2 |
+
// v1: vertex position 1
|
| 3 |
+
// v2: vertex position 2
|
| 4 |
+
// v3: vertex position 3
|
| 5 |
+
//
|
| 6 |
+
bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
|
| 7 |
+
{
|
| 8 |
+
// Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
|
| 9 |
+
// If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
|
| 10 |
+
float2 v1v2 = v2 - v1;
|
| 11 |
+
float2 v1v3 = v3 - v1;
|
| 12 |
+
float2 xyv1 = xy - v1;
|
| 13 |
+
|
| 14 |
+
float d00 = dot(v1v2, v1v2);
|
| 15 |
+
float d01 = dot(v1v2, v1v3);
|
| 16 |
+
float d11 = dot(v1v3, v1v3);
|
| 17 |
+
float d20 = dot(xyv1, v1v2);
|
| 18 |
+
float d21 = dot(xyv1, v1v3);
|
| 19 |
+
|
| 20 |
+
float denom = d00 * d11 - d01 * d01;
|
| 21 |
+
v = (d11 * d20 - d01 * d21) / denom;
|
| 22 |
+
w = (d00 * d21 - d01 * d20) / denom;
|
| 23 |
+
u = 1.0 - v - w;
|
| 24 |
+
|
| 25 |
+
return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
[AutoPyBindCUDA]
|
| 29 |
+
[CUDAKernel]
|
| 30 |
+
void interpolate(
|
| 31 |
+
TensorView<float3> attr,
|
| 32 |
+
TensorView<int3> indices,
|
| 33 |
+
TensorView<float4> rast,
|
| 34 |
+
TensorView<float3> output)
|
| 35 |
+
{
|
| 36 |
+
// Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
|
| 37 |
+
|
| 38 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
| 39 |
+
|
| 40 |
+
if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
|
| 41 |
+
return;
|
| 42 |
+
|
| 43 |
+
float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
|
| 44 |
+
int triangle_idx = int(barycentric.w);
|
| 45 |
+
|
| 46 |
+
if (triangle_idx < 0) {
|
| 47 |
+
output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
|
| 48 |
+
return;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
float3 v1 = attr[indices[triangle_idx].x];
|
| 52 |
+
float3 v2 = attr[indices[triangle_idx].y];
|
| 53 |
+
float3 v3 = attr[indices[triangle_idx].z];
|
| 54 |
+
|
| 55 |
+
output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
[AutoPyBindCUDA]
|
| 59 |
+
[CUDAKernel]
|
| 60 |
+
void bake_uv(
|
| 61 |
+
TensorView<float2> uv,
|
| 62 |
+
TensorView<int3> indices,
|
| 63 |
+
TensorView<float4> output)
|
| 64 |
+
{
|
| 65 |
+
uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
|
| 66 |
+
|
| 67 |
+
if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
|
| 68 |
+
return;
|
| 69 |
+
|
| 70 |
+
// We index x,y but the orginal coords are HW. So swap them
|
| 71 |
+
float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
|
| 72 |
+
// Normalize to [0, 1]
|
| 73 |
+
pixel_coord /= float2(output.size(1), output.size(0));
|
| 74 |
+
pixel_coord = clamp(pixel_coord, 0.0, 1.0);
|
| 75 |
+
// Flip x-axis
|
| 76 |
+
pixel_coord.y = 1 - pixel_coord.y;
|
| 77 |
+
|
| 78 |
+
for (int i = 0; i < indices.size(0); i++) {
|
| 79 |
+
float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
|
| 80 |
+
float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
|
| 81 |
+
float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
|
| 82 |
+
|
| 83 |
+
float u, v, w;
|
| 84 |
+
bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
|
| 85 |
+
|
| 86 |
+
if (hit){
|
| 87 |
+
output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
|
| 88 |
+
return;
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
|
| 93 |
+
}
|
sf3d/sf3d_utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import rembg
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
import sf3d.models.utils as sf3d_utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
|
| 12 |
+
intrinsic = sf3d_utils.get_intrinsic_from_fov(
|
| 13 |
+
np.deg2rad(fov_deg),
|
| 14 |
+
H=cond_height,
|
| 15 |
+
W=cond_width,
|
| 16 |
+
)
|
| 17 |
+
intrinsic_normed_cond = intrinsic.clone()
|
| 18 |
+
intrinsic_normed_cond[..., 0, 2] /= cond_width
|
| 19 |
+
intrinsic_normed_cond[..., 1, 2] /= cond_height
|
| 20 |
+
intrinsic_normed_cond[..., 0, 0] /= cond_width
|
| 21 |
+
intrinsic_normed_cond[..., 1, 1] /= cond_height
|
| 22 |
+
|
| 23 |
+
return intrinsic, intrinsic_normed_cond
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def default_cond_c2w(distance: float):
|
| 27 |
+
c2w_cond = torch.as_tensor(
|
| 28 |
+
[
|
| 29 |
+
[0, 0, 1, distance],
|
| 30 |
+
[1, 0, 0, 0],
|
| 31 |
+
[0, 1, 0, 0],
|
| 32 |
+
[0, 0, 0, 1],
|
| 33 |
+
]
|
| 34 |
+
).float()
|
| 35 |
+
return c2w_cond
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def remove_background(
|
| 39 |
+
image: Image,
|
| 40 |
+
rembg_session: Any = None,
|
| 41 |
+
force: bool = False,
|
| 42 |
+
**rembg_kwargs,
|
| 43 |
+
) -> Image:
|
| 44 |
+
do_remove = True
|
| 45 |
+
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
|
| 46 |
+
do_remove = False
|
| 47 |
+
do_remove = do_remove or force
|
| 48 |
+
if do_remove:
|
| 49 |
+
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def resize_foreground(
|
| 54 |
+
image: Image,
|
| 55 |
+
ratio: float,
|
| 56 |
+
) -> Image:
|
| 57 |
+
image = np.array(image)
|
| 58 |
+
assert image.shape[-1] == 4
|
| 59 |
+
alpha = np.where(image[..., 3] > 0)
|
| 60 |
+
y1, y2, x1, x2 = (
|
| 61 |
+
alpha[0].min(),
|
| 62 |
+
alpha[0].max(),
|
| 63 |
+
alpha[1].min(),
|
| 64 |
+
alpha[1].max(),
|
| 65 |
+
)
|
| 66 |
+
# crop the foreground
|
| 67 |
+
fg = image[y1:y2, x1:x2]
|
| 68 |
+
# pad to square
|
| 69 |
+
size = max(fg.shape[0], fg.shape[1])
|
| 70 |
+
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
|
| 71 |
+
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
|
| 72 |
+
new_image = np.pad(
|
| 73 |
+
fg,
|
| 74 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 75 |
+
mode="constant",
|
| 76 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# compute padding according to the ratio
|
| 80 |
+
new_size = int(new_image.shape[0] / ratio)
|
| 81 |
+
# pad to size, double side
|
| 82 |
+
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
|
| 83 |
+
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
|
| 84 |
+
new_image = np.pad(
|
| 85 |
+
new_image,
|
| 86 |
+
((ph0, ph1), (pw0, pw1), (0, 0)),
|
| 87 |
+
mode="constant",
|
| 88 |
+
constant_values=((0, 0), (0, 0), (0, 0)),
|
| 89 |
+
)
|
| 90 |
+
new_image = Image.fromarray(new_image, mode="RGBA")
|
| 91 |
+
return new_image
|