gokaygokay commited on
Commit
0038320
·
1 Parent(s): 0b32f48
Files changed (39) hide show
  1. app.py +0 -162
  2. demo_files/comp.gif +0 -3
  3. demo_files/examples/animal_character.png +0 -3
  4. demo_files/examples/animal_character_2.png +0 -3
  5. demo_files/examples/axe.png +0 -0
  6. demo_files/examples/chair1.png +0 -0
  7. demo_files/examples/character1.png +0 -0
  8. demo_files/examples/otter_samurai.png +0 -0
  9. demo_files/examples/raccoon_wizard.png +0 -0
  10. demo_files/examples/stylized-rocks.png +0 -0
  11. demo_files/examples/tree.png +0 -0
  12. demo_files/hdri/abandoned_tiled_room_1k.hdr +0 -0
  13. demo_files/hdri/metro_noord_1k.hdr +0 -0
  14. demo_files/hdri/neon_photostudio_1k.hdr +0 -0
  15. demo_files/hdri/peppermint_powerplant_1k.hdr +0 -0
  16. demo_files/hdri/rainforest_trail_1k.hdr +0 -0
  17. demo_files/hdri/studio_small_08_1k.hdr +0 -0
  18. demo_files/hdri/urban_alley_01_1k.hdr +0 -0
  19. demo_files/scatterplot.jpg +0 -0
  20. demo_files/teaser.gif +0 -3
  21. load/tets/160_tets.npz +0 -3
  22. sf3d/box_uv_unwrap.py +0 -610
  23. sf3d/models/camera.py +0 -32
  24. sf3d/models/global_estimator/multi_head_estimator.py +0 -118
  25. sf3d/models/image_estimator/clip_based_estimator.py +0 -168
  26. sf3d/models/isosurface.py +0 -229
  27. sf3d/models/mesh.py +0 -172
  28. sf3d/models/network.py +0 -195
  29. sf3d/models/tokenizers/dinov2.py +0 -1196
  30. sf3d/models/tokenizers/image.py +0 -99
  31. sf3d/models/tokenizers/triplane.py +0 -49
  32. sf3d/models/transformers/attention.py +0 -31
  33. sf3d/models/transformers/backbone.py +0 -515
  34. sf3d/models/utils.py +0 -292
  35. sf3d/system.py +0 -482
  36. sf3d/texture_baker.py +0 -87
  37. sf3d/texture_baker.slang +0 -93
  38. sf3d/utils.py +0 -91
  39. stable_fast.py +0 -355
app.py DELETED
@@ -1,162 +0,0 @@
1
- import spaces
2
- import os
3
- import tempfile
4
- import time
5
- import gradio as gr
6
- import torch
7
- from PIL import Image
8
- from diffusers import DiffusionPipeline
9
- from huggingface_hub import hf_hub_download
10
- from sf3d.system import SF3D
11
- import sf3d.utils as sf3d_utils
12
- from gradio_litmodel3d import LitModel3D
13
- from huggingface_hub import login
14
- import subprocess
15
-
16
- dtype = torch.bfloat16
17
-
18
- torch.backends.cuda.matmul.allow_tf32 = True
19
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
20
-
21
- device = torch.device('cuda')
22
-
23
- import shutil
24
-
25
- def find_cuda():
26
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
27
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
28
-
29
- if cuda_home and os.path.exists(cuda_home):
30
- return cuda_home
31
-
32
- # Search for the nvcc executable in the system's PATH
33
- nvcc_path = shutil.which('nvcc')
34
-
35
- if nvcc_path:
36
- # Remove the 'bin/nvcc' part to get the CUDA installation path
37
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
38
- return cuda_path
39
-
40
- return None
41
-
42
- cuda_path = find_cuda()
43
-
44
- if cuda_path:
45
- print(f"CUDA installation found at: {cuda_path}")
46
- else:
47
- print("CUDA installation not found")
48
-
49
- login(token=huggingface_token)
50
- # Set up environment and cache
51
- cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
52
- os.environ["TRANSFORMERS_CACHE"] = cache_path
53
- os.environ["HF_HUB_CACHE"] = cache_path
54
- os.environ["HF_HOME"] = cache_path
55
-
56
- if not os.path.exists(cache_path):
57
- os.makedirs(cache_path, exist_ok=True)
58
-
59
- # Initialize Flux pipeline
60
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
61
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
62
- pipe.fuse_lora(lora_scale=0.125)
63
- pipe.to(device="cuda", dtype=torch.bfloat16)
64
-
65
- # Initialize SF3D model
66
- sf3d_model = SF3D.from_pretrained(
67
- "stabilityai/stable-fast-3d",
68
- config_name="config.yaml",
69
- weight_name="model.safetensors",
70
-
71
- ).eval().to(device)
72
-
73
- # Constants for SF3D
74
- COND_WIDTH, COND_HEIGHT = 512, 512
75
- COND_DISTANCE, COND_FOVY_DEG = 1.6, 40
76
- BACKGROUND_COLOR = [0.5, 0.5, 0.5]
77
-
78
- c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
79
- intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
80
- COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
81
- )
82
-
83
- def generate_image(prompt, height, width, steps, scales, seed):
84
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
85
- return pipe(
86
- prompt=[prompt],
87
- generator=torch.Generator().manual_seed(int(seed)),
88
- num_inference_steps=int(steps),
89
- guidance_scale=float(scales),
90
- height=int(height),
91
- width=int(width),
92
- max_sequence_length=256
93
- ).images[0]
94
-
95
- def create_batch(input_image: Image.Image) -> dict:
96
- img_cond = torch.from_numpy(
97
- np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
98
- ).float().clip(0, 1)
99
- mask_cond = img_cond[:, :, -1:]
100
- rgb_cond = torch.lerp(
101
- torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
102
- )
103
-
104
- batch_elem = {
105
- "rgb_cond": rgb_cond,
106
- "mask_cond": mask_cond,
107
- "c2w_cond": c2w_cond.unsqueeze(0),
108
- "intrinsic_cond": intrinsic.unsqueeze(0),
109
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
110
- }
111
- return {k: v.unsqueeze(0) for k, v in batch_elem.items()}
112
-
113
- def generate_3d_model(input_image):
114
- with torch.no_grad():
115
- with torch.autocast(device_type="cuda", dtype=torch.float16):
116
- model_batch = create_batch(input_image)
117
- model_batch = {k: v.cuda() for k, v in model_batch.items()}
118
- trimesh_mesh, _ = sf3d_model.generate_mesh(model_batch, 1024)
119
- trimesh_mesh = trimesh_mesh[0]
120
-
121
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
122
- trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
123
- return tmp_file.name
124
-
125
- @spaces.GPU
126
- def process_and_generate(prompt, height, width, steps, scales, seed):
127
- # Generate image from prompt
128
- generated_image = generate_image(prompt, height, width, steps, scales, seed)
129
-
130
- # Generate 3D model from the image
131
- glb_file = generate_3d_model(generated_image)
132
-
133
- return generated_image, glb_file
134
-
135
- # Gradio interface
136
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
137
- gr.Markdown("# Text-to-3D Model Generator")
138
-
139
- with gr.Row():
140
- with gr.Column(scale=3):
141
- prompt = gr.Textbox(label="Your Image Description", lines=3)
142
- with gr.Accordion("Advanced Settings", open=False):
143
- height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
144
- width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
145
- steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
146
- scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
147
- seed = gr.Number(label="Seed", value=3413, precision=0)
148
-
149
- generate_btn = gr.Button("Generate 3D Model", variant="primary")
150
-
151
- with gr.Column(scale=4):
152
- output_image = gr.Image(label="Generated Image")
153
- output_3d = LitModel3D(label="3D Model", clear_color=[0.0, 0.0, 0.0, 0.0])
154
-
155
- generate_btn.click(
156
- process_and_generate,
157
- inputs=[prompt, height, width, steps, scales, seed],
158
- outputs=[output_image, output_3d]
159
- )
160
-
161
- if __name__ == "__main__":
162
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_files/comp.gif DELETED

Git LFS Details

  • SHA256: 1d5e060d90f29889c55c1c5681dbeb4b4c2408709d18f7451bb0a6f02c6e9bc5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.93 MB
demo_files/examples/animal_character.png DELETED

Git LFS Details

  • SHA256: 5949f60c651e71a41b7291197f91bb8be2c8861472765fc884e604e18b7806a0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
demo_files/examples/animal_character_2.png DELETED

Git LFS Details

  • SHA256: ffc3f10c629afd64798d38dad2cc419eb343c7106149426f78634a91367bf031
  • Pointer size: 132 Bytes
  • Size of remote file: 1.6 MB
demo_files/examples/axe.png DELETED
Binary file (277 kB)
 
demo_files/examples/chair1.png DELETED
Binary file (115 kB)
 
demo_files/examples/character1.png DELETED
Binary file (120 kB)
 
demo_files/examples/otter_samurai.png DELETED
Binary file (980 kB)
 
demo_files/examples/raccoon_wizard.png DELETED
Binary file (774 kB)
 
demo_files/examples/stylized-rocks.png DELETED
Binary file (439 kB)
 
demo_files/examples/tree.png DELETED
Binary file (693 kB)
 
demo_files/hdri/abandoned_tiled_room_1k.hdr DELETED
Binary file (478 kB)
 
demo_files/hdri/metro_noord_1k.hdr DELETED
Binary file (467 kB)
 
demo_files/hdri/neon_photostudio_1k.hdr DELETED
Binary file (438 kB)
 
demo_files/hdri/peppermint_powerplant_1k.hdr DELETED
Binary file (473 kB)
 
demo_files/hdri/rainforest_trail_1k.hdr DELETED
Binary file (512 kB)
 
demo_files/hdri/studio_small_08_1k.hdr DELETED
Binary file (412 kB)
 
demo_files/hdri/urban_alley_01_1k.hdr DELETED
Binary file (458 kB)
 
demo_files/scatterplot.jpg DELETED
Binary file (879 kB)
 
demo_files/teaser.gif DELETED

Git LFS Details

  • SHA256: 1d5dcb4fbe710e94c0fa70cc2c783d66e327222cb5e74839cfd003e619bc2e1d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.81 MB
load/tets/160_tets.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f4be37efc604d28d55a1a78c2aabefeeab7e63149f541aa45f9dd858ee35bb9
3
- size 15408790
 
 
 
 
sf3d/box_uv_unwrap.py DELETED
@@ -1,610 +0,0 @@
1
- import math
2
- from typing import Tuple
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from jaxtyping import Float, Integer
7
- from torch import Tensor
8
-
9
- from sf3d.models.utils import dot, triangle_intersection_2d
10
-
11
-
12
- def _box_assign_vertex_to_cube_face(
13
- vertex_positions: Float[Tensor, "Nv 3"],
14
- vertex_normals: Float[Tensor, "Nv 3"],
15
- triangle_idxs: Integer[Tensor, "Nf 3"],
16
- bbox: Float[Tensor, "2 3"],
17
- ) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]:
18
- # Test to not have a scaled model to fit the space better
19
- # bbox_min = bbox[:1].mean(-1, keepdim=True)
20
- # bbox_max = bbox[1:].mean(-1, keepdim=True)
21
- # v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min)
22
-
23
- # Create a [0, 1] normalized vertex position
24
- v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1])
25
- # And to [-1, 1]
26
- v_pos_normalized = 2.0 * v_pos_normalized - 1.0
27
-
28
- # Get all vertex positions for each triangle
29
- # Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos?
30
- v0 = v_pos_normalized[triangle_idxs[:, 0]]
31
- v1 = v_pos_normalized[triangle_idxs[:, 1]]
32
- v2 = v_pos_normalized[triangle_idxs[:, 2]]
33
- tri_stack = torch.stack([v0, v1, v2], dim=1)
34
-
35
- vn0 = vertex_normals[triangle_idxs[:, 0]]
36
- vn1 = vertex_normals[triangle_idxs[:, 1]]
37
- vn2 = vertex_normals[triangle_idxs[:, 2]]
38
- tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1)
39
-
40
- # Just average the normals per face
41
- face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1)
42
-
43
- # Now decide based on the face normal in which box map we project
44
- # abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1)
45
- abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1)
46
-
47
- axis = torch.tensor(
48
- [
49
- [1, 0, 0], # 0
50
- [-1, 0, 0], # 1
51
- [0, 1, 0], # 2
52
- [0, -1, 0], # 3
53
- [0, 0, 1], # 4
54
- [0, 0, -1], # 5
55
- ],
56
- device=face_normal.device,
57
- dtype=face_normal.dtype,
58
- )
59
- face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1)
60
- index = face_normal_axis.argmax(-1)
61
-
62
- max_axis, uc, vc = (
63
- torch.ones_like(abs_x),
64
- torch.zeros_like(tri_stack[..., :1]),
65
- torch.zeros_like(tri_stack[..., :1]),
66
- )
67
- mask_pos_x = index == 0
68
- max_axis[mask_pos_x] = abs_x[mask_pos_x]
69
- uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2]
70
- vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:]
71
-
72
- mask_neg_x = index == 1
73
- max_axis[mask_neg_x] = abs_x[mask_neg_x]
74
- uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2]
75
- vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:]
76
-
77
- mask_pos_y = index == 2
78
- max_axis[mask_pos_y] = abs_y[mask_pos_y]
79
- uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1]
80
- vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:]
81
-
82
- mask_neg_y = index == 3
83
- max_axis[mask_neg_y] = abs_y[mask_neg_y]
84
- uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1]
85
- vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:]
86
-
87
- mask_pos_z = index == 4
88
- max_axis[mask_pos_z] = abs_z[mask_pos_z]
89
- uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1]
90
- vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2]
91
-
92
- mask_neg_z = index == 5
93
- max_axis[mask_neg_z] = abs_z[mask_neg_z]
94
- uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1]
95
- vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2]
96
-
97
- # UC from [-1, 1] to [0, 1]
98
- max_dim_div = max_axis.max(dim=0, keepdims=True).values
99
- uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
100
- vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1)
101
-
102
- uv = torch.stack([uc, vc], dim=-1)
103
-
104
- return uv, index
105
-
106
-
107
- def _assign_faces_uv_to_atlas_index(
108
- vertex_positions: Float[Tensor, "Nv 3"],
109
- triangle_idxs: Integer[Tensor, "Nf 3"],
110
- face_uv: Float[Tensor, "Nf 3 2"],
111
- face_index: Integer[Tensor, "Nf 3"],
112
- ) -> Integer[Tensor, "Nf"]: # noqa: F821
113
- triangle_pos = vertex_positions[triangle_idxs]
114
- # We need to do perform 3 overlap checks.
115
- # The first set is placed in the upper two thirds of the UV atlas.
116
- # Conceptually, this is the direct visible surfaces from the each cube side
117
- # The second set is placed in the lower thirds and the left half of the UV atlas.
118
- # This is the first set of occluded surfaces. They will also be saved in the projected fashion
119
- # The third pass finds all non assigned faces. They will be placed in the bottom right half of
120
- # the UV atlas in scattered fashion.
121
- assign_idx = face_index.clone()
122
- for overlap_step in range(3):
123
- overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool)
124
- for i in range(overlap_step * 6, (overlap_step + 1) * 6):
125
- mask = assign_idx == i
126
- if not mask.any():
127
- continue
128
- # Get all elements belonging to the projection face
129
- uv_triangle = face_uv[mask]
130
- cur_triangle_pos = triangle_pos[mask]
131
- # Find the center of the uv coordinates
132
- center_uv = uv_triangle.mean(dim=1, keepdim=True)
133
- # And also the radius of the triangle
134
- uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values
135
-
136
- potentially_overlapping_mask = (
137
- # Find all close triangles
138
- (center_uv[None, ...] - center_uv[:, None]).norm(dim=-1)
139
- # Do not select the same element by offseting with an large valued identity matrix
140
- + torch.eye(
141
- uv_triangle.shape[0],
142
- device=uv_triangle.device,
143
- dtype=uv_triangle.dtype,
144
- ).unsqueeze(-1)
145
- * 1000
146
- )
147
- # Mark all potentially overlapping triangles to reduce the number of triangle intersection tests
148
- potentially_overlapping_mask = (
149
- potentially_overlapping_mask
150
- <= (uv_triangle_radius.view(-1, 1, 1) * 3.0)
151
- ).squeeze(-1)
152
- overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1)
153
-
154
- # Only unique triangles (A|B and B|A should be the same)
155
- f = torch.min(overlap_coords, dim=-1).values
156
- s = torch.max(overlap_coords, dim=-1).values
157
- overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0)
158
- first, second = overlap_coords.unbind(-1)
159
-
160
- # Get the triangles
161
- tri_1 = uv_triangle[first]
162
- tri_2 = uv_triangle[second]
163
-
164
- # Perform the actual set with the reduced number of potentially overlapping triangles
165
- its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6)
166
-
167
- # So we now need to detect which triangles are the occluded ones.
168
- # We always assume the first to be the visible one (the others should move)
169
- # In the previous step we use a lexigraphical sort to get the unique pairs
170
- # In this we use a sort based on the orthographic projection
171
- ax = 0 if i < 2 else 1 if i < 4 else 2
172
- use_max = i % 2 == 1
173
-
174
- tri1_c = cur_triangle_pos[first].mean(dim=1)
175
- tri2_c = cur_triangle_pos[second].mean(dim=1)
176
-
177
- mark_first = (
178
- (tri1_c[..., ax] > tri2_c[..., ax])
179
- if use_max
180
- else (tri1_c[..., ax] < tri2_c[..., ax])
181
- )
182
- first[mark_first] = second[mark_first]
183
-
184
- # Lastly the same index can be tested multiple times.
185
- # If one marks it as overlapping we keep it marked as such.
186
- # We do this by testing if it has been marked at least once.
187
- unique_idx, rev_idx = torch.unique(first, return_inverse=True)
188
-
189
- add = torch.zeros_like(unique_idx, dtype=torch.float32)
190
- add.index_add_(0, rev_idx, its.float())
191
- its_mask = add > 0
192
-
193
- # And fill it in the overlapping indicator
194
- idx = torch.where(mask)[0][unique_idx]
195
- overlapping_indicator[idx] = its_mask
196
-
197
- # Move the index to the overlap regions (shift by 6)
198
- assign_idx[overlapping_indicator] += 6
199
-
200
- # We do not care about the correct face placement after the first 2 slices
201
- max_idx = 6 * 2
202
- return assign_idx.clamp(0, max_idx)
203
-
204
-
205
- def _find_slice_offset_and_scale(
206
- index: Integer[Tensor, "Nf"], # noqa: F821
207
- ) -> Tuple[
208
- Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821
209
- ]: # noqa: F821
210
- # 6 due to the 6 cube faces
211
- off = 1 / 3
212
- dupl_off = 1 / 6
213
-
214
- # Here, we need to decide how to pack the textures in the case of overlap
215
- def x_offset_calc(x, i):
216
- offset_calc = i // 6
217
- # Initial coordinates - just 3x2 grid
218
- if offset_calc == 0:
219
- return off * x
220
- else:
221
- # Smaller 3x2 grid plus eventual shift to right for
222
- # second overlap
223
- return dupl_off * x + min(offset_calc - 1, 1) * 0.5
224
-
225
- def y_offset_calc(x, i):
226
- offset_calc = i // 6
227
- # Initial coordinates - just a 3x2 grid
228
- if offset_calc == 0:
229
- return off * x
230
- else:
231
- # Smaller coordinates in the lowest row
232
- return dupl_off * x + off * 2
233
-
234
- offset_x = torch.zeros_like(index, dtype=torch.float32)
235
- offset_y = torch.zeros_like(index, dtype=torch.float32)
236
- offset_x_vals = [0, 1, 2, 0, 1, 2]
237
- offset_y_vals = [0, 0, 0, 1, 1, 1]
238
- for i in range(index.max().item() + 1):
239
- mask = index == i
240
- if not mask.any():
241
- continue
242
- offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i)
243
- offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i)
244
-
245
- div_x = torch.full_like(index, 6 // 2, dtype=torch.float32)
246
- # All overlap elements are saved in half scale
247
- div_x[index >= 6] = 6
248
- div_y = div_x.clone() # Same for y
249
- # Except for the random overlaps
250
- div_x[index >= 12] = 2
251
- # But the random overlaps are saved in a large block in the lower thirds
252
- div_y[index >= 12] = 3
253
-
254
- return offset_x, offset_y, div_x, div_y
255
-
256
-
257
- def rotation_flip_matrix_2d(
258
- rad: float, flip_x: bool = False, flip_y: bool = False
259
- ) -> Float[Tensor, "2 2"]:
260
- cos = math.cos(rad)
261
- sin = math.sin(rad)
262
- rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32)
263
- flip_mat = torch.tensor(
264
- [
265
- [-1 if flip_x else 1, 0],
266
- [0, -1 if flip_y else 1],
267
- ],
268
- dtype=torch.float32,
269
- )
270
-
271
- return flip_mat @ rot_mat
272
-
273
-
274
- def calculate_tangents(
275
- vertex_positions: Float[Tensor, "Nv 3"],
276
- vertex_normals: Float[Tensor, "Nv 3"],
277
- triangle_idxs: Integer[Tensor, "Nf 3"],
278
- face_uv: Float[Tensor, "Nf 3 2"],
279
- ) -> Float[Tensor, "Nf 3 4"]: # noqa: F821
280
- vn_idx = [None] * 3
281
- pos = [None] * 3
282
- tex = face_uv.unbind(1)
283
- for i in range(0, 3):
284
- pos[i] = vertex_positions[triangle_idxs[:, i]]
285
- # t_nrm_idx is always the same as t_pos_idx
286
- vn_idx[i] = triangle_idxs[:, i]
287
-
288
- tangents = torch.zeros_like(vertex_normals)
289
- tansum = torch.zeros_like(vertex_normals)
290
-
291
- # Compute tangent space for each triangle
292
- duv1 = tex[1] - tex[0]
293
- duv2 = tex[2] - tex[0]
294
- dpos1 = pos[1] - pos[0]
295
- dpos2 = pos[2] - pos[0]
296
-
297
- tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
298
-
299
- denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
300
-
301
- # Avoid division by zero for degenerated texture coordinates
302
- denom_safe = denom.clip(1e-6)
303
- tang = tng_nom / denom_safe
304
-
305
- # Update all 3 vertices
306
- for i in range(0, 3):
307
- idx = vn_idx[i][:, None].repeat(1, 3)
308
- tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
309
- tansum.scatter_add_(
310
- 0, idx, torch.ones_like(tang)
311
- ) # tansum[n_i] = tansum[n_i] + 1
312
- # Also normalize it. Here we do not normalize the individual triangles first so larger area
313
- # triangles influence the tangent space more
314
- tangents = tangents / tansum
315
-
316
- # Normalize and make sure tangent is perpendicular to normal
317
- tangents = F.normalize(tangents, dim=1)
318
- tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals)
319
-
320
- return tangents
321
-
322
-
323
- def _rotate_uv_slices_consistent_space(
324
- vertex_positions: Float[Tensor, "Nv 3"],
325
- vertex_normals: Float[Tensor, "Nv 3"],
326
- triangle_idxs: Integer[Tensor, "Nf 3"],
327
- uv: Float[Tensor, "Nf 3 2"],
328
- index: Integer[Tensor, "Nf"], # noqa: F821
329
- ):
330
- tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv)
331
- pos_stack = torch.stack(
332
- [
333
- -vertex_positions[..., 1],
334
- vertex_positions[..., 0],
335
- torch.zeros_like(vertex_positions[..., 0]),
336
- ],
337
- dim=-1,
338
- )
339
- expected_tangents = F.normalize(
340
- torch.linalg.cross(
341
- vertex_normals, torch.linalg.cross(pos_stack, vertex_normals)
342
- ),
343
- -1,
344
- )
345
-
346
- actual_tangents = tangents[triangle_idxs]
347
- expected_tangents = expected_tangents[triangle_idxs]
348
-
349
- def rotation_matrix_2d(theta):
350
- c, s = torch.cos(theta), torch.sin(theta)
351
- return torch.tensor([[c, -s], [s, c]])
352
-
353
- # Now find the rotation
354
- index_mod = index % 6 # Shouldn't happen. Just for safety
355
- for i in range(6):
356
- mask = index_mod == i
357
- if not mask.any():
358
- continue
359
-
360
- actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1))
361
- expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1))
362
-
363
- dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent)
364
- cross_product = (
365
- actual_mean_tangent[0] * expected_mean_tangent[1]
366
- - actual_mean_tangent[1] * expected_mean_tangent[0]
367
- )
368
- angle = torch.atan2(cross_product, dot_product)
369
-
370
- rot_matrix = rotation_matrix_2d(angle).to(mask.device)
371
- # Center the uv coordinate to be in the range of -1 to 1 and 0 centered
372
- uv_cur = uv[mask] * 2 - 1 # Center it first
373
- # Rotate it
374
- uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur)
375
-
376
- # Rescale uv[mask] to be within the 0-1 range
377
- uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min())
378
-
379
- return uv
380
-
381
-
382
- def _handle_slice_uvs(
383
- uv: Float[Tensor, "Nf 3 2"],
384
- index: Integer[Tensor, "Nf"], # noqa: F821
385
- island_padding: float,
386
- max_index: int = 6 * 2,
387
- ) -> Float[Tensor, "Nf 3 2"]: # noqa: F821
388
- uc, vc = uv.unbind(-1)
389
-
390
- # Get the second slice (The first overlap)
391
- index_filter = [index == i for i in range(6, max_index)]
392
-
393
- # Normalize them to always fully fill the atlas patch
394
- for i, fi in enumerate(index_filter):
395
- if fi.sum() > 0:
396
- # Scale the slice but only up to a factor of 2
397
- # This keeps the texture resolution with the first slice in line (Half space in UV)
398
- uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5)
399
- vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5)
400
-
401
- uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
402
- vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1)
403
-
404
- return torch.stack([uc_padded, vc_padded], dim=-1)
405
-
406
-
407
- def _handle_remaining_uvs(
408
- uv: Float[Tensor, "Nf 3 2"],
409
- index: Integer[Tensor, "Nf"], # noqa: F821
410
- island_padding: float,
411
- ) -> Float[Tensor, "Nf 3 2"]:
412
- uc, vc = uv.unbind(-1)
413
- # Get all remaining elements
414
- remaining_filter = index >= 6 * 2
415
- squares_left = remaining_filter.sum()
416
-
417
- if squares_left == 0:
418
- return uv
419
-
420
- uc = uc[remaining_filter]
421
- vc = vc[remaining_filter]
422
-
423
- # Or remaining triangles are distributed in a rectangle
424
- # The rectangle takes 0.5 of the entire uv space in width and 1/3 in height
425
- ratio = 0.5 * (1 / 3) # 1.5
426
- # sqrt(744/(0.5*(1/3)))
427
-
428
- mult = math.sqrt(squares_left / ratio)
429
- num_square_width = int(math.ceil(0.5 * mult))
430
- num_square_height = int(math.ceil(squares_left / num_square_width))
431
-
432
- width = 1 / num_square_width
433
- height = 1 / num_square_height
434
-
435
- # The idea is again to keep the texture resolution consistent with the first slice
436
- # This only occupys half the region in the texture chart but the scaling on the squares
437
- # assumes full coverage.
438
- clip_val = min(width, height) * 1.5
439
- # Now normalize the UVs with taking into account the maximum scaling
440
- uc = (uc - uc.min(dim=1, keepdim=True).values) / (
441
- uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True)
442
- ).clip(clip_val)
443
- vc = (vc - vc.min(dim=1, keepdim=True).values) / (
444
- vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True)
445
- ).clip(clip_val)
446
- # Add a small padding
447
- uc = (
448
- uc * (1 - island_padding * num_square_width * 0.5)
449
- + island_padding * num_square_width * 0.25
450
- ).clip(0, 1)
451
- vc = (
452
- vc * (1 - island_padding * num_square_height * 0.5)
453
- + island_padding * num_square_height * 0.25
454
- ).clip(0, 1)
455
-
456
- uc = uc * width
457
- vc = vc * height
458
-
459
- # And calculate offsets for each element
460
- idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32)
461
- x_idx = idx % num_square_width
462
- y_idx = idx // num_square_width
463
- # And move each triangle to its own spot
464
- uc = uc + x_idx[:, None] * width
465
- vc = vc + y_idx[:, None] * height
466
-
467
- uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
468
- vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1)
469
-
470
- uv[remaining_filter] = torch.stack([uc, vc], dim=-1)
471
-
472
- return uv
473
-
474
-
475
- def _distribute_individual_uvs_in_atlas(
476
- face_uv: Float[Tensor, "Nf 3 2"],
477
- assigned_faces: Integer[Tensor, "Nf"], # noqa: F821
478
- offset_x: Float[Tensor, "Nf"], # noqa: F821
479
- offset_y: Float[Tensor, "Nf"], # noqa: F821
480
- div_x: Float[Tensor, "Nf"], # noqa: F821
481
- div_y: Float[Tensor, "Nf"], # noqa: F821
482
- island_padding: float,
483
- ):
484
- # Place the slice first
485
- placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding)
486
- # Then handle the remaining overlap elements
487
- placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding)
488
-
489
- uc, vc = placed_uv.unbind(-1)
490
- uc = uc / div_x[:, None] + offset_x[:, None]
491
- vc = vc / div_y[:, None] + offset_y[:, None]
492
-
493
- uv = torch.stack([uc, vc], dim=-1).view(-1, 2)
494
-
495
- return uv
496
-
497
-
498
- def _get_unique_face_uv(
499
- uv: Float[Tensor, "Nf 3 2"],
500
- ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
501
- unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0)
502
- # And add the face to uv index mapping
503
- vtex_idx = unique_idx.view(-1, 3)
504
-
505
- return unique_uv, vtex_idx
506
-
507
-
508
- def _align_mesh_with_main_axis(
509
- vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"]
510
- ) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]:
511
- # Use pca to find the 2 main axis (third is derived by cross product)
512
- # Set the random seed so it's repeatable
513
- torch.manual_seed(0)
514
- _, _, v = torch.pca_lowrank(vertex_positions, q=2)
515
- main_axis, seconday_axis = v[:, 0], v[:, 1]
516
-
517
- main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1)
518
- # Orthogonalize the second axis
519
- seconday_axis: Float[Tensor, "3"] = F.normalize(
520
- seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1
521
- )
522
- # Create perpendicular third axis
523
- third_axis: Float[Tensor, "3"] = F.normalize(
524
- torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6
525
- )
526
-
527
- # Check to which canonical axis each aligns
528
- main_axis_max_idx = main_axis.abs().argmax().item()
529
- seconday_axis_max_idx = seconday_axis.abs().argmax().item()
530
- third_axis_max_idx = third_axis.abs().argmax().item()
531
-
532
- # Now sort the axes based on the argmax so they align with thecanonoical axes
533
- # If two axes have the same argmax move one of them
534
- all_possible_axis = {0, 1, 2}
535
- cur_index = 1
536
- while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3:
537
- # Find missing axis
538
- missing_axis = all_possible_axis - set(
539
- [main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx]
540
- )
541
- missing_axis = missing_axis.pop()
542
- # Just assign it to third axis as it had the smallest contribution to the
543
- # overall shape
544
- if cur_index == 1:
545
- third_axis_max_idx = missing_axis
546
- elif cur_index == 2:
547
- seconday_axis_max_idx = missing_axis
548
- else:
549
- raise ValueError("Could not find 3 unique axis")
550
- cur_index += 1
551
-
552
- if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3:
553
- raise ValueError("Could not find 3 unique axis")
554
-
555
- axes = [None] * 3
556
- axes[main_axis_max_idx] = main_axis
557
- axes[seconday_axis_max_idx] = seconday_axis
558
- axes[third_axis_max_idx] = third_axis
559
- # Create rotation matrix from the individual axes
560
- rot_mat = torch.stack(axes, dim=1).T
561
-
562
- # Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis
563
- vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions)
564
- vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals)
565
-
566
- return vertex_positions, vertex_normals
567
-
568
-
569
- def box_projection_uv_unwrap(
570
- vertex_positions: Float[Tensor, "Nv 3"],
571
- vertex_normals: Float[Tensor, "Nv 3"],
572
- triangle_idxs: Integer[Tensor, "Nf 3"],
573
- island_padding: float,
574
- ) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821
575
- # Align the mesh with main axis directions first
576
- vertex_positions, vertex_normals = _align_mesh_with_main_axis(
577
- vertex_positions, vertex_normals
578
- )
579
-
580
- bbox: Float[Tensor, "2 3"] = torch.stack(
581
- [vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0
582
- )
583
- # First decide in which cube face the triangle is placed
584
- face_uv, face_index = _box_assign_vertex_to_cube_face(
585
- vertex_positions, vertex_normals, triangle_idxs, bbox
586
- )
587
-
588
- # Rotate the UV islands in a way that they align with the radial z tangent space
589
- face_uv = _rotate_uv_slices_consistent_space(
590
- vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index
591
- )
592
-
593
- # Then find where where the face is placed in the atlas.
594
- # This has to detect potential overlaps
595
- assigned_atlas_index = _assign_faces_uv_to_atlas_index(
596
- vertex_positions, triangle_idxs, face_uv, face_index
597
- )
598
-
599
- # Then figure out the final place in the atlas based on the assignment
600
- offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale(
601
- assigned_atlas_index
602
- )
603
-
604
- # Next distribute the faces in the uv atlas
605
- placed_uv = _distribute_individual_uvs_in_atlas(
606
- face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding
607
- )
608
-
609
- # And get the unique per-triangle UV coordinates
610
- return _get_unique_face_uv(placed_uv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/camera.py DELETED
@@ -1,32 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from sf3d.models.utils import BaseModule
8
-
9
-
10
- class LinearCameraEmbedder(BaseModule):
11
- @dataclass
12
- class Config(BaseModule.Config):
13
- in_channels: int = 25
14
- out_channels: int = 768
15
- conditions: List[str] = field(default_factory=list)
16
-
17
- cfg: Config
18
-
19
- def configure(self) -> None:
20
- self.linear = nn.Linear(self.cfg.in_channels, self.cfg.out_channels)
21
-
22
- def forward(self, **kwargs):
23
- cond_tensors = []
24
- for cond_name in self.cfg.conditions:
25
- assert cond_name in kwargs
26
- cond = kwargs[cond_name]
27
- # cond in shape (B, Nv, ...)
28
- cond_tensors.append(cond.view(*cond.shape[:2], -1))
29
- cond_tensor = torch.cat(cond_tensors, dim=-1)
30
- assert cond_tensor.shape[-1] == self.cfg.in_channels
31
- embedding = self.linear(cond_tensor)
32
- return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/global_estimator/multi_head_estimator.py DELETED
@@ -1,118 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Any, List, Optional
3
-
4
- import torch.nn as nn
5
- from jaxtyping import Float
6
- from torch import Tensor
7
-
8
- from sf3d.models.network import get_activation
9
- from sf3d.models.utils import BaseModule
10
-
11
-
12
- @dataclass
13
- class HeadSpec:
14
- name: str
15
- out_channels: int
16
- n_hidden_layers: int
17
- output_activation: Optional[str] = None
18
- output_bias: float = 0.0
19
- add_to_decoder_features: bool = False
20
- shape: Optional[list[int]] = None
21
-
22
-
23
- class MultiHeadEstimator(BaseModule):
24
- @dataclass
25
- class Config(BaseModule.Config):
26
- triplane_features: int = 1024
27
-
28
- n_layers: int = 2
29
- hidden_features: int = 512
30
- activation: str = "relu"
31
-
32
- pool: str = "max"
33
- # Literal["mean", "max"] = "mean" # noqa: F821
34
-
35
- heads: List[HeadSpec] = field(default_factory=lambda: [])
36
-
37
- cfg: Config
38
-
39
- def configure(self):
40
- layers = []
41
- cur_features = self.cfg.triplane_features * 3
42
- for _ in range(self.cfg.n_layers):
43
- layers.append(
44
- nn.Conv2d(
45
- cur_features,
46
- self.cfg.hidden_features,
47
- kernel_size=3,
48
- padding=0,
49
- stride=2,
50
- )
51
- )
52
- layers.append(self.make_activation(self.cfg.activation))
53
-
54
- cur_features = self.cfg.hidden_features
55
-
56
- self.layers = nn.Sequential(*layers)
57
-
58
- assert len(self.cfg.heads) > 0
59
- heads = {}
60
- for head in self.cfg.heads:
61
- head_layers = []
62
- for i in range(head.n_hidden_layers):
63
- head_layers += [
64
- nn.Linear(
65
- self.cfg.hidden_features,
66
- self.cfg.hidden_features,
67
- ),
68
- self.make_activation(self.cfg.activation),
69
- ]
70
- head_layers += [
71
- nn.Linear(
72
- self.cfg.hidden_features,
73
- head.out_channels,
74
- ),
75
- ]
76
- heads[head.name] = nn.Sequential(*head_layers)
77
- self.heads = nn.ModuleDict(heads)
78
-
79
- def make_activation(self, activation):
80
- if activation == "relu":
81
- return nn.ReLU(inplace=True)
82
- elif activation == "silu":
83
- return nn.SiLU(inplace=True)
84
- else:
85
- raise NotImplementedError
86
-
87
- def forward(
88
- self,
89
- triplane: Float[Tensor, "B 3 F Ht Wt"],
90
- ) -> dict[str, Any]:
91
- x = self.layers(
92
- triplane.reshape(
93
- triplane.shape[0], -1, triplane.shape[-2], triplane.shape[-1]
94
- )
95
- )
96
-
97
- if self.cfg.pool == "max":
98
- x = x.amax(dim=[-2, -1])
99
- elif self.cfg.pool == "mean":
100
- x = x.mean(dim=[-2, -1])
101
- else:
102
- raise NotImplementedError
103
-
104
- out = {
105
- ("decoder_" if head.add_to_decoder_features else "")
106
- + head.name: get_activation(head.output_activation)(
107
- self.heads[head.name](x) + head.output_bias
108
- )
109
- for head in self.cfg.heads
110
- }
111
- for head in self.cfg.heads:
112
- if head.shape:
113
- head_name = (
114
- "decoder_" if head.add_to_decoder_features else ""
115
- ) + head.name
116
- out[head_name] = out[head_name].reshape(*head.shape)
117
-
118
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/image_estimator/clip_based_estimator.py DELETED
@@ -1,168 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Any, List, Optional
3
-
4
- import open_clip
5
- import torch
6
- import torch.nn as nn
7
- from jaxtyping import Float
8
- from torch import Tensor
9
- from torchvision.transforms import Normalize
10
-
11
- from sf3d.models.network import get_activation
12
- from sf3d.models.utils import BaseModule
13
-
14
-
15
- @dataclass
16
- class HeadSpec:
17
- name: str
18
- out_channels: int
19
- n_hidden_layers: int
20
- output_activation: Optional[str] = None
21
- output_bias: float = 0.0
22
- add_to_decoder_features: bool = False
23
- shape: Optional[list[int]] = None
24
-
25
-
26
- class ClipBasedHeadEstimator(BaseModule):
27
- @dataclass
28
- class Config(BaseModule.Config):
29
- model: str = "ViT-B-32"
30
- pretrain: str = "laion2b_s34b_b79k"
31
-
32
- distribution: str = "beta"
33
-
34
- # ["mean", "mode", "sample", "sample_mean"]
35
- distribution_eval: str = "mode"
36
-
37
- activation: str = "relu"
38
- hidden_features: int = 512
39
- heads: List[HeadSpec] = field(default_factory=lambda: [])
40
-
41
- cfg: Config
42
-
43
- def configure(self):
44
- self.model, _, self.preprocess = open_clip.create_model_and_transforms(
45
- self.cfg.model, pretrained=self.cfg.pretrain
46
- )
47
- self.model.eval()
48
-
49
- # Do not add the weights in self.model to the optimizer
50
- for param in self.model.parameters():
51
- param.requires_grad = False
52
-
53
- assert len(self.cfg.heads) > 0
54
- heads = {}
55
- for head in self.cfg.heads:
56
- head_layers = []
57
-
58
- for i in range(head.n_hidden_layers):
59
- head_layers += [
60
- nn.Linear(
61
- self.cfg.hidden_features,
62
- self.cfg.hidden_features,
63
- ),
64
- self.make_activation(self.cfg.activation),
65
- ]
66
-
67
- head_layers = [nn.Sequential(*head_layers)]
68
- head_layers += [
69
- nn.Sequential(
70
- nn.Linear(
71
- self.cfg.hidden_features,
72
- self.cfg.hidden_features,
73
- ),
74
- self.make_activation(self.cfg.activation),
75
- nn.Linear(self.cfg.hidden_features, 1),
76
- )
77
- for _ in range(2)
78
- ]
79
- heads[head.name] = nn.ModuleList(head_layers)
80
- self.heads = nn.ModuleDict(heads)
81
-
82
- def make_activation(self, activation):
83
- if activation == "relu":
84
- return nn.ReLU(inplace=True)
85
- elif activation == "silu":
86
- return nn.SiLU(inplace=True)
87
- else:
88
- raise NotImplementedError
89
-
90
- def forward(
91
- self,
92
- cond_image: Float[Tensor, "B 1 H W 3"],
93
- sample: bool = True,
94
- ) -> dict[str, Any]:
95
- # Run the model
96
- # Resize cond_image to 224
97
- cond_image = nn.functional.interpolate(
98
- cond_image.flatten(0, 1).permute(0, 3, 1, 2),
99
- size=(224, 224),
100
- mode="bilinear",
101
- align_corners=False,
102
- )
103
- cond_image = Normalize(
104
- mean=open_clip.constants.OPENAI_DATASET_MEAN,
105
- std=open_clip.constants.OPENAI_DATASET_STD,
106
- )(cond_image)
107
- image_features = self.model.encode_image(cond_image)
108
-
109
- # Run the heads
110
- outputs = {}
111
-
112
- for head_dict in self.cfg.heads:
113
- head_name = head_dict.name
114
- shared_head, d1_h, d2_h = self.heads[head_name]
115
- shared_features = shared_head(image_features)
116
- d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
117
- if self.cfg.distribution == "normal":
118
- mean = d1
119
- var = d2
120
- if mean.shape[-1] == 1:
121
- outputs[head_name] = torch.distributions.Normal(
122
- mean + head_dict.output_bias,
123
- torch.nn.functional.softplus(var),
124
- )
125
- else:
126
- outputs[head_name] = torch.distributions.MultivariateNormal(
127
- mean + head_dict.output_bias,
128
- torch.nn.functional.softplus(var).diag_embed(),
129
- )
130
- elif self.cfg.distribution == "beta":
131
- outputs[head_name] = torch.distributions.Beta(
132
- torch.nn.functional.softplus(d1 + head_dict.output_bias),
133
- torch.nn.functional.softplus(d2 + head_dict.output_bias),
134
- )
135
- else:
136
- raise NotImplementedError
137
-
138
- if sample:
139
- for head_dict in self.cfg.heads:
140
- head_name = head_dict.name
141
- dist = outputs[head_name]
142
-
143
- if self.cfg.distribution_eval == "mean":
144
- out = dist.mean
145
- elif self.cfg.distribution_eval == "mode":
146
- out = dist.mode
147
- elif self.cfg.distribution_eval == "sample_mean":
148
- out = dist.sample([10]).mean(-1)
149
- else:
150
- # use rsample if gradient is needed
151
- out = dist.rsample() if self.training else dist.sample()
152
-
153
- outputs[head_name] = get_activation(head_dict.output_activation)(out)
154
- outputs[f"{head_name}_dist"] = dist
155
-
156
- for head in self.cfg.heads:
157
- if head.shape:
158
- if not sample:
159
- raise ValueError(
160
- "Cannot reshape non-sampled probabilisitic outputs"
161
- )
162
- outputs[head.name] = outputs[head.name].reshape(*head.shape)
163
-
164
- if head.add_to_decoder_features:
165
- outputs[f"decoder_{head.name}"] = outputs[head.name]
166
- del outputs[head.name]
167
-
168
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/isosurface.py DELETED
@@ -1,229 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from jaxtyping import Float, Integer
7
- from torch import Tensor
8
-
9
- from .mesh import Mesh
10
-
11
-
12
- class IsosurfaceHelper(nn.Module):
13
- points_range: Tuple[float, float] = (0, 1)
14
-
15
- @property
16
- def grid_vertices(self) -> Float[Tensor, "N 3"]:
17
- raise NotImplementedError
18
-
19
- @property
20
- def requires_instance_per_batch(self) -> bool:
21
- return False
22
-
23
-
24
- class MarchingTetrahedraHelper(IsosurfaceHelper):
25
- def __init__(self, resolution: int, tets_path: str):
26
- super().__init__()
27
- self.resolution = resolution
28
- self.tets_path = tets_path
29
-
30
- self.triangle_table: Float[Tensor, "..."]
31
- self.register_buffer(
32
- "triangle_table",
33
- torch.as_tensor(
34
- [
35
- [-1, -1, -1, -1, -1, -1],
36
- [1, 0, 2, -1, -1, -1],
37
- [4, 0, 3, -1, -1, -1],
38
- [1, 4, 2, 1, 3, 4],
39
- [3, 1, 5, -1, -1, -1],
40
- [2, 3, 0, 2, 5, 3],
41
- [1, 4, 0, 1, 5, 4],
42
- [4, 2, 5, -1, -1, -1],
43
- [4, 5, 2, -1, -1, -1],
44
- [4, 1, 0, 4, 5, 1],
45
- [3, 2, 0, 3, 5, 2],
46
- [1, 3, 5, -1, -1, -1],
47
- [4, 1, 2, 4, 3, 1],
48
- [3, 0, 4, -1, -1, -1],
49
- [2, 0, 1, -1, -1, -1],
50
- [-1, -1, -1, -1, -1, -1],
51
- ],
52
- dtype=torch.long,
53
- ),
54
- persistent=False,
55
- )
56
- self.num_triangles_table: Integer[Tensor, "..."]
57
- self.register_buffer(
58
- "num_triangles_table",
59
- torch.as_tensor(
60
- [0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long
61
- ),
62
- persistent=False,
63
- )
64
- self.base_tet_edges: Integer[Tensor, "..."]
65
- self.register_buffer(
66
- "base_tet_edges",
67
- torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long),
68
- persistent=False,
69
- )
70
-
71
- tets = np.load(self.tets_path)
72
- self._grid_vertices: Float[Tensor, "..."]
73
- self.register_buffer(
74
- "_grid_vertices",
75
- torch.from_numpy(tets["vertices"]).float(),
76
- persistent=False,
77
- )
78
- self.indices: Integer[Tensor, "..."]
79
- self.register_buffer(
80
- "indices", torch.from_numpy(tets["indices"]).long(), persistent=False
81
- )
82
-
83
- self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None
84
-
85
- center_indices, boundary_indices = self.get_center_boundary_index(
86
- self._grid_vertices
87
- )
88
- self.center_indices: Integer[Tensor, "..."]
89
- self.register_buffer("center_indices", center_indices, persistent=False)
90
- self.boundary_indices: Integer[Tensor, "..."]
91
- self.register_buffer("boundary_indices", boundary_indices, persistent=False)
92
-
93
- def get_center_boundary_index(self, verts):
94
- magn = torch.sum(verts**2, dim=-1)
95
-
96
- center_idx = torch.argmin(magn)
97
- boundary_neg = verts == verts.max()
98
- boundary_pos = verts == verts.min()
99
-
100
- boundary = torch.bitwise_or(boundary_pos, boundary_neg)
101
- boundary = torch.sum(boundary.float(), dim=-1)
102
-
103
- boundary_idx = torch.nonzero(boundary)
104
- return center_idx, boundary_idx.squeeze(dim=-1)
105
-
106
- def normalize_grid_deformation(
107
- self, grid_vertex_offsets: Float[Tensor, "Nv 3"]
108
- ) -> Float[Tensor, "Nv 3"]:
109
- return (
110
- (self.points_range[1] - self.points_range[0])
111
- / self.resolution # half tet size is approximately 1 / self.resolution
112
- * torch.tanh(grid_vertex_offsets)
113
- ) # FIXME: hard-coded activation
114
-
115
- @property
116
- def grid_vertices(self) -> Float[Tensor, "Nv 3"]:
117
- return self._grid_vertices
118
-
119
- @property
120
- def all_edges(self) -> Integer[Tensor, "Ne 2"]:
121
- if self._all_edges is None:
122
- # compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation)
123
- edges = torch.tensor(
124
- [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3],
125
- dtype=torch.long,
126
- device=self.indices.device,
127
- )
128
- _all_edges = self.indices[:, edges].reshape(-1, 2)
129
- _all_edges_sorted = torch.sort(_all_edges, dim=1)[0]
130
- _all_edges = torch.unique(_all_edges_sorted, dim=0)
131
- self._all_edges = _all_edges
132
- return self._all_edges
133
-
134
- def sort_edges(self, edges_ex2):
135
- with torch.no_grad():
136
- order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long()
137
- order = order.unsqueeze(dim=1)
138
-
139
- a = torch.gather(input=edges_ex2, index=order, dim=1)
140
- b = torch.gather(input=edges_ex2, index=1 - order, dim=1)
141
-
142
- return torch.stack([a, b], -1)
143
-
144
- def _forward(self, pos_nx3, sdf_n, tet_fx4):
145
- with torch.no_grad():
146
- occ_n = sdf_n > 0
147
- occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4)
148
- occ_sum = torch.sum(occ_fx4, -1)
149
- valid_tets = (occ_sum > 0) & (occ_sum < 4)
150
- occ_sum = occ_sum[valid_tets]
151
-
152
- # find all vertices
153
- all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2)
154
- all_edges = self.sort_edges(all_edges)
155
- unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)
156
-
157
- unique_edges = unique_edges.long()
158
- mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
159
- mapping = (
160
- torch.ones(
161
- (unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device
162
- )
163
- * -1
164
- )
165
- mapping[mask_edges] = torch.arange(
166
- mask_edges.sum(), dtype=torch.long, device=pos_nx3.device
167
- )
168
- idx_map = mapping[idx_map] # map edges to verts
169
-
170
- interp_v = unique_edges[mask_edges]
171
- edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3)
172
- edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1)
173
- edges_to_interp_sdf[:, -1] *= -1
174
-
175
- denominator = edges_to_interp_sdf.sum(1, keepdim=True)
176
-
177
- edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
178
- verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
179
-
180
- idx_map = idx_map.reshape(-1, 6)
181
-
182
- v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device))
183
- tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
184
- num_triangles = self.num_triangles_table[tetindex]
185
-
186
- # Generate triangle indices
187
- faces = torch.cat(
188
- (
189
- torch.gather(
190
- input=idx_map[num_triangles == 1],
191
- dim=1,
192
- index=self.triangle_table[tetindex[num_triangles == 1]][:, :3],
193
- ).reshape(-1, 3),
194
- torch.gather(
195
- input=idx_map[num_triangles == 2],
196
- dim=1,
197
- index=self.triangle_table[tetindex[num_triangles == 2]][:, :6],
198
- ).reshape(-1, 3),
199
- ),
200
- dim=0,
201
- )
202
-
203
- return verts, faces
204
-
205
- def forward(
206
- self,
207
- level: Float[Tensor, "N3 1"],
208
- deformation: Optional[Float[Tensor, "N3 3"]] = None,
209
- ) -> Mesh:
210
- if deformation is not None:
211
- grid_vertices = self.grid_vertices + self.normalize_grid_deformation(
212
- deformation
213
- )
214
- else:
215
- grid_vertices = self.grid_vertices
216
-
217
- v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices)
218
-
219
- mesh = Mesh(
220
- v_pos=v_pos,
221
- t_pos_idx=t_pos_idx,
222
- # extras
223
- grid_vertices=grid_vertices,
224
- tet_edges=self.all_edges,
225
- grid_level=level,
226
- grid_deformation=deformation,
227
- )
228
-
229
- return mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/mesh.py DELETED
@@ -1,172 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any, Dict, Optional
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from jaxtyping import Float, Integer
8
- from torch import Tensor
9
-
10
- from sf3d.box_uv_unwrap import box_projection_uv_unwrap
11
- from sf3d.models.utils import dot
12
-
13
-
14
- class Mesh:
15
- def __init__(
16
- self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs
17
- ) -> None:
18
- self.v_pos: Float[Tensor, "Nv 3"] = v_pos
19
- self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
20
- self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
21
- self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
22
- self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
23
- self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
24
- self.extras: Dict[str, Any] = {}
25
- for k, v in kwargs.items():
26
- self.add_extra(k, v)
27
-
28
- def add_extra(self, k, v) -> None:
29
- self.extras[k] = v
30
-
31
- @property
32
- def requires_grad(self):
33
- return self.v_pos.requires_grad
34
-
35
- @property
36
- def v_nrm(self):
37
- if self._v_nrm is None:
38
- self._v_nrm = self._compute_vertex_normal()
39
- return self._v_nrm
40
-
41
- @property
42
- def v_tng(self):
43
- if self._v_tng is None:
44
- self._v_tng = self._compute_vertex_tangent()
45
- return self._v_tng
46
-
47
- @property
48
- def v_tex(self):
49
- if self._v_tex is None:
50
- self.unwrap_uv()
51
- return self._v_tex
52
-
53
- @property
54
- def edges(self):
55
- if self._edges is None:
56
- self._edges = self._compute_edges()
57
- return self._edges
58
-
59
- def _compute_vertex_normal(self):
60
- i0 = self.t_pos_idx[:, 0]
61
- i1 = self.t_pos_idx[:, 1]
62
- i2 = self.t_pos_idx[:, 2]
63
-
64
- v0 = self.v_pos[i0, :]
65
- v1 = self.v_pos[i1, :]
66
- v2 = self.v_pos[i2, :]
67
-
68
- face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
69
-
70
- # Splat face normals to vertices
71
- v_nrm = torch.zeros_like(self.v_pos)
72
- v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
73
- v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
74
- v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
75
-
76
- # Normalize, replace zero (degenerated) normals with some default value
77
- v_nrm = torch.where(
78
- dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
79
- )
80
- v_nrm = F.normalize(v_nrm, dim=1)
81
-
82
- if torch.is_anomaly_enabled():
83
- assert torch.all(torch.isfinite(v_nrm))
84
-
85
- return v_nrm
86
-
87
- def _compute_vertex_tangent(self):
88
- vn_idx = [None] * 3
89
- pos = [None] * 3
90
- tex = [None] * 3
91
- for i in range(0, 3):
92
- pos[i] = self.v_pos[self.t_pos_idx[:, i]]
93
- tex[i] = self.v_tex[self.t_pos_idx[:, i]]
94
- # t_nrm_idx is always the same as t_pos_idx
95
- vn_idx[i] = self.t_pos_idx[:, i]
96
-
97
- tangents = torch.zeros_like(self.v_nrm)
98
- tansum = torch.zeros_like(self.v_nrm)
99
-
100
- # Compute tangent space for each triangle
101
- duv1 = tex[1] - tex[0]
102
- duv2 = tex[2] - tex[0]
103
- dpos1 = pos[1] - pos[0]
104
- dpos2 = pos[2] - pos[0]
105
-
106
- tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2]
107
-
108
- denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1]
109
-
110
- # Avoid division by zero for degenerated texture coordinates
111
- denom_safe = denom.clip(1e-6)
112
- tang = tng_nom / denom_safe
113
-
114
- # Update all 3 vertices
115
- for i in range(0, 3):
116
- idx = vn_idx[i][:, None].repeat(1, 3)
117
- tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
118
- tansum.scatter_add_(
119
- 0, idx, torch.ones_like(tang)
120
- ) # tansum[n_i] = tansum[n_i] + 1
121
- # Also normalize it. Here we do not normalize the individual triangles first so larger area
122
- # triangles influence the tangent space more
123
- tangents = tangents / tansum
124
-
125
- # Normalize and make sure tangent is perpendicular to normal
126
- tangents = F.normalize(tangents, dim=1)
127
- tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
128
-
129
- if torch.is_anomaly_enabled():
130
- assert torch.all(torch.isfinite(tangents))
131
-
132
- return tangents
133
-
134
- @torch.no_grad()
135
- def unwrap_uv(
136
- self,
137
- island_padding: float = 0.02,
138
- ) -> Mesh:
139
- uv, indices = box_projection_uv_unwrap(
140
- self.v_pos, self.v_nrm, self.t_pos_idx, island_padding
141
- )
142
-
143
- # Do store per vertex UVs.
144
- # This means we need to duplicate some vertices at the seams
145
- individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3)
146
- individual_faces = torch.arange(
147
- individual_vertices.shape[0],
148
- device=individual_vertices.device,
149
- dtype=self.t_pos_idx.dtype,
150
- ).reshape(-1, 3)
151
- uv_flat = uv[indices].reshape((-1, 2))
152
- # uv_flat[:, 1] = 1 - uv_flat[:, 1]
153
-
154
- self.v_pos = individual_vertices
155
- self.t_pos_idx = individual_faces
156
- self._v_tex = uv_flat
157
- self._v_nrm = self._compute_vertex_normal()
158
- self._v_tng = self._compute_vertex_tangent()
159
-
160
- def _compute_edges(self):
161
- # Compute edges
162
- edges = torch.cat(
163
- [
164
- self.t_pos_idx[:, [0, 1]],
165
- self.t_pos_idx[:, [1, 2]],
166
- self.t_pos_idx[:, [2, 0]],
167
- ],
168
- dim=0,
169
- )
170
- edges = edges.sort()[0]
171
- edges = torch.unique(edges, dim=0)
172
- return edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/network.py DELETED
@@ -1,195 +0,0 @@
1
- from dataclasses import dataclass, field
2
- from typing import Callable, List, Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from einops import rearrange
8
- from jaxtyping import Float
9
- from torch import Tensor
10
- from torch.autograd import Function
11
- from torch.cuda.amp import custom_bwd, custom_fwd
12
-
13
- from sf3d.models.utils import BaseModule, normalize
14
-
15
-
16
- class PixelShuffleUpsampleNetwork(BaseModule):
17
- @dataclass
18
- class Config(BaseModule.Config):
19
- in_channels: int = 1024
20
- out_channels: int = 40
21
- scale_factor: int = 4
22
-
23
- conv_layers: int = 4
24
- conv_kernel_size: int = 3
25
-
26
- cfg: Config
27
-
28
- def configure(self) -> None:
29
- layers = []
30
- output_channels = self.cfg.out_channels * self.cfg.scale_factor**2
31
-
32
- in_channels = self.cfg.in_channels
33
- for i in range(self.cfg.conv_layers):
34
- cur_out_channels = (
35
- in_channels if i != self.cfg.conv_layers - 1 else output_channels
36
- )
37
- layers.append(
38
- nn.Conv2d(
39
- in_channels,
40
- cur_out_channels,
41
- self.cfg.conv_kernel_size,
42
- padding=(self.cfg.conv_kernel_size - 1) // 2,
43
- )
44
- )
45
- if i != self.cfg.conv_layers - 1:
46
- layers.append(nn.ReLU(inplace=True))
47
-
48
- layers.append(nn.PixelShuffle(self.cfg.scale_factor))
49
-
50
- self.upsample = nn.Sequential(*layers)
51
-
52
- def forward(
53
- self, triplanes: Float[Tensor, "B 3 Ci Hp Wp"]
54
- ) -> Float[Tensor, "B 3 Co Hp2 Wp2"]:
55
- return rearrange(
56
- self.upsample(
57
- rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
58
- ),
59
- "(B Np) Co Hp Wp -> B Np Co Hp Wp",
60
- Np=3,
61
- )
62
-
63
-
64
- class _TruncExp(Function): # pylint: disable=abstract-method
65
- # Implementation from torch-ngp:
66
- # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
67
- @staticmethod
68
- @custom_fwd(cast_inputs=torch.float32)
69
- def forward(ctx, x): # pylint: disable=arguments-differ
70
- ctx.save_for_backward(x)
71
- return torch.exp(x)
72
-
73
- @staticmethod
74
- @custom_bwd
75
- def backward(ctx, g): # pylint: disable=arguments-differ
76
- x = ctx.saved_tensors[0]
77
- return g * torch.exp(torch.clamp(x, max=15))
78
-
79
-
80
- trunc_exp = _TruncExp.apply
81
-
82
-
83
- def get_activation(name) -> Callable:
84
- if name is None:
85
- return lambda x: x
86
- name = name.lower()
87
- if name == "none" or name == "linear" or name == "identity":
88
- return lambda x: x
89
- elif name == "lin2srgb":
90
- return lambda x: torch.where(
91
- x > 0.0031308,
92
- torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
93
- 12.92 * x,
94
- ).clamp(0.0, 1.0)
95
- elif name == "exp":
96
- return lambda x: torch.exp(x)
97
- elif name == "shifted_exp":
98
- return lambda x: torch.exp(x - 1.0)
99
- elif name == "trunc_exp":
100
- return trunc_exp
101
- elif name == "shifted_trunc_exp":
102
- return lambda x: trunc_exp(x - 1.0)
103
- elif name == "sigmoid":
104
- return lambda x: torch.sigmoid(x)
105
- elif name == "tanh":
106
- return lambda x: torch.tanh(x)
107
- elif name == "shifted_softplus":
108
- return lambda x: F.softplus(x - 1.0)
109
- elif name == "scale_-11_01":
110
- return lambda x: x * 0.5 + 0.5
111
- elif name == "negative":
112
- return lambda x: -x
113
- elif name == "normalize_channel_last":
114
- return lambda x: normalize(x)
115
- elif name == "normalize_channel_first":
116
- return lambda x: normalize(x, dim=1)
117
- else:
118
- try:
119
- return getattr(F, name)
120
- except AttributeError:
121
- raise ValueError(f"Unknown activation function: {name}")
122
-
123
-
124
- @dataclass
125
- class HeadSpec:
126
- name: str
127
- out_channels: int
128
- n_hidden_layers: int
129
- output_activation: Optional[str] = None
130
- out_bias: float = 0.0
131
-
132
-
133
- class MaterialMLP(BaseModule):
134
- @dataclass
135
- class Config(BaseModule.Config):
136
- in_channels: int = 120
137
- n_neurons: int = 64
138
- activation: str = "silu"
139
- heads: List[HeadSpec] = field(default_factory=lambda: [])
140
-
141
- cfg: Config
142
-
143
- def configure(self) -> None:
144
- assert len(self.cfg.heads) > 0
145
- heads = {}
146
- for head in self.cfg.heads:
147
- head_layers = []
148
- for i in range(head.n_hidden_layers):
149
- head_layers += [
150
- nn.Linear(
151
- self.cfg.in_channels if i == 0 else self.cfg.n_neurons,
152
- self.cfg.n_neurons,
153
- ),
154
- self.make_activation(self.cfg.activation),
155
- ]
156
- head_layers += [
157
- nn.Linear(
158
- self.cfg.n_neurons,
159
- head.out_channels,
160
- ),
161
- ]
162
- heads[head.name] = nn.Sequential(*head_layers)
163
- self.heads = nn.ModuleDict(heads)
164
-
165
- def make_activation(self, activation):
166
- if activation == "relu":
167
- return nn.ReLU(inplace=True)
168
- elif activation == "silu":
169
- return nn.SiLU(inplace=True)
170
- else:
171
- raise NotImplementedError
172
-
173
- def keys(self):
174
- return self.heads.keys()
175
-
176
- def forward(
177
- self, x, include: Optional[List] = None, exclude: Optional[List] = None
178
- ):
179
- if include is not None and exclude is not None:
180
- raise ValueError("Cannot specify both include and exclude.")
181
- if include is not None:
182
- heads = [h for h in self.cfg.heads if h.name in include]
183
- elif exclude is not None:
184
- heads = [h for h in self.cfg.heads if h.name not in exclude]
185
- else:
186
- heads = self.cfg.heads
187
-
188
- out = {
189
- head.name: get_activation(head.output_activation)(
190
- self.heads[head.name](x) + head.out_bias
191
- )
192
- for head in heads
193
- }
194
-
195
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/tokenizers/dinov2.py DELETED
@@ -1,1196 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch DINOv2 model."""
16
-
17
- import collections.abc
18
- import math
19
- from dataclasses import dataclass
20
- from typing import Dict, List, Optional, Set, Tuple, Union
21
-
22
- import torch
23
- import torch.nn.functional as F
24
- import torch.utils.checkpoint
25
- from torch import nn
26
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
- from transformers.activations import ACT2FN
28
- from transformers.modeling_outputs import (
29
- BackboneOutput,
30
- BaseModelOutput,
31
- BaseModelOutputWithPooling,
32
- ImageClassifierOutput,
33
- )
34
- from transformers.modeling_utils import PreTrainedModel
35
- from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
36
- from transformers.pytorch_utils import (
37
- find_pruneable_heads_and_indices,
38
- prune_linear_layer,
39
- )
40
- from transformers.utils import (
41
- add_code_sample_docstrings,
42
- add_start_docstrings,
43
- add_start_docstrings_to_model_forward,
44
- logging,
45
- replace_return_docstrings,
46
- )
47
- from transformers.utils.backbone_utils import BackboneMixin
48
-
49
- logger = logging.get_logger(__name__)
50
-
51
- # General docstring
52
- _CONFIG_FOR_DOC = "Dinov2Config"
53
-
54
- # Base docstring
55
- _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
56
- _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
57
-
58
- # Image classification docstring
59
- _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
60
-
61
-
62
- DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
- "facebook/dinov2-base",
64
- # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
65
- ]
66
-
67
-
68
- class Dinov2Embeddings(nn.Module):
69
- """
70
- Construct the CLS token, mask token, position and patch embeddings.
71
- """
72
-
73
- def __init__(self, config: Dinov2Config) -> None:
74
- super().__init__()
75
-
76
- self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
77
- # register as mask token as it's not used in optimization
78
- # to avoid the use of find_unused_parameters_true
79
- # self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
- self.register_buffer("mask_token", torch.zeros(1, config.hidden_size))
81
- self.patch_embeddings = Dinov2PatchEmbeddings(config)
82
- num_patches = self.patch_embeddings.num_patches
83
- self.position_embeddings = nn.Parameter(
84
- torch.randn(1, num_patches + 1, config.hidden_size)
85
- )
86
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
87
- self.config = config
88
-
89
- def interpolate_pos_encoding(
90
- self, embeddings: torch.Tensor, height: int, width: int
91
- ) -> torch.Tensor:
92
- """
93
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
- resolution images.
95
-
96
- Source:
97
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
- """
99
-
100
- num_patches = embeddings.shape[1] - 1
101
- num_positions = self.position_embeddings.shape[1] - 1
102
- if num_patches == num_positions and height == width:
103
- return self.position_embeddings
104
- class_pos_embed = self.position_embeddings[:, 0]
105
- patch_pos_embed = self.position_embeddings[:, 1:]
106
- dim = embeddings.shape[-1]
107
- height = height // self.config.patch_size
108
- width = width // self.config.patch_size
109
- # we add a small number to avoid floating point error in the interpolation
110
- # see discussion at https://github.com/facebookresearch/dino/issues/8
111
- height, width = height + 0.1, width + 0.1
112
- patch_pos_embed = patch_pos_embed.reshape(
113
- 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
114
- )
115
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
116
- patch_pos_embed = nn.functional.interpolate(
117
- patch_pos_embed,
118
- scale_factor=(
119
- height / math.sqrt(num_positions),
120
- width / math.sqrt(num_positions),
121
- ),
122
- mode="bicubic",
123
- align_corners=False,
124
- )
125
- if (
126
- int(height) != patch_pos_embed.shape[-2]
127
- or int(width) != patch_pos_embed.shape[-1]
128
- ):
129
- raise ValueError(
130
- "Width or height does not match with the interpolated position embeddings"
131
- )
132
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
-
135
- def forward(
136
- self,
137
- pixel_values: torch.Tensor,
138
- bool_masked_pos: Optional[torch.Tensor] = None,
139
- ) -> torch.Tensor:
140
- batch_size, _, height, width = pixel_values.shape
141
- patch_embeddings = self.patch_embeddings(pixel_values)
142
- embeddings = patch_embeddings
143
-
144
- if bool_masked_pos is not None:
145
- embeddings = torch.where(
146
- bool_masked_pos.unsqueeze(-1),
147
- self.mask_token.to(embeddings.dtype).unsqueeze(0),
148
- embeddings,
149
- )
150
-
151
- # add the [CLS] token to the embedded patch tokens
152
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
153
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
154
-
155
- # add positional encoding to each token
156
- embeddings = embeddings + self.interpolate_pos_encoding(
157
- embeddings, height, width
158
- )
159
-
160
- embeddings = self.dropout(embeddings)
161
-
162
- return embeddings
163
-
164
-
165
- class Dinov2PatchEmbeddings(nn.Module):
166
- """
167
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
168
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
169
- Transformer.
170
- """
171
-
172
- def __init__(self, config):
173
- super().__init__()
174
- image_size, patch_size = config.image_size, config.patch_size
175
- num_channels, hidden_size = config.num_channels, config.hidden_size
176
-
177
- image_size = (
178
- image_size
179
- if isinstance(image_size, collections.abc.Iterable)
180
- else (image_size, image_size)
181
- )
182
- patch_size = (
183
- patch_size
184
- if isinstance(patch_size, collections.abc.Iterable)
185
- else (patch_size, patch_size)
186
- )
187
- num_patches = (image_size[1] // patch_size[1]) * (
188
- image_size[0] // patch_size[0]
189
- )
190
- self.image_size = image_size
191
- self.patch_size = patch_size
192
- self.num_channels = num_channels
193
- self.num_patches = num_patches
194
-
195
- self.projection = nn.Conv2d(
196
- num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
197
- )
198
-
199
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
200
- """
201
- num_channels = pixel_values.shape[1]
202
- if num_channels != self.num_channels:
203
- raise ValueError(
204
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
205
- f" Expected {self.num_channels} but got {num_channels}."
206
- )
207
- """
208
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
209
- return embeddings
210
-
211
-
212
- # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
213
- class Dinov2SelfAttention(nn.Module):
214
- def __init__(self, config: Dinov2Config) -> None:
215
- super().__init__()
216
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
217
- config, "embedding_size"
218
- ):
219
- raise ValueError(
220
- f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
221
- f"heads {config.num_attention_heads}."
222
- )
223
-
224
- self.num_attention_heads = config.num_attention_heads
225
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
226
- self.all_head_size = self.num_attention_heads * self.attention_head_size
227
- self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
228
-
229
- self.query = nn.Linear(
230
- config.hidden_size, self.all_head_size, bias=config.qkv_bias
231
- )
232
- self.key = nn.Linear(
233
- config.hidden_size, self.all_head_size, bias=config.qkv_bias
234
- )
235
- self.value = nn.Linear(
236
- config.hidden_size, self.all_head_size, bias=config.qkv_bias
237
- )
238
-
239
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
240
-
241
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
242
- new_x_shape = x.size()[:-1] + (
243
- self.num_attention_heads,
244
- self.attention_head_size,
245
- )
246
- x = x.view(new_x_shape)
247
- return x.permute(0, 2, 1, 3)
248
-
249
- def forward(
250
- self,
251
- hidden_states,
252
- head_mask: Optional[torch.Tensor] = None,
253
- output_attentions: bool = False,
254
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
255
- mixed_query_layer = self.query(hidden_states)
256
-
257
- if hasattr(F, "scaled_dot_product_attention"):
258
- assert head_mask is None and not output_attentions
259
- new_size = hidden_states.size()[:-1] + (
260
- self.num_attention_heads,
261
- self.attention_head_size,
262
- )
263
- key_layer = self.key(hidden_states).reshape(new_size).transpose(1, 2)
264
- value_layer = self.value(hidden_states).reshape(new_size).transpose(1, 2)
265
- query_layer = mixed_query_layer.reshape(new_size).transpose(1, 2)
266
- context_layer = F.scaled_dot_product_attention(
267
- query_layer,
268
- key_layer,
269
- value_layer,
270
- dropout_p=self.attention_probs_dropout_prob,
271
- is_causal=False,
272
- )
273
- context_layer = context_layer.transpose(1, 2).reshape(
274
- *hidden_states.size()[:-1], -1
275
- )
276
- else:
277
- key_layer = self.transpose_for_scores(self.key(hidden_states))
278
- value_layer = self.transpose_for_scores(self.value(hidden_states))
279
- query_layer = self.transpose_for_scores(mixed_query_layer)
280
-
281
- # Take the dot product between "query" and "key" to get the raw attention scores.
282
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
-
284
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
285
-
286
- # Normalize the attention scores to probabilities.
287
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
288
-
289
- # This is actually dropping out entire tokens to attend to, which might
290
- # seem a bit unusual, but is taken from the original Transformer paper.
291
- attention_probs = self.dropout(attention_probs)
292
-
293
- # Mask heads if we want to
294
- if head_mask is not None:
295
- attention_probs = attention_probs * head_mask
296
-
297
- context_layer = torch.matmul(attention_probs, value_layer)
298
-
299
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
300
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
301
- context_layer = context_layer.view(new_context_layer_shape)
302
-
303
- outputs = (
304
- (context_layer, attention_probs) if output_attentions else (context_layer,)
305
- )
306
-
307
- return outputs
308
-
309
-
310
- # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
311
- class Dinov2SelfOutput(nn.Module):
312
- """
313
- The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
314
- layernorm applied before each block.
315
- """
316
-
317
- def __init__(self, config: Dinov2Config) -> None:
318
- super().__init__()
319
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
320
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
-
322
- def forward(
323
- self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
324
- ) -> torch.Tensor:
325
- hidden_states = self.dense(hidden_states)
326
- hidden_states = self.dropout(hidden_states)
327
-
328
- return hidden_states
329
-
330
-
331
- # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
332
- class Dinov2Attention(nn.Module):
333
- def __init__(self, config: Dinov2Config) -> None:
334
- super().__init__()
335
- self.attention = Dinov2SelfAttention(config)
336
- self.output = Dinov2SelfOutput(config)
337
- self.pruned_heads = set()
338
-
339
- def prune_heads(self, heads: Set[int]) -> None:
340
- if len(heads) == 0:
341
- return
342
- heads, index = find_pruneable_heads_and_indices(
343
- heads,
344
- self.attention.num_attention_heads,
345
- self.attention.attention_head_size,
346
- self.pruned_heads,
347
- )
348
-
349
- # Prune linear layers
350
- self.attention.query = prune_linear_layer(self.attention.query, index)
351
- self.attention.key = prune_linear_layer(self.attention.key, index)
352
- self.attention.value = prune_linear_layer(self.attention.value, index)
353
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
354
-
355
- # Update hyper params and store pruned heads
356
- self.attention.num_attention_heads = self.attention.num_attention_heads - len(
357
- heads
358
- )
359
- self.attention.all_head_size = (
360
- self.attention.attention_head_size * self.attention.num_attention_heads
361
- )
362
- self.pruned_heads = self.pruned_heads.union(heads)
363
-
364
- def forward(
365
- self,
366
- hidden_states: torch.Tensor,
367
- head_mask: Optional[torch.Tensor] = None,
368
- output_attentions: bool = False,
369
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
370
- self_outputs = self.attention(hidden_states, head_mask, output_attentions)
371
-
372
- attention_output = self.output(self_outputs[0], hidden_states)
373
-
374
- outputs = (attention_output,) + self_outputs[
375
- 1:
376
- ] # add attentions if we output them
377
- return outputs
378
-
379
-
380
- class Dinov2LayerScale(nn.Module):
381
- def __init__(self, config) -> None:
382
- super().__init__()
383
- self.lambda1 = nn.Parameter(
384
- config.layerscale_value * torch.ones(config.hidden_size)
385
- )
386
-
387
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
388
- return hidden_state * self.lambda1
389
-
390
-
391
- # Copied from transformers.models.beit.modeling_beit.drop_path
392
- def drop_path(
393
- input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
394
- ) -> torch.Tensor:
395
- """
396
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
397
-
398
- Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
399
- however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
400
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
401
- layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
402
- argument.
403
- """
404
- if drop_prob == 0.0 or not training:
405
- return input
406
- keep_prob = 1 - drop_prob
407
- shape = (input.shape[0],) + (1,) * (
408
- input.ndim - 1
409
- ) # work with diff dim tensors, not just 2D ConvNets
410
- random_tensor = keep_prob + torch.rand(
411
- shape, dtype=input.dtype, device=input.device
412
- )
413
- random_tensor.floor_() # binarize
414
- output = input.div(keep_prob) * random_tensor
415
- return output
416
-
417
-
418
- # Copied from transformers.models.beit.modeling_beit.BeitDropPath
419
- class Dinov2DropPath(nn.Module):
420
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
421
-
422
- def __init__(self, drop_prob: Optional[float] = None) -> None:
423
- super().__init__()
424
- self.drop_prob = drop_prob
425
-
426
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
427
- return drop_path(hidden_states, self.drop_prob, self.training)
428
-
429
- def extra_repr(self) -> str:
430
- return "p={}".format(self.drop_prob)
431
-
432
-
433
- class Dinov2MLP(nn.Module):
434
- def __init__(self, config) -> None:
435
- super().__init__()
436
- in_features = out_features = config.hidden_size
437
- hidden_features = int(config.hidden_size * config.mlp_ratio)
438
- self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
439
- if isinstance(config.hidden_act, str):
440
- self.activation = ACT2FN[config.hidden_act]
441
- else:
442
- self.activation = config.hidden_act
443
- self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
444
-
445
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
446
- hidden_state = self.fc1(hidden_state)
447
- hidden_state = self.activation(hidden_state)
448
- hidden_state = self.fc2(hidden_state)
449
- return hidden_state
450
-
451
-
452
- class Dinov2SwiGLUFFN(nn.Module):
453
- def __init__(self, config) -> None:
454
- super().__init__()
455
- in_features = out_features = config.hidden_size
456
- hidden_features = int(config.hidden_size * config.mlp_ratio)
457
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
458
-
459
- self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
460
- self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
461
-
462
- def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
463
- hidden_state = self.weights_in(hidden_state)
464
- x1, x2 = hidden_state.chunk(2, dim=-1)
465
- hidden = nn.functional.silu(x1) * x2
466
- return self.weights_out(hidden)
467
-
468
-
469
- class Dinov2Layer(nn.Module):
470
- """This corresponds to the Block class in the original implementation."""
471
-
472
- def __init__(self, config: Dinov2Config) -> None:
473
- super().__init__()
474
-
475
- self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
- self.norm1_modulation = None
477
- self.attention = Dinov2Attention(config)
478
- self.layer_scale1 = Dinov2LayerScale(config)
479
- self.drop_path1 = (
480
- Dinov2DropPath(config.drop_path_rate)
481
- if config.drop_path_rate > 0.0
482
- else nn.Identity()
483
- )
484
-
485
- self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
486
- self.norm2_modulation = None
487
-
488
- if config.use_swiglu_ffn:
489
- self.mlp = Dinov2SwiGLUFFN(config)
490
- else:
491
- self.mlp = Dinov2MLP(config)
492
- self.layer_scale2 = Dinov2LayerScale(config)
493
- self.drop_path2 = (
494
- Dinov2DropPath(config.drop_path_rate)
495
- if config.drop_path_rate > 0.0
496
- else nn.Identity()
497
- )
498
-
499
- def forward(
500
- self,
501
- hidden_states: torch.Tensor,
502
- head_mask: Optional[torch.Tensor] = None,
503
- modulation_cond: Optional[torch.Tensor] = None,
504
- output_attentions: bool = False,
505
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
506
- hidden_states_norm = self.norm1(hidden_states)
507
- if self.norm1_modulation is not None:
508
- assert modulation_cond is not None
509
- hidden_states_norm = self.norm1_modulation(
510
- hidden_states_norm, modulation_cond
511
- )
512
- self_attention_outputs = self.attention(
513
- hidden_states_norm, # in Dinov2, layernorm is applied before self-attention
514
- head_mask,
515
- output_attentions=output_attentions,
516
- )
517
- attention_output = self_attention_outputs[0]
518
-
519
- attention_output = self.layer_scale1(attention_output)
520
- outputs = self_attention_outputs[
521
- 1:
522
- ] # add self attentions if we output attention weights
523
-
524
- # first residual connection
525
- hidden_states = attention_output + hidden_states
526
-
527
- # in Dinov2, layernorm is also applied after self-attention
528
- layer_output = self.norm2(hidden_states)
529
- if self.norm2_modulation is not None:
530
- assert modulation_cond is not None
531
- layer_output = self.norm2_modulation(layer_output, modulation_cond)
532
- layer_output = self.mlp(layer_output)
533
- layer_output = self.layer_scale2(layer_output)
534
-
535
- # second residual connection
536
- layer_output = layer_output + hidden_states
537
-
538
- outputs = (layer_output,) + outputs
539
-
540
- return outputs
541
-
542
- def register_ada_norm_modulation(self, norm1_mod: nn.Module, norm2_mod: nn.Module):
543
- self.norm1_modulation = norm1_mod
544
- self.norm2_modulation = norm2_mod
545
-
546
-
547
- # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
548
- class Dinov2Encoder(nn.Module):
549
- def __init__(self, config: Dinov2Config) -> None:
550
- super().__init__()
551
- self.config = config
552
- self.layer = nn.ModuleList(
553
- [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
554
- )
555
- self.gradient_checkpointing = False
556
-
557
- def forward(
558
- self,
559
- hidden_states: torch.Tensor,
560
- head_mask: Optional[torch.Tensor] = None,
561
- modulation_cond: Optional[torch.Tensor] = None,
562
- output_attentions: bool = False,
563
- output_hidden_states: bool = False,
564
- return_dict: bool = True,
565
- ) -> Union[tuple, BaseModelOutput]:
566
- all_hidden_states = () if output_hidden_states else None
567
- all_self_attentions = () if output_attentions else None
568
-
569
- for i, layer_module in enumerate(self.layer):
570
- if output_hidden_states:
571
- all_hidden_states = all_hidden_states + (hidden_states,)
572
-
573
- layer_head_mask = head_mask[i] if head_mask is not None else None
574
-
575
- if self.gradient_checkpointing and self.training:
576
-
577
- def create_custom_forward(module):
578
- def custom_forward(*inputs):
579
- return module(*inputs, output_attentions)
580
-
581
- return custom_forward
582
-
583
- layer_outputs = torch.utils.checkpoint.checkpoint(
584
- create_custom_forward(layer_module),
585
- hidden_states,
586
- layer_head_mask,
587
- modulation_cond,
588
- use_reentrant=False,
589
- )
590
- else:
591
- layer_outputs = layer_module(
592
- hidden_states, layer_head_mask, modulation_cond, output_attentions
593
- )
594
-
595
- hidden_states = layer_outputs[0]
596
-
597
- if output_attentions:
598
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
599
-
600
- if output_hidden_states:
601
- all_hidden_states = all_hidden_states + (hidden_states,)
602
-
603
- if not return_dict:
604
- return tuple(
605
- v
606
- for v in [hidden_states, all_hidden_states, all_self_attentions]
607
- if v is not None
608
- )
609
- return BaseModelOutput(
610
- last_hidden_state=hidden_states,
611
- hidden_states=all_hidden_states,
612
- attentions=all_self_attentions,
613
- )
614
-
615
-
616
- class Dinov2PreTrainedModel(PreTrainedModel):
617
- """
618
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
619
- models.
620
- """
621
-
622
- config_class = Dinov2Config
623
- base_model_prefix = "dinov2"
624
- main_input_name = "pixel_values"
625
- supports_gradient_checkpointing = True
626
-
627
- def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
628
- """Initialize the weights"""
629
- if isinstance(module, (nn.Linear, nn.Conv2d)):
630
- # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
631
- # `trunc_normal_cpu` not implemented in `half` issues
632
- module.weight.data = nn.init.trunc_normal_(
633
- module.weight.data.to(torch.float32),
634
- mean=0.0,
635
- std=self.config.initializer_range,
636
- ).to(module.weight.dtype)
637
- if module.bias is not None:
638
- module.bias.data.zero_()
639
- elif isinstance(module, nn.LayerNorm):
640
- module.bias.data.zero_()
641
- module.weight.data.fill_(1.0)
642
- elif isinstance(module, Dinov2Embeddings):
643
- module.position_embeddings.data = nn.init.trunc_normal_(
644
- module.position_embeddings.data.to(torch.float32),
645
- mean=0.0,
646
- std=self.config.initializer_range,
647
- ).to(module.position_embeddings.dtype)
648
-
649
- module.cls_token.data = nn.init.trunc_normal_(
650
- module.cls_token.data.to(torch.float32),
651
- mean=0.0,
652
- std=self.config.initializer_range,
653
- ).to(module.cls_token.dtype)
654
-
655
- def _set_gradient_checkpointing(
656
- self, module: Dinov2Encoder, value: bool = False
657
- ) -> None:
658
- if isinstance(module, Dinov2Encoder):
659
- module.gradient_checkpointing = value
660
-
661
-
662
- DINOV2_START_DOCSTRING = r"""
663
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
664
- as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
665
- behavior.
666
-
667
- Parameters:
668
- config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
669
- Initializing with a config file does not load the weights associated with the model, only the
670
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
- """
672
-
673
- DINOV2_BASE_INPUTS_DOCSTRING = r"""
674
- Args:
675
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
676
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
677
- [`BitImageProcessor.preprocess`] for details.
678
-
679
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
680
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
681
- pre-training.
682
-
683
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
684
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
685
-
686
- - 1 indicates the head is **not masked**,
687
- - 0 indicates the head is **masked**.
688
-
689
- output_attentions (`bool`, *optional*):
690
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
691
- tensors for more detail.
692
- output_hidden_states (`bool`, *optional*):
693
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
694
- more detail.
695
- return_dict (`bool`, *optional*):
696
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
697
- """
698
-
699
- DINOV2_INPUTS_DOCSTRING = r"""
700
- Args:
701
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
702
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
703
- [`BitImageProcessor.preprocess`] for details.
704
-
705
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
706
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
707
-
708
- - 1 indicates the head is **not masked**,
709
- - 0 indicates the head is **masked**.
710
-
711
- output_attentions (`bool`, *optional*):
712
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
713
- tensors for more detail.
714
- output_hidden_states (`bool`, *optional*):
715
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
716
- more detail.
717
- return_dict (`bool`, *optional*):
718
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
719
- """
720
-
721
-
722
- @dataclass
723
- class CustomBaseModelOutputWithPooling(BaseModelOutputWithPooling):
724
- patch_embeddings: Optional[torch.FloatTensor] = None
725
-
726
-
727
- @add_start_docstrings(
728
- "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
729
- DINOV2_START_DOCSTRING,
730
- )
731
- class Dinov2Model(Dinov2PreTrainedModel):
732
- def __init__(self, config: Dinov2Config):
733
- super().__init__(config)
734
- self.config = config
735
-
736
- self.embeddings = Dinov2Embeddings(config)
737
- self.encoder = Dinov2Encoder(config)
738
-
739
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
740
-
741
- # Initialize weights and apply final processing
742
- self.post_init()
743
-
744
- def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
745
- return self.embeddings.patch_embeddings
746
-
747
- def expand_input_channels(self, extra_input_channels: int) -> None:
748
- if extra_input_channels == 0:
749
- return
750
- conv_old = self.embeddings.patch_embeddings.projection
751
- conv_new = nn.Conv2d(
752
- self.config.num_channels + extra_input_channels,
753
- self.config.hidden_size,
754
- kernel_size=self.config.patch_size,
755
- stride=self.config.patch_size,
756
- ).to(self.device)
757
- with torch.no_grad():
758
- conv_new.weight[:, :3] = conv_old.weight
759
- conv_new.bias = conv_old.bias
760
- self.embeddings.patch_embeddings.projection = conv_new
761
- del conv_old
762
-
763
- def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
764
- """
765
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
766
- class PreTrainedModel
767
- """
768
- for layer, heads in heads_to_prune.items():
769
- self.encoder.layer[layer].attention.prune_heads(heads)
770
-
771
- @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
772
- @add_code_sample_docstrings(
773
- checkpoint=_CHECKPOINT_FOR_DOC,
774
- output_type=BaseModelOutputWithPooling,
775
- config_class=_CONFIG_FOR_DOC,
776
- modality="vision",
777
- expected_output=_EXPECTED_OUTPUT_SHAPE,
778
- )
779
- def forward(
780
- self,
781
- pixel_values: Optional[torch.Tensor] = None,
782
- bool_masked_pos: Optional[torch.Tensor] = None,
783
- head_mask: Optional[torch.Tensor] = None,
784
- modulation_cond: Optional[torch.Tensor] = None,
785
- output_attentions: Optional[bool] = None,
786
- output_hidden_states: Optional[bool] = None,
787
- return_dict: Optional[bool] = None,
788
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
789
- output_attentions = (
790
- output_attentions
791
- if output_attentions is not None
792
- else self.config.output_attentions
793
- )
794
- output_hidden_states = (
795
- output_hidden_states
796
- if output_hidden_states is not None
797
- else self.config.output_hidden_states
798
- )
799
- return_dict = (
800
- return_dict if return_dict is not None else self.config.use_return_dict
801
- )
802
-
803
- if pixel_values is None:
804
- raise ValueError("You have to specify pixel_values")
805
-
806
- # Prepare head mask if needed
807
- # 1.0 in head_mask indicate we keep the head
808
- # attention_probs has shape bsz x n_heads x N x N
809
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
810
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
811
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
812
-
813
- embedding_output = self.embeddings(
814
- pixel_values, bool_masked_pos=bool_masked_pos
815
- )
816
-
817
- encoder_outputs = self.encoder(
818
- embedding_output,
819
- head_mask=head_mask,
820
- modulation_cond=modulation_cond,
821
- output_attentions=output_attentions,
822
- output_hidden_states=output_hidden_states,
823
- return_dict=return_dict,
824
- )
825
- sequence_output = encoder_outputs[0]
826
- sequence_output = self.layernorm(sequence_output)
827
- pooled_output = sequence_output[:, 0, :]
828
-
829
- if not return_dict:
830
- head_outputs = (sequence_output, pooled_output)
831
- return head_outputs + encoder_outputs[1:]
832
-
833
- return CustomBaseModelOutputWithPooling(
834
- last_hidden_state=sequence_output,
835
- pooler_output=pooled_output,
836
- hidden_states=encoder_outputs.hidden_states,
837
- attentions=encoder_outputs.attentions,
838
- patch_embeddings=embedding_output,
839
- )
840
-
841
- def set_gradient_checkpointing(self, value: bool = False) -> None:
842
- self._set_gradient_checkpointing(self.encoder, value)
843
-
844
-
845
- @add_start_docstrings(
846
- """
847
- Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
848
- of the [CLS] token) e.g. for ImageNet.
849
- """,
850
- DINOV2_START_DOCSTRING,
851
- )
852
- class Dinov2ForImageClassification(Dinov2PreTrainedModel):
853
- def __init__(self, config: Dinov2Config) -> None:
854
- super().__init__(config)
855
-
856
- self.num_labels = config.num_labels
857
- self.dinov2 = Dinov2Model(config)
858
-
859
- # Classifier head
860
- self.classifier = (
861
- nn.Linear(config.hidden_size * 2, config.num_labels)
862
- if config.num_labels > 0
863
- else nn.Identity()
864
- )
865
-
866
- # Initialize weights and apply final processing
867
- self.post_init()
868
-
869
- @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
870
- @add_code_sample_docstrings(
871
- checkpoint=_IMAGE_CLASS_CHECKPOINT,
872
- output_type=ImageClassifierOutput,
873
- config_class=_CONFIG_FOR_DOC,
874
- )
875
- def forward(
876
- self,
877
- pixel_values: Optional[torch.Tensor] = None,
878
- head_mask: Optional[torch.Tensor] = None,
879
- labels: Optional[torch.Tensor] = None,
880
- output_attentions: Optional[bool] = None,
881
- output_hidden_states: Optional[bool] = None,
882
- return_dict: Optional[bool] = None,
883
- ) -> Union[tuple, ImageClassifierOutput]:
884
- r"""
885
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
886
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
887
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
888
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
889
- """
890
- return_dict = (
891
- return_dict if return_dict is not None else self.config.use_return_dict
892
- )
893
-
894
- outputs = self.dinov2(
895
- pixel_values,
896
- head_mask=head_mask,
897
- output_attentions=output_attentions,
898
- output_hidden_states=output_hidden_states,
899
- return_dict=return_dict,
900
- )
901
-
902
- sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
903
-
904
- cls_token = sequence_output[:, 0]
905
- patch_tokens = sequence_output[:, 1:]
906
-
907
- linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
908
-
909
- logits = self.classifier(linear_input)
910
-
911
- loss = None
912
- if labels is not None:
913
- # move labels to correct device to enable model parallelism
914
- labels = labels.to(logits.device)
915
- if self.config.problem_type is None:
916
- if self.num_labels == 1:
917
- self.config.problem_type = "regression"
918
- elif self.num_labels > 1 and (
919
- labels.dtype == torch.long or labels.dtype == torch.int
920
- ):
921
- self.config.problem_type = "single_label_classification"
922
- else:
923
- self.config.problem_type = "multi_label_classification"
924
-
925
- if self.config.problem_type == "regression":
926
- loss_fct = MSELoss()
927
- if self.num_labels == 1:
928
- loss = loss_fct(logits.squeeze(), labels.squeeze())
929
- else:
930
- loss = loss_fct(logits, labels)
931
- elif self.config.problem_type == "single_label_classification":
932
- loss_fct = CrossEntropyLoss()
933
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
934
- elif self.config.problem_type == "multi_label_classification":
935
- loss_fct = BCEWithLogitsLoss()
936
- loss = loss_fct(logits, labels)
937
-
938
- if not return_dict:
939
- output = (logits,) + outputs[2:]
940
- return ((loss,) + output) if loss is not None else output
941
-
942
- return ImageClassifierOutput(
943
- loss=loss,
944
- logits=logits,
945
- hidden_states=outputs.hidden_states,
946
- attentions=outputs.attentions,
947
- )
948
-
949
-
950
- @add_start_docstrings(
951
- """
952
- Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
953
- """,
954
- DINOV2_START_DOCSTRING,
955
- )
956
- class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
957
- def __init__(self, config):
958
- super().__init__(config)
959
- super()._init_backbone(config)
960
-
961
- self.num_features = [
962
- config.hidden_size for _ in range(config.num_hidden_layers + 1)
963
- ]
964
- self.embeddings = Dinov2Embeddings(config)
965
- self.encoder = Dinov2Encoder(config)
966
-
967
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
968
-
969
- # Initialize weights and apply final processing
970
- self.post_init()
971
-
972
- def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
973
- return self.embeddings.patch_embeddings
974
-
975
- @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
976
- @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
977
- def forward(
978
- self,
979
- pixel_values: torch.Tensor,
980
- output_hidden_states: Optional[bool] = None,
981
- output_attentions: Optional[bool] = None,
982
- return_dict: Optional[bool] = None,
983
- ) -> BackboneOutput:
984
- """
985
- Returns:
986
-
987
- Examples:
988
-
989
- ```python
990
- >>> from transformers import AutoImageProcessor, AutoBackbone
991
- >>> import torch
992
- >>> from PIL import Image
993
- >>> import requests
994
-
995
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
996
- >>> image = Image.open(requests.get(url, stream=True).raw)
997
-
998
- >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
999
- >>> model = AutoBackbone.from_pretrained(
1000
- ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1001
- ... )
1002
-
1003
- >>> inputs = processor(image, return_tensors="pt")
1004
-
1005
- >>> outputs = model(**inputs)
1006
- >>> feature_maps = outputs.feature_maps
1007
- >>> list(feature_maps[-1].shape)
1008
- [1, 768, 16, 16]
1009
- ```"""
1010
- return_dict = (
1011
- return_dict if return_dict is not None else self.config.use_return_dict
1012
- )
1013
- output_hidden_states = (
1014
- output_hidden_states
1015
- if output_hidden_states is not None
1016
- else self.config.output_hidden_states
1017
- )
1018
- output_attentions = (
1019
- output_attentions
1020
- if output_attentions is not None
1021
- else self.config.output_attentions
1022
- )
1023
-
1024
- embedding_output = self.embeddings(pixel_values)
1025
-
1026
- outputs = self.encoder(
1027
- embedding_output,
1028
- output_hidden_states=True,
1029
- output_attentions=output_attentions,
1030
- return_dict=return_dict,
1031
- )
1032
-
1033
- hidden_states = outputs.hidden_states if return_dict else outputs[1]
1034
-
1035
- feature_maps = ()
1036
- for stage, hidden_state in zip(self.stage_names, hidden_states):
1037
- if stage in self.out_features:
1038
- if self.config.apply_layernorm:
1039
- hidden_state = self.layernorm(hidden_state)
1040
- if self.config.reshape_hidden_states:
1041
- batch_size, _, height, width = pixel_values.shape
1042
- patch_size = self.config.patch_size
1043
- hidden_state = hidden_state[:, 1:, :].reshape(
1044
- batch_size, width // patch_size, height // patch_size, -1
1045
- )
1046
- hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1047
- feature_maps += (hidden_state,)
1048
-
1049
- if not return_dict:
1050
- if output_hidden_states:
1051
- output = (feature_maps,) + outputs[1:]
1052
- else:
1053
- output = (feature_maps,) + outputs[2:]
1054
- return output
1055
-
1056
- return BackboneOutput(
1057
- feature_maps=feature_maps,
1058
- hidden_states=outputs.hidden_states if output_hidden_states else None,
1059
- attentions=outputs.attentions if output_attentions else None,
1060
- )
1061
-
1062
-
1063
- class CustomPatchEmbeddings(nn.Module):
1064
- """
1065
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
1066
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
1067
- Transformer.
1068
- """
1069
-
1070
- def __init__(
1071
- self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1072
- ):
1073
- super().__init__()
1074
-
1075
- image_size = (
1076
- image_size
1077
- if isinstance(image_size, collections.abc.Iterable)
1078
- else (image_size, image_size)
1079
- )
1080
- patch_size = (
1081
- patch_size
1082
- if isinstance(patch_size, collections.abc.Iterable)
1083
- else (patch_size, patch_size)
1084
- )
1085
- num_patches = (image_size[1] // patch_size[1]) * (
1086
- image_size[0] // patch_size[0]
1087
- )
1088
- self.image_size = image_size
1089
- self.patch_size = patch_size
1090
- self.num_channels = num_channels
1091
- self.num_patches = num_patches
1092
-
1093
- self.projection = nn.Conv2d(
1094
- num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
1095
- )
1096
-
1097
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
1098
- num_channels = pixel_values.shape[1]
1099
- if num_channels != self.num_channels:
1100
- raise ValueError(
1101
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
1102
- f" Expected {self.num_channels} but got {num_channels}."
1103
- )
1104
- embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
1105
- return embeddings
1106
-
1107
-
1108
- class CustomEmbeddings(nn.Module):
1109
- """
1110
- Construct the CLS token, mask token, position and patch embeddings.
1111
- """
1112
-
1113
- def __init__(
1114
- self, image_size: int, patch_size: int, num_channels: int, hidden_size: int
1115
- ) -> None:
1116
- super().__init__()
1117
-
1118
- self.image_size = image_size
1119
- self.patch_size = patch_size
1120
- self.num_channels = num_channels
1121
- self.hidden_size = hidden_size
1122
-
1123
- self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
1124
-
1125
- self.patch_embeddings = CustomPatchEmbeddings(
1126
- image_size, patch_size, num_channels, hidden_size
1127
- )
1128
- num_patches = self.patch_embeddings.num_patches
1129
- self.position_embeddings = nn.Parameter(
1130
- torch.randn(1, num_patches + 1, self.hidden_size)
1131
- )
1132
-
1133
- def interpolate_pos_encoding(
1134
- self, embeddings: torch.Tensor, height: int, width: int
1135
- ) -> torch.Tensor:
1136
- """
1137
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
1138
- resolution images.
1139
-
1140
- Source:
1141
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1142
- """
1143
-
1144
- num_patches = embeddings.shape[1] - 1
1145
- num_positions = self.position_embeddings.shape[1] - 1
1146
- if num_patches == num_positions and height == width:
1147
- return self.position_embeddings
1148
- class_pos_embed = self.position_embeddings[:, 0]
1149
- patch_pos_embed = self.position_embeddings[:, 1:]
1150
- dim = embeddings.shape[-1]
1151
- height = height // self.patch_size
1152
- width = width // self.patch_size
1153
- # we add a small number to avoid floating point error in the interpolation
1154
- # see discussion at https://github.com/facebookresearch/dino/issues/8
1155
- height, width = height + 0.1, width + 0.1
1156
- patch_pos_embed = patch_pos_embed.reshape(
1157
- 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
1158
- )
1159
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
1160
- patch_pos_embed = nn.functional.interpolate(
1161
- patch_pos_embed,
1162
- scale_factor=(
1163
- height / math.sqrt(num_positions),
1164
- width / math.sqrt(num_positions),
1165
- ),
1166
- mode="bicubic",
1167
- align_corners=False,
1168
- )
1169
- if (
1170
- int(height) != patch_pos_embed.shape[-2]
1171
- or int(width) != patch_pos_embed.shape[-1]
1172
- ):
1173
- raise ValueError(
1174
- "Width or height does not match with the interpolated position embeddings"
1175
- )
1176
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
1177
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
1178
-
1179
- def forward(
1180
- self,
1181
- pixel_values: torch.Tensor,
1182
- ) -> torch.Tensor:
1183
- batch_size, _, height, width = pixel_values.shape
1184
- patch_embeddings = self.patch_embeddings(pixel_values)
1185
- embeddings = patch_embeddings
1186
-
1187
- # add the [CLS] token to the embedded patch tokens
1188
- cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1189
- embeddings = torch.cat((cls_tokens, embeddings), dim=1)
1190
-
1191
- # add positional encoding to each token
1192
- embeddings = embeddings + self.interpolate_pos_encoding(
1193
- embeddings, height, width
1194
- )
1195
-
1196
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/tokenizers/image.py DELETED
@@ -1,99 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- from einops import rearrange
7
- from jaxtyping import Float
8
- from torch import Tensor
9
-
10
- from sf3d.models.tokenizers.dinov2 import Dinov2Model
11
- from sf3d.models.transformers.attention import Modulation
12
- from sf3d.models.utils import BaseModule
13
-
14
-
15
- class DINOV2SingleImageTokenizer(BaseModule):
16
- @dataclass
17
- class Config(BaseModule.Config):
18
- pretrained_model_name_or_path: str = "facebook/dinov2-large"
19
- width: int = 512
20
- height: int = 512
21
- modulation_cond_dim: int = 768
22
-
23
- cfg: Config
24
-
25
- def configure(self) -> None:
26
- self.model = Dinov2Model.from_pretrained(self.cfg.pretrained_model_name_or_path)
27
-
28
- for p in self.model.parameters():
29
- p.requires_grad_(False)
30
- self.model.eval()
31
-
32
- self.model.set_gradient_checkpointing(False)
33
-
34
- # add modulation
35
- modulations = []
36
- for layer in self.model.encoder.layer:
37
- norm1_modulation = Modulation(
38
- self.model.config.hidden_size,
39
- self.cfg.modulation_cond_dim,
40
- zero_init=True,
41
- single_layer=True,
42
- )
43
- norm2_modulation = Modulation(
44
- self.model.config.hidden_size,
45
- self.cfg.modulation_cond_dim,
46
- zero_init=True,
47
- single_layer=True,
48
- )
49
- layer.register_ada_norm_modulation(norm1_modulation, norm2_modulation)
50
- modulations += [norm1_modulation, norm2_modulation]
51
- self.modulations = nn.ModuleList(modulations)
52
-
53
- self.register_buffer(
54
- "image_mean",
55
- torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
56
- persistent=False,
57
- )
58
- self.register_buffer(
59
- "image_std",
60
- torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
61
- persistent=False,
62
- )
63
-
64
- def forward(
65
- self,
66
- images: Float[Tensor, "B *N C H W"],
67
- modulation_cond: Optional[Float[Tensor, "B *N Cc"]],
68
- **kwargs,
69
- ) -> Float[Tensor, "B *N Ct Nt"]:
70
- model = self.model
71
-
72
- packed = False
73
- if images.ndim == 4:
74
- packed = True
75
- images = images.unsqueeze(1)
76
- if modulation_cond is not None:
77
- assert modulation_cond.ndim == 2
78
- modulation_cond = modulation_cond.unsqueeze(1)
79
-
80
- batch_size, n_input_views = images.shape[:2]
81
- images = (images - self.image_mean) / self.image_std
82
- out = model(
83
- rearrange(images, "B N C H W -> (B N) C H W"),
84
- modulation_cond=rearrange(modulation_cond, "B N Cc -> (B N) Cc")
85
- if modulation_cond is not None
86
- else None,
87
- )
88
- local_features = out.last_hidden_state
89
- local_features = local_features.permute(0, 2, 1)
90
- local_features = rearrange(
91
- local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
92
- )
93
- if packed:
94
- local_features = local_features.squeeze(1)
95
-
96
- return local_features
97
-
98
- def detokenize(self, *args, **kwargs):
99
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/tokenizers/triplane.py DELETED
@@ -1,49 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- import torch.nn as nn
6
- from einops import rearrange, repeat
7
- from jaxtyping import Float
8
- from torch import Tensor
9
-
10
- from sf3d.models.utils import BaseModule
11
-
12
-
13
- class TriplaneLearnablePositionalEmbedding(BaseModule):
14
- @dataclass
15
- class Config(BaseModule.Config):
16
- plane_size: int = 96
17
- num_channels: int = 1024
18
-
19
- cfg: Config
20
-
21
- def configure(self) -> None:
22
- self.embeddings = nn.Parameter(
23
- torch.randn(
24
- (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
25
- dtype=torch.float32,
26
- )
27
- * 1
28
- / math.sqrt(self.cfg.num_channels)
29
- )
30
-
31
- def forward(self, batch_size: int) -> Float[Tensor, "B Ct Nt"]:
32
- return rearrange(
33
- repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
34
- "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
35
- )
36
-
37
- def detokenize(
38
- self, tokens: Float[Tensor, "B Ct Nt"]
39
- ) -> Float[Tensor, "B 3 Ct Hp Wp"]:
40
- batch_size, Ct, Nt = tokens.shape
41
- assert Nt == self.cfg.plane_size**2 * 3
42
- assert Ct == self.cfg.num_channels
43
- return rearrange(
44
- tokens,
45
- "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
46
- Np=3,
47
- Hp=self.cfg.plane_size,
48
- Wp=self.cfg.plane_size,
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/transformers/attention.py DELETED
@@ -1,31 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class Modulation(nn.Module):
6
- def __init__(
7
- self,
8
- embedding_dim: int,
9
- condition_dim: int,
10
- zero_init: bool = False,
11
- single_layer: bool = False,
12
- ):
13
- super().__init__()
14
- self.silu = nn.SiLU()
15
- if single_layer:
16
- self.linear1 = nn.Identity()
17
- else:
18
- self.linear1 = nn.Linear(condition_dim, condition_dim)
19
-
20
- self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
21
-
22
- # Only zero init the last linear layer
23
- if zero_init:
24
- nn.init.zeros_(self.linear2.weight)
25
- nn.init.zeros_(self.linear2.bias)
26
-
27
- def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
28
- emb = self.linear2(self.silu(self.linear1(condition)))
29
- scale, shift = torch.chunk(emb, 2, dim=1)
30
- x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
31
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/transformers/backbone.py DELETED
@@ -1,515 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn.functional as F
6
- from torch import nn
7
-
8
- from sf3d.models.utils import BaseModule
9
-
10
-
11
- class GEGLU(nn.Module):
12
- r"""
13
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
14
-
15
- Parameters:
16
- dim_in (`int`): The number of channels in the input.
17
- dim_out (`int`): The number of channels in the output.
18
- """
19
-
20
- def __init__(self, dim_in: int, dim_out: int):
21
- super().__init__()
22
- self.proj = nn.Linear(dim_in, dim_out * 2)
23
-
24
- def gelu(self, gate: torch.Tensor) -> torch.Tensor:
25
- if gate.device.type != "mps":
26
- return F.gelu(gate)
27
- # mps: gelu is not implemented for float16
28
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
29
-
30
- def forward(self, hidden_states, scale: float = 1.0):
31
- args = ()
32
- hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
33
- return hidden_states * self.gelu(gate)
34
-
35
-
36
- class CrossAttention(nn.Module):
37
- def __init__(
38
- self,
39
- dim,
40
- kv_dim=None,
41
- num_heads=16,
42
- qkv_bias=False,
43
- attn_drop=0.0,
44
- proj_drop=0.0,
45
- ):
46
- super().__init__()
47
- self.num_heads = num_heads
48
- head_dim = dim // num_heads
49
- self.scale = head_dim**-0.5
50
- kv_dim = dim if not kv_dim else kv_dim
51
- self.wq = nn.Linear(dim, dim, bias=qkv_bias)
52
- self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
53
- self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
54
- self.attn_drop = attn_drop
55
- self.proj = nn.Linear(dim, dim)
56
- self.proj_drop = nn.Dropout(proj_drop)
57
-
58
- def forward(self, x_q, x_kv):
59
- B, N_q, C = x_q.shape
60
- B, N_kv, _ = x_kv.shape
61
- # [B, N_q, C] -> [B, N_q, H, C/H]
62
- q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads)
63
- # [B, N_kv, C] -> [B, N_kv, H, C/H]
64
- k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
65
- v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads)
66
-
67
- # attention
68
- x = torch.nn.functional.scaled_dot_product_attention(
69
- q.permute(0, 2, 1, 3),
70
- k.permute(0, 2, 1, 3),
71
- v.permute(0, 2, 1, 3),
72
- attn_mask=None,
73
- dropout_p=self.attn_drop,
74
- scale=self.scale,
75
- ).permute(0, 2, 1, 3)
76
-
77
- # [B, N_q, H, C/H] -> [B, N_q, C]
78
- x = x.reshape(B, N_q, C)
79
- x = self.proj(x)
80
- x = self.proj_drop(x)
81
- return x
82
-
83
-
84
- class FeedForward(nn.Module):
85
- def __init__(
86
- self,
87
- dim: int,
88
- dim_out: Optional[int] = None,
89
- mult: int = 4,
90
- dropout: float = 0.0,
91
- ):
92
- super().__init__()
93
- inner_dim = int(dim * mult)
94
- dim_out = dim_out if dim_out is not None else dim
95
- act_fn = GEGLU(dim, inner_dim)
96
- self.net = nn.ModuleList([])
97
- self.net.append(act_fn)
98
- self.net.append(nn.Dropout(dropout))
99
- self.net.append(nn.Linear(inner_dim, dim_out))
100
-
101
- def forward(self, x: torch.Tensor) -> torch.Tensor:
102
- for module in self.net:
103
- x = module(x)
104
- return x
105
-
106
-
107
- class BasicBlock(nn.Module):
108
- def __init__(
109
- self,
110
- dim: int,
111
- kv_dim: Optional[int] = None,
112
- num_heads: int = 16,
113
- qkv_bias: bool = False,
114
- attn_drop: float = 0.0,
115
- proj_drop: float = 0.0,
116
- ff_drop: float = 0.0,
117
- ):
118
- super().__init__()
119
- self.norm1 = nn.LayerNorm(dim)
120
- self.attn1 = CrossAttention(
121
- dim,
122
- kv_dim=dim,
123
- num_heads=num_heads,
124
- qkv_bias=qkv_bias,
125
- attn_drop=attn_drop,
126
- proj_drop=proj_drop,
127
- )
128
- self.norm2 = nn.LayerNorm(dim)
129
- self.attn2 = CrossAttention(
130
- dim,
131
- kv_dim=kv_dim,
132
- num_heads=num_heads,
133
- qkv_bias=qkv_bias,
134
- attn_drop=attn_drop,
135
- proj_drop=proj_drop,
136
- )
137
- self.norm3 = nn.LayerNorm(dim)
138
- self.ff = FeedForward(dim, dropout=ff_drop)
139
-
140
- def forward(self, z, x):
141
- z_norm = self.norm1(z)
142
- z = z + self.attn1(z_norm, z_norm)
143
- # TODO: do we need to have the second attention when x is None?
144
- z_norm = self.norm2(z)
145
- z = z + self.attn2(z_norm, x if x is not None else z_norm)
146
- z_norm = self.norm3(z)
147
- z = z + self.ff(z_norm)
148
- return z
149
-
150
-
151
- class SingleStreamTransformer(BaseModule):
152
- @dataclass
153
- class Config(BaseModule.Config):
154
- num_attention_heads: int = 16
155
- attention_head_dim: int = 88
156
- in_channels: Optional[int] = None
157
- out_channels: Optional[int] = None
158
- num_layers: int = 16
159
- dropout: float = 0.0
160
- norm_num_groups: int = 32
161
- cross_attention_dim: Optional[int] = None
162
- attention_bias: bool = False
163
-
164
- cfg: Config
165
-
166
- def configure(self) -> None:
167
- self.num_attention_heads = self.cfg.num_attention_heads
168
- self.attention_head_dim = self.cfg.attention_head_dim
169
- inner_dim = self.num_attention_heads * self.attention_head_dim
170
-
171
- # Define input layers
172
- self.norm = torch.nn.GroupNorm(
173
- num_groups=self.cfg.norm_num_groups,
174
- num_channels=self.cfg.in_channels,
175
- eps=1e-6,
176
- affine=True,
177
- )
178
- self.proj_in = nn.Linear(self.cfg.in_channels, inner_dim)
179
-
180
- # Define transformers blocks
181
- self.transformer_blocks = nn.ModuleList(
182
- [
183
- BasicBlock(
184
- inner_dim,
185
- kv_dim=self.cfg.cross_attention_dim,
186
- num_heads=self.num_attention_heads,
187
- qkv_bias=self.cfg.attention_bias,
188
- proj_drop=self.cfg.dropout,
189
- ff_drop=self.cfg.dropout,
190
- )
191
- for d in range(self.cfg.num_layers)
192
- ]
193
- )
194
-
195
- # 4. Define output layers
196
- self.proj_out = nn.Linear(inner_dim, self.cfg.in_channels)
197
-
198
- def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
199
- residual = hidden_states
200
- hidden_states = self.norm(hidden_states)
201
- hidden_states = hidden_states.permute(0, 2, 1)
202
- hidden_states = self.proj_in(hidden_states)
203
- for block in self.transformer_blocks:
204
- hidden_states = block(hidden_states, encoder_hidden_states)
205
- hidden_states = self.proj_out(hidden_states).permute(0, 2, 1).contiguous()
206
- # TODO: do we really need to add the residual?
207
- hidden_states = hidden_states + residual
208
- return hidden_states
209
-
210
-
211
- class FuseBlock(nn.Module):
212
- """
213
- Fuse X in to Z with cross attention
214
- """
215
-
216
- def __init__(
217
- self,
218
- dim_z: int,
219
- dim_x: int,
220
- num_heads: int = 16,
221
- qkv_bias: bool = False,
222
- attn_drop: float = 0.0,
223
- proj_drop: float = 0.0,
224
- ff_drop: float = 0.0,
225
- norm_x_input: bool = True,
226
- ):
227
- super().__init__()
228
- self.norm_x_input = norm_x_input
229
- if self.norm_x_input:
230
- self.norm_x = nn.LayerNorm(dim_x)
231
- self.attn = CrossAttention(
232
- dim_z,
233
- kv_dim=dim_x,
234
- num_heads=num_heads,
235
- qkv_bias=qkv_bias,
236
- attn_drop=attn_drop,
237
- proj_drop=proj_drop,
238
- )
239
- self.norm_z1 = nn.LayerNorm(dim_z)
240
- self.norm_z2 = nn.LayerNorm(dim_z)
241
- self.ff = FeedForward(dim_z, dropout=ff_drop)
242
-
243
- def forward(self, z, x):
244
- # TODO: do we need to normalize x?
245
- z = z + self.attn(self.norm_z1(z), self.norm_x(x) if self.norm_x_input else x)
246
- z = z + self.ff(self.norm_z2(z))
247
- return z
248
-
249
-
250
- @torch.no_grad()
251
- def get_triplane_attention_mask(res):
252
- N = 3 * res * res
253
- attn_mask = torch.zeros(3, res, res, 3, res, res)
254
-
255
- i, j = torch.meshgrid(torch.arange(res), torch.arange(res))
256
-
257
- attn_mask[0, i, j, 1, i, :] = 1.0
258
- attn_mask[0, i, j, 2, j, :] = 1.0
259
- attn_mask[1, i, j, 0, i, :] = 1.0
260
- attn_mask[1, i, j, 2, :, j] = 1.0
261
- attn_mask[2, i, j, 0, :, i] = 1.0
262
- attn_mask[2, i, j, 1, :, j] = 1.0
263
- attn_mask = attn_mask.bool()
264
-
265
- attn_bias = torch.empty_like(attn_mask, dtype=torch.float)
266
- attn_bias.masked_fill_(attn_mask, 0.0)
267
- attn_bias.masked_fill_(~attn_mask, float("-inf"))
268
-
269
- return attn_bias.reshape(N, N)
270
-
271
-
272
- class TriplaneAttention(nn.Module):
273
- def __init__(
274
- self,
275
- dim: int,
276
- resolution: int,
277
- num_heads: int = 16,
278
- qkv_bias: bool = False,
279
- attn_drop: float = 0.0,
280
- proj_drop: float = 0.0,
281
- full_attention: bool = False,
282
- ):
283
- super().__init__()
284
- self.num_heads = num_heads
285
- head_dim = dim // num_heads
286
- self.scale = head_dim**-0.5
287
- self.wq = nn.Linear(dim, dim, bias=qkv_bias)
288
- self.wk = nn.Linear(dim, dim, bias=qkv_bias)
289
- self.wv = nn.Linear(dim, dim, bias=qkv_bias)
290
- self.attn_drop = attn_drop
291
- self.proj = nn.Linear(dim, dim)
292
- self.proj_drop = nn.Dropout(proj_drop)
293
-
294
- self.resolution = resolution
295
- self.full_attention = full_attention
296
- self.attn_mask = (
297
- get_triplane_attention_mask(resolution) if not full_attention else None
298
- )
299
-
300
- def forward(self, x):
301
- B, N, C = x.shape
302
- # [B, N, C] -> [B, N, H, C/H]
303
- q = self.wq(x).reshape(B, N, self.num_heads, C // self.num_heads)
304
- k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads)
305
- v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads)
306
-
307
- # detokenize the planes
308
- assert N == self.resolution**2 * 3
309
- attn_bias = (
310
- self.attn_mask.to(q)
311
- .unsqueeze(0)
312
- .unsqueeze(0)
313
- .expand(B, self.num_heads, -1, -1)
314
- if not self.full_attention
315
- else None
316
- )
317
-
318
- # full attention
319
- x = torch.nn.functional.scaled_dot_product_attention(
320
- q.permute(0, 2, 1, 3),
321
- k.permute(0, 2, 1, 3),
322
- v.permute(0, 2, 1, 3),
323
- attn_mask=attn_bias,
324
- dropout_p=self.attn_drop,
325
- scale=self.scale,
326
- ).permute(0, 2, 1, 3)
327
-
328
- # [B, N_q, H, C/H] -> [B, N_q, C]
329
- x = x.reshape(B, N, C)
330
- x = self.proj(x)
331
- x = self.proj_drop(x)
332
- return x
333
-
334
-
335
- class TwoStreamBlock(nn.Module):
336
- def __init__(
337
- self,
338
- dim_latent: int,
339
- dim_input: int,
340
- num_basic_blocks: int = 4,
341
- num_heads: int = 16,
342
- qkv_bias: bool = False,
343
- attn_drop: float = 0.0,
344
- proj_drop: float = 0.0,
345
- ff_drop: float = 0.0,
346
- norm_x_input: bool = True,
347
- dim_cross: Optional[int] = None,
348
- ):
349
- super().__init__()
350
-
351
- # Define the fuse block that fuse the input into the latent
352
- self.fuse_block_in = FuseBlock(
353
- dim_latent,
354
- dim_input,
355
- num_heads=num_heads,
356
- qkv_bias=qkv_bias,
357
- attn_drop=attn_drop,
358
- proj_drop=proj_drop,
359
- ff_drop=ff_drop,
360
- norm_x_input=norm_x_input,
361
- )
362
-
363
- # Define the transformer block that process the latent
364
- self.transformer_block = nn.ModuleList(
365
- [
366
- BasicBlock(
367
- dim_latent,
368
- kv_dim=dim_cross,
369
- num_heads=num_heads,
370
- qkv_bias=qkv_bias,
371
- proj_drop=proj_drop,
372
- ff_drop=ff_drop,
373
- )
374
- for _ in range(num_basic_blocks)
375
- ]
376
- )
377
-
378
- # Define the fuse block that fuse the latent into the input
379
- self.fuse_block_out = FuseBlock(
380
- dim_input,
381
- dim_latent,
382
- num_heads=num_heads,
383
- qkv_bias=qkv_bias,
384
- attn_drop=attn_drop,
385
- proj_drop=proj_drop,
386
- ff_drop=ff_drop,
387
- norm_x_input=norm_x_input,
388
- )
389
-
390
- def forward(self, latent, input, cross_input):
391
- latent = self.fuse_block_in(latent, input)
392
- for block in self.transformer_block:
393
- latent = block(latent, cross_input)
394
- input = self.fuse_block_out(input, latent)
395
- return latent, input
396
-
397
-
398
- class TwoStreamInterleaveTransformer(BaseModule):
399
- @dataclass
400
- class Config(BaseModule.Config):
401
- num_attention_heads: int = 16
402
- attention_head_dim: int = 64
403
- raw_triplane_channels: int = 1024
404
- triplane_channels: int = 1024
405
- raw_image_channels: int = 1024
406
- num_latents: int = 1792
407
- num_blocks: int = 4
408
- num_basic_blocks: int = 3
409
- dropout: float = 0.0
410
- latent_init_std: float = 0.02
411
- norm_num_groups: int = 32
412
- attention_bias: bool = False
413
- norm_x_input: bool = False
414
- cross_attention_dim: int = 1024
415
- mix_latent: bool = True
416
-
417
- cfg: Config
418
-
419
- def configure(self) -> None:
420
- self.mix_latent = self.cfg.mix_latent
421
-
422
- # Define the dimensions
423
- self.num_attention_heads = self.cfg.num_attention_heads
424
- self.attention_head_dim = self.cfg.attention_head_dim
425
- self.num_latents = self.cfg.num_latents
426
- self.latent_dim = self.num_attention_heads * self.attention_head_dim
427
-
428
- # Define input layers
429
- if self.cfg.norm_num_groups > 0:
430
- self.norm_triplane = torch.nn.GroupNorm(
431
- num_groups=self.cfg.norm_num_groups,
432
- num_channels=self.cfg.raw_triplane_channels,
433
- eps=1e-6,
434
- affine=True,
435
- )
436
- else:
437
- self.norm_triplane = nn.LayerNorm(self.cfg.raw_triplane_channels)
438
- self.proj_triplane = nn.Linear(
439
- self.cfg.raw_triplane_channels, self.cfg.triplane_channels
440
- )
441
- if self.mix_latent:
442
- self.norm_image = nn.LayerNorm(self.cfg.raw_image_channels)
443
- self.proj_image = nn.Linear(self.cfg.raw_image_channels, self.latent_dim)
444
- self.norm_latent = nn.LayerNorm(self.latent_dim)
445
- self.proj_latent = nn.Linear(self.latent_dim, self.latent_dim)
446
-
447
- # Define the latents
448
- self.latent_init = nn.Parameter(
449
- torch.zeros(1, self.num_latents, self.latent_dim)
450
- )
451
- nn.init.normal_(self.latent_init, std=self.cfg.latent_init_std)
452
-
453
- # Define the transformer blocks
454
- self.main_blocks = nn.ModuleList(
455
- [
456
- TwoStreamBlock(
457
- self.latent_dim,
458
- self.cfg.triplane_channels,
459
- num_basic_blocks=self.cfg.num_basic_blocks,
460
- num_heads=self.num_attention_heads,
461
- qkv_bias=self.cfg.attention_bias,
462
- proj_drop=self.cfg.dropout,
463
- ff_drop=self.cfg.dropout,
464
- norm_x_input=self.cfg.norm_x_input,
465
- dim_cross=self.cfg.cross_attention_dim,
466
- )
467
- for _ in range(self.cfg.num_blocks)
468
- ]
469
- )
470
-
471
- # 4. Define output layers
472
- self.proj_out = nn.Linear(
473
- self.cfg.triplane_channels, self.cfg.raw_triplane_channels
474
- )
475
-
476
- def forward(self, hidden_states, encoder_hidden_states, **kwargs):
477
- # hidden_states: [B, triplane_dim, N_triplane] is triplane tokens
478
- # encoder_hidden_states: [B, N_image, image_dim] is the image tokens
479
- if isinstance(self.norm_triplane, nn.GroupNorm):
480
- triplane_tokens = self.norm_triplane(hidden_states)
481
- triplane_tokens = triplane_tokens.permute(
482
- 0, 2, 1
483
- ) # [B, N_triplane, triplane_dim]
484
- elif isinstance(self.norm_triplane, nn.LayerNorm):
485
- triplane_tokens = self.norm_triplane(hidden_states.permute(0, 2, 1))
486
- else:
487
- raise ValueError("Unknown normalization layer")
488
- triplane_tokens = self.proj_triplane(triplane_tokens)
489
- if self.mix_latent:
490
- image_tokens = self.norm_image(
491
- encoder_hidden_states
492
- ) # [B, N_image, image_dim]
493
- image_tokens = self.proj_image(image_tokens)
494
- init_latents = self.latent_init.expand(
495
- hidden_states.shape[0], -1, -1
496
- ) # [B, N_latent_init, latent_dim]
497
- init_latents = self.norm_latent(init_latents)
498
- init_latents = self.proj_latent(init_latents)
499
- if self.mix_latent:
500
- latent_tokens = torch.cat(
501
- [image_tokens, init_latents], dim=1
502
- ) # [B, N_latent, latent_dim]
503
- else:
504
- latent_tokens = init_latents
505
-
506
- # forward the main blocks
507
- for block in self.main_blocks:
508
- latent_tokens, triplane_tokens = block(
509
- latent_tokens, triplane_tokens, encoder_hidden_states
510
- )
511
-
512
- # project the triplane tokens back to the original dimension
513
- triplane_tokens = self.proj_out(triplane_tokens).permute(0, 2, 1).contiguous()
514
- triplane_tokens = triplane_tokens + hidden_states
515
- return triplane_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/models/utils.py DELETED
@@ -1,292 +0,0 @@
1
- import dataclasses
2
- import importlib
3
- import math
4
- from dataclasses import dataclass
5
- from typing import Any, List, Optional, Tuple, Union
6
-
7
- import numpy as np
8
- import PIL
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from jaxtyping import Bool, Float, Int, Num
13
- from omegaconf import DictConfig, OmegaConf
14
- from torch import Tensor
15
-
16
-
17
- class BaseModule(nn.Module):
18
- @dataclass
19
- class Config:
20
- pass
21
-
22
- cfg: Config # add this to every subclass of BaseModule to enable static type checking
23
-
24
- def __init__(
25
- self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
26
- ) -> None:
27
- super().__init__()
28
- self.cfg = parse_structured(self.Config, cfg)
29
- self.configure(*args, **kwargs)
30
-
31
- def configure(self, *args, **kwargs) -> None:
32
- raise NotImplementedError
33
-
34
-
35
- def find_class(cls_string):
36
- module_string = ".".join(cls_string.split(".")[:-1])
37
- cls_name = cls_string.split(".")[-1]
38
- module = importlib.import_module(module_string, package=None)
39
- cls = getattr(module, cls_name)
40
- return cls
41
-
42
-
43
- def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
44
- # Check if cfg.keys are in fields
45
- cfg_ = cfg.copy()
46
- keys = list(cfg_.keys())
47
-
48
- field_names = {f.name for f in dataclasses.fields(fields)}
49
- for key in keys:
50
- # This is helpful when swapping out modules from CLI
51
- if key not in field_names:
52
- print(f"Ignoring {key} as it's not supported by {fields}")
53
- cfg_.pop(key)
54
- scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg_)
55
- return scfg
56
-
57
-
58
- EPS_DTYPE = {
59
- torch.float16: 1e-4,
60
- torch.bfloat16: 1e-4,
61
- torch.float32: 1e-7,
62
- torch.float64: 1e-8,
63
- }
64
-
65
-
66
- def dot(x, y, dim=-1):
67
- return torch.sum(x * y, dim, keepdim=True)
68
-
69
-
70
- def reflect(x, n):
71
- return x - 2 * dot(x, n) * n
72
-
73
-
74
- def normalize(x, dim=-1, eps=None):
75
- if eps is None:
76
- eps = EPS_DTYPE[x.dtype]
77
- return F.normalize(x, dim=dim, p=2, eps=eps)
78
-
79
-
80
- def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]:
81
- # One pad for determinant
82
- tri_sq = F.pad(tri, (0, 1), "constant", 1.0)
83
- det_tri = torch.det(tri_sq)
84
- tri_rev = torch.cat(
85
- (tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2
86
- )
87
- tri_sq[det_tri < 0] = tri_rev[det_tri < 0]
88
- return tri_sq
89
-
90
-
91
- def triangle_intersection_2d(
92
- t1: Float[Tensor, "*B 3 2"],
93
- t2: Float[Tensor, "*B 3 2"],
94
- eps=1e-12,
95
- ) -> Float[Tensor, "*B"]: # noqa: F821
96
- """Returns True if triangles collide, False otherwise"""
97
-
98
- def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821
99
- logdetx = torch.logdet(x.double())
100
- if eps is None:
101
- return ~torch.isfinite(logdetx)
102
- return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps)))
103
-
104
- t1s = tri_winding(t1)
105
- t2s = tri_winding(t2)
106
-
107
- # Assume the triangles do not collide in the begging
108
- ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device)
109
- for i in range(3):
110
- edge = torch.roll(t1s, i, dims=1)[:, :2, :]
111
- # Check if all points of triangle 2 lay on the external side of edge E.
112
- # If this is the case the triangle do not collide
113
- upd = (
114
- chk_edge(torch.cat((edge, t2s[:, 0:1]), 1))
115
- & chk_edge(torch.cat((edge, t2s[:, 1:2]), 1))
116
- & chk_edge(torch.cat((edge, t2s[:, 2:3]), 1))
117
- )
118
- # Here no collision is still True due to inversion
119
- ret = ret | upd
120
-
121
- for i in range(3):
122
- edge = torch.roll(t2s, i, dims=1)[:, :2, :]
123
-
124
- upd = (
125
- chk_edge(torch.cat((edge, t1s[:, 0:1]), 1))
126
- & chk_edge(torch.cat((edge, t1s[:, 1:2]), 1))
127
- & chk_edge(torch.cat((edge, t1s[:, 2:3]), 1))
128
- )
129
- # Here no collision is still True due to inversion
130
- ret = ret | upd
131
-
132
- return ~ret # Do the inversion
133
-
134
-
135
- ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
136
-
137
-
138
- def scale_tensor(
139
- dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
140
- ):
141
- if inp_scale is None:
142
- inp_scale = (0, 1)
143
- if tgt_scale is None:
144
- tgt_scale = (0, 1)
145
- if isinstance(tgt_scale, Tensor):
146
- assert dat.shape[-1] == tgt_scale.shape[-1]
147
- dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
148
- dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
149
- return dat
150
-
151
-
152
- def dilate_fill(img, mask, iterations=10):
153
- oldMask = mask.float()
154
- oldImg = img
155
-
156
- mask_kernel = torch.ones(
157
- (1, 1, 3, 3),
158
- dtype=oldMask.dtype,
159
- device=oldMask.device,
160
- )
161
-
162
- for i in range(iterations):
163
- newMask = torch.nn.functional.max_pool2d(oldMask, 3, 1, 1)
164
-
165
- # Fill the extension with mean color of old valid regions
166
- img_unfold = F.unfold(oldImg, (3, 3)).view(1, 3, 3 * 3, -1)
167
- mask_unfold = F.unfold(oldMask, (3, 3)).view(1, 1, 3 * 3, -1)
168
- new_mask_unfold = F.unfold(newMask, (3, 3)).view(1, 1, 3 * 3, -1)
169
-
170
- # Average color of the valid region
171
- mean_color = (img_unfold.sum(dim=2) / mask_unfold.sum(dim=2).clip(1)).unsqueeze(
172
- 2
173
- )
174
- # Extend it to the new region
175
- fill_color = (mean_color * new_mask_unfold).view(1, 3 * 3 * 3, -1)
176
-
177
- mask_conv = F.conv2d(
178
- newMask, mask_kernel, padding=1
179
- ) # Get the sum for each kernel patch
180
- newImg = F.fold(
181
- fill_color, (img.shape[-2], img.shape[-1]), (3, 3)
182
- ) / mask_conv.clamp(1)
183
-
184
- diffMask = newMask - oldMask
185
-
186
- oldMask = newMask
187
- oldImg = torch.lerp(oldImg, newImg, diffMask)
188
-
189
- return oldImg
190
-
191
-
192
- def float32_to_uint8_np(
193
- x: Float[np.ndarray, "*B H W C"],
194
- dither: bool = True,
195
- dither_mask: Optional[Float[np.ndarray, "*B H W C"]] = None,
196
- dither_strength: float = 1.0,
197
- ) -> Int[np.ndarray, "*B H W C"]:
198
- if dither:
199
- dither = (
200
- dither_strength * np.random.rand(*x[..., :1].shape).astype(np.float32) - 0.5
201
- )
202
- if dither_mask is not None:
203
- dither = dither * dither_mask
204
- return np.clip(np.floor((256.0 * x + dither)), 0, 255).astype(np.uint8)
205
- return np.clip(np.floor((256.0 * x)), 0, 255).astype(torch.uint8)
206
-
207
-
208
- def convert_data(data):
209
- if data is None:
210
- return None
211
- elif isinstance(data, np.ndarray):
212
- return data
213
- elif isinstance(data, torch.Tensor):
214
- if data.dtype in [torch.float16, torch.bfloat16]:
215
- data = data.float()
216
- return data.detach().cpu().numpy()
217
- elif isinstance(data, list):
218
- return [convert_data(d) for d in data]
219
- elif isinstance(data, dict):
220
- return {k: convert_data(v) for k, v in data.items()}
221
- else:
222
- raise TypeError(
223
- "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting",
224
- type(data),
225
- )
226
-
227
-
228
- class ImageProcessor:
229
- def convert_and_resize(
230
- self,
231
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
232
- size: int,
233
- ):
234
- if isinstance(image, PIL.Image.Image):
235
- image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
236
- elif isinstance(image, np.ndarray):
237
- if image.dtype == np.uint8:
238
- image = torch.from_numpy(image.astype(np.float32) / 255.0)
239
- else:
240
- image = torch.from_numpy(image)
241
- elif isinstance(image, torch.Tensor):
242
- pass
243
-
244
- batched = image.ndim == 4
245
-
246
- if not batched:
247
- image = image[None, ...]
248
- image = F.interpolate(
249
- image.permute(0, 3, 1, 2),
250
- (size, size),
251
- mode="bilinear",
252
- align_corners=False,
253
- antialias=True,
254
- ).permute(0, 2, 3, 1)
255
- if not batched:
256
- image = image[0]
257
- return image
258
-
259
- def __call__(
260
- self,
261
- image: Union[
262
- PIL.Image.Image,
263
- np.ndarray,
264
- torch.FloatTensor,
265
- List[PIL.Image.Image],
266
- List[np.ndarray],
267
- List[torch.FloatTensor],
268
- ],
269
- size: int,
270
- ) -> Any:
271
- if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
272
- image = self.convert_and_resize(image, size)
273
- else:
274
- if not isinstance(image, list):
275
- image = [image]
276
- image = [self.convert_and_resize(im, size) for im in image]
277
- image = torch.stack(image, dim=0)
278
- return image
279
-
280
-
281
- def get_intrinsic_from_fov(fov, H, W, bs=-1):
282
- focal_length = 0.5 * H / np.tan(0.5 * fov)
283
- intrinsic = np.identity(3, dtype=np.float32)
284
- intrinsic[0, 0] = focal_length
285
- intrinsic[1, 1] = focal_length
286
- intrinsic[0, 2] = W / 2.0
287
- intrinsic[1, 2] = H / 2.0
288
-
289
- if bs > 0:
290
- intrinsic = intrinsic[None].repeat(bs, axis=0)
291
-
292
- return torch.from_numpy(intrinsic)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/system.py DELETED
@@ -1,482 +0,0 @@
1
- import os
2
- from dataclasses import dataclass, field
3
- from typing import Any, List, Optional, Tuple
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn.functional as F
8
- import trimesh
9
- from einops import rearrange
10
- from huggingface_hub import hf_hub_download
11
- from jaxtyping import Float
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- from safetensors.torch import load_model
15
- from torch import Tensor
16
-
17
- from sf3d.models.isosurface import MarchingTetrahedraHelper
18
- from sf3d.models.mesh import Mesh
19
- from sf3d.models.utils import (
20
- BaseModule,
21
- ImageProcessor,
22
- convert_data,
23
- dilate_fill,
24
- dot,
25
- find_class,
26
- float32_to_uint8_np,
27
- normalize,
28
- scale_tensor,
29
- )
30
- from sf3d.utils import create_intrinsic_from_fov_deg, default_cond_c2w
31
-
32
- from .texture_baker import TextureBaker
33
-
34
-
35
- class SF3D(BaseModule):
36
- @dataclass
37
- class Config(BaseModule.Config):
38
- cond_image_size: int
39
- isosurface_resolution: int
40
- isosurface_threshold: float = 10.0
41
- radius: float = 1.0
42
- background_color: list[float] = field(default_factory=lambda: [0.5, 0.5, 0.5])
43
- default_fovy_deg: float = 40.0
44
- default_distance: float = 1.6
45
-
46
- camera_embedder_cls: str = ""
47
- camera_embedder: dict = field(default_factory=dict)
48
-
49
- image_tokenizer_cls: str = ""
50
- image_tokenizer: dict = field(default_factory=dict)
51
-
52
- tokenizer_cls: str = ""
53
- tokenizer: dict = field(default_factory=dict)
54
-
55
- backbone_cls: str = ""
56
- backbone: dict = field(default_factory=dict)
57
-
58
- post_processor_cls: str = ""
59
- post_processor: dict = field(default_factory=dict)
60
-
61
- decoder_cls: str = ""
62
- decoder: dict = field(default_factory=dict)
63
-
64
- image_estimator_cls: str = ""
65
- image_estimator: dict = field(default_factory=dict)
66
-
67
- global_estimator_cls: str = ""
68
- global_estimator: dict = field(default_factory=dict)
69
-
70
- cfg: Config
71
-
72
- @classmethod
73
- def from_pretrained(
74
- cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
75
- ):
76
- if os.path.isdir(pretrained_model_name_or_path):
77
- config_path = os.path.join(pretrained_model_name_or_path, config_name)
78
- weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
79
- else:
80
- config_path = hf_hub_download(
81
- repo_id=pretrained_model_name_or_path, filename=config_name
82
- )
83
- weight_path = hf_hub_download(
84
- repo_id=pretrained_model_name_or_path, filename=weight_name
85
- )
86
-
87
- cfg = OmegaConf.load(config_path)
88
- OmegaConf.resolve(cfg)
89
- model = cls(cfg)
90
- load_model(model, weight_path)
91
- return model
92
-
93
- @property
94
- def device(self):
95
- return next(self.parameters()).device
96
-
97
- def configure(self):
98
- self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
99
- self.cfg.image_tokenizer
100
- )
101
- self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
102
- self.camera_embedder = find_class(self.cfg.camera_embedder_cls)(
103
- self.cfg.camera_embedder
104
- )
105
- self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
106
- self.post_processor = find_class(self.cfg.post_processor_cls)(
107
- self.cfg.post_processor
108
- )
109
- self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
110
- self.image_estimator = find_class(self.cfg.image_estimator_cls)(
111
- self.cfg.image_estimator
112
- )
113
- self.global_estimator = find_class(self.cfg.global_estimator_cls)(
114
- self.cfg.global_estimator
115
- )
116
-
117
- self.bbox: Float[Tensor, "2 3"]
118
- self.register_buffer(
119
- "bbox",
120
- torch.as_tensor(
121
- [
122
- [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius],
123
- [self.cfg.radius, self.cfg.radius, self.cfg.radius],
124
- ],
125
- dtype=torch.float32,
126
- ),
127
- )
128
- self.isosurface_helper = MarchingTetrahedraHelper(
129
- self.cfg.isosurface_resolution,
130
- os.path.join(
131
- os.path.dirname(__file__),
132
- "..",
133
- "load",
134
- "tets",
135
- f"{self.cfg.isosurface_resolution}_tets.npz",
136
- ),
137
- )
138
-
139
- self.baker = TextureBaker()
140
- self.image_processor = ImageProcessor()
141
-
142
- def triplane_to_meshes(
143
- self, triplanes: Float[Tensor, "B 3 Cp Hp Wp"]
144
- ) -> list[Mesh]:
145
- meshes = []
146
- for i in range(triplanes.shape[0]):
147
- triplane = triplanes[i]
148
- grid_vertices = scale_tensor(
149
- self.isosurface_helper.grid_vertices.to(triplanes.device),
150
- self.isosurface_helper.points_range,
151
- self.bbox,
152
- )
153
-
154
- values = self.query_triplane(grid_vertices, triplane)
155
- decoded = self.decoder(values, include=["vertex_offset", "density"])
156
- sdf = decoded["density"] - self.cfg.isosurface_threshold
157
-
158
- deform = decoded["vertex_offset"].squeeze(0)
159
-
160
- mesh: Mesh = self.isosurface_helper(
161
- sdf.view(-1, 1), deform.view(-1, 3) if deform is not None else None
162
- )
163
- mesh.v_pos = scale_tensor(
164
- mesh.v_pos, self.isosurface_helper.points_range, self.bbox
165
- )
166
-
167
- meshes.append(mesh)
168
-
169
- return meshes
170
-
171
- def query_triplane(
172
- self,
173
- positions: Float[Tensor, "*B N 3"],
174
- triplanes: Float[Tensor, "*B 3 Cp Hp Wp"],
175
- ) -> Float[Tensor, "*B N F"]:
176
- batched = positions.ndim == 3
177
- if not batched:
178
- # no batch dimension
179
- triplanes = triplanes[None, ...]
180
- positions = positions[None, ...]
181
- assert triplanes.ndim == 5 and positions.ndim == 3
182
-
183
- positions = scale_tensor(
184
- positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
185
- )
186
-
187
- indices2D: Float[Tensor, "B 3 N 2"] = torch.stack(
188
- (positions[..., [0, 1]], positions[..., [0, 2]], positions[..., [1, 2]]),
189
- dim=-3,
190
- ).to(triplanes.dtype)
191
- out: Float[Tensor, "B3 Cp 1 N"] = F.grid_sample(
192
- rearrange(triplanes, "B Np Cp Hp Wp -> (B Np) Cp Hp Wp", Np=3).float(),
193
- rearrange(indices2D, "B Np N Nd -> (B Np) () N Nd", Np=3).float(),
194
- align_corners=True,
195
- mode="bilinear",
196
- )
197
- out = rearrange(out, "(B Np) Cp () N -> B N (Np Cp)", Np=3)
198
-
199
- return out
200
-
201
- def get_scene_codes(self, batch) -> Float[Tensor, "B 3 C H W"]:
202
- # if batch[rgb_cond] is only one view, add a view dimension
203
- if len(batch["rgb_cond"].shape) == 4:
204
- batch["rgb_cond"] = batch["rgb_cond"].unsqueeze(1)
205
- batch["mask_cond"] = batch["mask_cond"].unsqueeze(1)
206
- batch["c2w_cond"] = batch["c2w_cond"].unsqueeze(1)
207
- batch["intrinsic_cond"] = batch["intrinsic_cond"].unsqueeze(1)
208
- batch["intrinsic_normed_cond"] = batch["intrinsic_normed_cond"].unsqueeze(1)
209
- batch_size, n_input_views = batch["rgb_cond"].shape[:2]
210
-
211
- camera_embeds: Optional[Float[Tensor, "B Nv Cc"]]
212
- camera_embeds = self.camera_embedder(**batch)
213
-
214
- input_image_tokens: Float[Tensor, "B Nv Cit Nit"] = self.image_tokenizer(
215
- rearrange(batch["rgb_cond"], "B Nv H W C -> B Nv C H W"),
216
- modulation_cond=camera_embeds,
217
- )
218
-
219
- input_image_tokens = rearrange(
220
- input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=n_input_views
221
- )
222
-
223
- tokens: Float[Tensor, "B Ct Nt"] = self.tokenizer(batch_size)
224
-
225
- tokens = self.backbone(
226
- tokens,
227
- encoder_hidden_states=input_image_tokens,
228
- modulation_cond=None,
229
- )
230
-
231
- direct_codes = self.tokenizer.detokenize(tokens)
232
- scene_codes = self.post_processor(direct_codes)
233
- return scene_codes, direct_codes
234
-
235
- def run_image(
236
- self,
237
- image: Image,
238
- bake_resolution: int,
239
- estimate_illumination: bool = False,
240
- ) -> Tuple[trimesh.Trimesh, dict[str, Any]]:
241
- if image.mode != "RGBA":
242
- raise ValueError("Image must be in RGBA mode")
243
- img_cond = (
244
- torch.from_numpy(
245
- np.asarray(
246
- image.resize((self.cfg.cond_image_size, self.cfg.cond_image_size))
247
- ).astype(np.float32)
248
- / 255.0
249
- )
250
- .float()
251
- .clip(0, 1)
252
- .to(self.device)
253
- )
254
- mask_cond = img_cond[:, :, -1:]
255
- rgb_cond = torch.lerp(
256
- torch.tensor(self.cfg.background_color, device=self.device)[None, None, :],
257
- img_cond[:, :, :3],
258
- mask_cond,
259
- )
260
-
261
- c2w_cond = default_cond_c2w(self.cfg.default_distance).to(self.device)
262
- intrinsic, intrinsic_normed_cond = create_intrinsic_from_fov_deg(
263
- self.cfg.default_fovy_deg,
264
- self.cfg.cond_image_size,
265
- self.cfg.cond_image_size,
266
- )
267
-
268
- batch = {
269
- "rgb_cond": rgb_cond,
270
- "mask_cond": mask_cond,
271
- "c2w_cond": c2w_cond.unsqueeze(0),
272
- "intrinsic_cond": intrinsic.to(self.device).unsqueeze(0),
273
- "intrinsic_normed_cond": intrinsic_normed_cond.to(self.device).unsqueeze(0),
274
- }
275
-
276
- meshes, global_dict = self.generate_mesh(
277
- batch, bake_resolution, estimate_illumination
278
- )
279
- return meshes[0], global_dict
280
-
281
- def generate_mesh(
282
- self,
283
- batch,
284
- bake_resolution: int,
285
- estimate_illumination: bool = False,
286
- ) -> Tuple[List[trimesh.Trimesh], dict[str, Any]]:
287
- batch["rgb_cond"] = self.image_processor(
288
- batch["rgb_cond"], self.cfg.cond_image_size
289
- )
290
- batch["mask_cond"] = self.image_processor(
291
- batch["mask_cond"], self.cfg.cond_image_size
292
- )
293
- scene_codes, non_postprocessed_codes = self.get_scene_codes(batch)
294
-
295
- global_dict = {}
296
- if self.image_estimator is not None:
297
- global_dict.update(
298
- self.image_estimator(batch["rgb_cond"] * batch["mask_cond"])
299
- )
300
- if self.global_estimator is not None and estimate_illumination:
301
- global_dict.update(self.global_estimator(non_postprocessed_codes))
302
-
303
- with torch.no_grad():
304
- with torch.autocast(device_type="cuda", enabled=False):
305
- meshes = self.triplane_to_meshes(scene_codes)
306
-
307
- rets = []
308
- for i, mesh in enumerate(meshes):
309
- # Check for empty mesh
310
- if mesh.v_pos.shape[0] == 0:
311
- rets.append(trimesh.Trimesh())
312
- continue
313
-
314
- mesh.unwrap_uv()
315
-
316
- # Build textures
317
- rast = self.baker.rasterize(
318
- mesh.v_tex, mesh.t_pos_idx, bake_resolution
319
- )
320
- bake_mask = self.baker.get_mask(rast)
321
-
322
- pos_bake = self.baker.interpolate(
323
- mesh.v_pos,
324
- rast,
325
- mesh.t_pos_idx,
326
- mesh.v_tex,
327
- )
328
- gb_pos = pos_bake[bake_mask]
329
-
330
- tri_query = self.query_triplane(gb_pos, scene_codes[i])[0]
331
- decoded = self.decoder(
332
- tri_query, exclude=["density", "vertex_offset"]
333
- )
334
-
335
- nrm = self.baker.interpolate(
336
- mesh.v_nrm,
337
- rast,
338
- mesh.t_pos_idx,
339
- mesh.v_tex,
340
- )
341
- gb_nrm = F.normalize(nrm[bake_mask], dim=-1)
342
- decoded["normal"] = gb_nrm
343
-
344
- # Check if any keys in global_dict start with decoded_
345
- for k, v in global_dict.items():
346
- if k.startswith("decoder_"):
347
- decoded[k.replace("decoder_", "")] = v[i]
348
-
349
- mat_out = {
350
- "albedo": decoded["features"],
351
- "roughness": decoded["roughness"],
352
- "metallic": decoded["metallic"],
353
- "normal": normalize(decoded["perturb_normal"]),
354
- "bump": None,
355
- }
356
-
357
- for k, v in mat_out.items():
358
- if v is None:
359
- continue
360
- if v.shape[0] == 1:
361
- # Skip and directly add a single value
362
- mat_out[k] = v[0]
363
- else:
364
- f = torch.zeros(
365
- bake_resolution,
366
- bake_resolution,
367
- v.shape[-1],
368
- dtype=v.dtype,
369
- device=v.device,
370
- )
371
- if v.shape == f.shape:
372
- continue
373
- if k == "normal":
374
- # Use un-normalized tangents here so that larger smaller tris
375
- # Don't effect the tangents that much
376
- tng = self.baker.interpolate(
377
- mesh.v_tng,
378
- rast,
379
- mesh.t_pos_idx,
380
- mesh.v_tex,
381
- )
382
- gb_tng = tng[bake_mask]
383
- gb_tng = F.normalize(gb_tng, dim=-1)
384
- gb_btng = F.normalize(
385
- torch.cross(gb_tng, gb_nrm, dim=-1), dim=-1
386
- )
387
- normal = F.normalize(mat_out["normal"], dim=-1)
388
-
389
- bump = torch.cat(
390
- # Check if we have to flip some things
391
- (
392
- dot(normal, gb_tng),
393
- dot(normal, gb_btng),
394
- dot(normal, gb_nrm).clip(
395
- 0.3, 1
396
- ), # Never go below 0.3. This would indicate a flipped (or close to one) normal
397
- ),
398
- -1,
399
- )
400
- bump = (bump * 0.5 + 0.5).clamp(0, 1)
401
-
402
- f[bake_mask] = bump.view(-1, 3)
403
- mat_out["bump"] = f
404
- else:
405
- f[bake_mask] = v.view(-1, v.shape[-1])
406
- mat_out[k] = f
407
-
408
- def uv_padding(arr):
409
- if arr.ndim == 1:
410
- return arr
411
- return (
412
- dilate_fill(
413
- arr.permute(2, 0, 1)[None, ...],
414
- bake_mask.unsqueeze(0).unsqueeze(0),
415
- iterations=bake_resolution // 150,
416
- )
417
- .squeeze(0)
418
- .permute(1, 2, 0)
419
- )
420
-
421
- verts_np = convert_data(mesh.v_pos)
422
- faces = convert_data(mesh.t_pos_idx)
423
- uvs = convert_data(mesh.v_tex)
424
-
425
- basecolor_tex = Image.fromarray(
426
- float32_to_uint8_np(convert_data(uv_padding(mat_out["albedo"])))
427
- ).convert("RGB")
428
- basecolor_tex.format = "JPEG"
429
-
430
- metallic = mat_out["metallic"].squeeze().cpu().item()
431
- roughness = mat_out["roughness"].squeeze().cpu().item()
432
-
433
- if "bump" in mat_out and mat_out["bump"] is not None:
434
- bump_np = convert_data(uv_padding(mat_out["bump"]))
435
- bump_up = np.ones_like(bump_np)
436
- bump_up[..., :2] = 0.5
437
- bump_up[..., 2:] = 1
438
- bump_tex = Image.fromarray(
439
- float32_to_uint8_np(
440
- bump_np,
441
- dither=True,
442
- # Do not dither if something is perfectly flat
443
- dither_mask=np.all(
444
- bump_np == bump_up, axis=-1, keepdims=True
445
- ).astype(np.float32),
446
- )
447
- ).convert("RGB")
448
- bump_tex.format = (
449
- "JPEG" # PNG would be better but the assets are larger
450
- )
451
- else:
452
- bump_tex = None
453
-
454
- material = trimesh.visual.material.PBRMaterial(
455
- baseColorTexture=basecolor_tex,
456
- roughnessFactor=roughness,
457
- metallicFactor=metallic,
458
- normalTexture=bump_tex,
459
- )
460
-
461
- tmesh = trimesh.Trimesh(
462
- vertices=verts_np,
463
- faces=faces,
464
- visual=trimesh.visual.texture.TextureVisuals(
465
- uv=uvs, material=material
466
- ),
467
- )
468
- rot = trimesh.transformations.rotation_matrix(
469
- np.radians(-90), [1, 0, 0]
470
- )
471
- tmesh.apply_transform(rot)
472
- tmesh.apply_transform(
473
- trimesh.transformations.rotation_matrix(
474
- np.radians(90), [0, 1, 0]
475
- )
476
- )
477
-
478
- tmesh.invert()
479
-
480
- rets.append(tmesh)
481
-
482
- return rets, global_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/texture_baker.py DELETED
@@ -1,87 +0,0 @@
1
- import os
2
-
3
- import slangtorch
4
- import torch
5
- import torch.nn as nn
6
- from jaxtyping import Bool, Float
7
- from torch import Tensor
8
-
9
-
10
- class TextureBaker(nn.Module):
11
- def __init__(self):
12
- super().__init__()
13
- self.baker = slangtorch.loadModule(
14
- os.path.join(os.path.dirname(__file__), "texture_baker.slang")
15
- )
16
-
17
- def rasterize(
18
- self,
19
- uv: Float[Tensor, "Nv 2"],
20
- face_indices: Float[Tensor, "Nf 3"],
21
- bake_resolution: int,
22
- ) -> Float[Tensor, "bake_resolution bake_resolution 4"]:
23
- if not face_indices.is_cuda or not uv.is_cuda:
24
- raise ValueError("All input tensors must be on cuda")
25
-
26
- face_indices = face_indices.to(torch.int32)
27
- uv = uv.to(torch.float32)
28
-
29
- rast_result = torch.empty(
30
- bake_resolution, bake_resolution, 4, device=uv.device, dtype=torch.float32
31
- )
32
-
33
- block_size = 16
34
- grid_size = bake_resolution // block_size
35
- self.baker.bake_uv(uv=uv, indices=face_indices, output=rast_result).launchRaw(
36
- blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
37
- )
38
-
39
- return rast_result
40
-
41
- def get_mask(
42
- self, rast: Float[Tensor, "bake_resolution bake_resolution 4"]
43
- ) -> Bool[Tensor, "bake_resolution bake_resolution"]:
44
- return rast[..., -1] >= 0
45
-
46
- def interpolate(
47
- self,
48
- attr: Float[Tensor, "Nv 3"],
49
- rast: Float[Tensor, "bake_resolution bake_resolution 4"],
50
- face_indices: Float[Tensor, "Nf 3"],
51
- uv: Float[Tensor, "Nv 2"],
52
- ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
53
- # Make sure all input tensors are on torch
54
- if not attr.is_cuda or not face_indices.is_cuda or not rast.is_cuda:
55
- raise ValueError("All input tensors must be on cuda")
56
-
57
- attr = attr.to(torch.float32)
58
- face_indices = face_indices.to(torch.int32)
59
- uv = uv.to(torch.float32)
60
-
61
- pos_bake = torch.zeros(
62
- rast.shape[0],
63
- rast.shape[1],
64
- 3,
65
- device=attr.device,
66
- dtype=attr.dtype,
67
- )
68
-
69
- block_size = 16
70
- grid_size = rast.shape[0] // block_size
71
- self.baker.interpolate(
72
- attr=attr, indices=face_indices, rast=rast, output=pos_bake
73
- ).launchRaw(
74
- blockSize=(block_size, block_size, 1), gridSize=(grid_size, grid_size, 1)
75
- )
76
-
77
- return pos_bake
78
-
79
- def forward(
80
- self,
81
- attr: Float[Tensor, "Nv 3"],
82
- uv: Float[Tensor, "Nv 2"],
83
- face_indices: Float[Tensor, "Nf 3"],
84
- bake_resolution: int,
85
- ) -> Float[Tensor, "bake_resolution bake_resolution 3"]:
86
- rast = self.rasterize(uv, face_indices, bake_resolution)
87
- return self.interpolate(attr, rast, face_indices, uv)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/texture_baker.slang DELETED
@@ -1,93 +0,0 @@
1
- // xy: 2D test position
2
- // v1: vertex position 1
3
- // v2: vertex position 2
4
- // v3: vertex position 3
5
- //
6
- bool barycentric_coordinates(float2 xy, float2 v1, float2 v2, float2 v3, out float u, out float v, out float w)
7
- {
8
- // Return true if the point (x,y) is inside the triangle defined by the vertices v1, v2, v3.
9
- // If the point is inside the triangle, the barycentric coordinates are stored in u, v, and w.
10
- float2 v1v2 = v2 - v1;
11
- float2 v1v3 = v3 - v1;
12
- float2 xyv1 = xy - v1;
13
-
14
- float d00 = dot(v1v2, v1v2);
15
- float d01 = dot(v1v2, v1v3);
16
- float d11 = dot(v1v3, v1v3);
17
- float d20 = dot(xyv1, v1v2);
18
- float d21 = dot(xyv1, v1v3);
19
-
20
- float denom = d00 * d11 - d01 * d01;
21
- v = (d11 * d20 - d01 * d21) / denom;
22
- w = (d00 * d21 - d01 * d20) / denom;
23
- u = 1.0 - v - w;
24
-
25
- return (v >= 0.0) && (w >= 0.0) && (v + w <= 1.0);
26
- }
27
-
28
- [AutoPyBindCUDA]
29
- [CUDAKernel]
30
- void interpolate(
31
- TensorView<float3> attr,
32
- TensorView<int3> indices,
33
- TensorView<float4> rast,
34
- TensorView<float3> output)
35
- {
36
- // Interpolate the attr into output based on the rast result (barycentric coordinates, + triangle idx)
37
-
38
- uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
39
-
40
- if (dispatch_id.x > output.size(0) || dispatch_id.y > output.size(1))
41
- return;
42
-
43
- float4 barycentric = rast[dispatch_id.x, dispatch_id.y];
44
- int triangle_idx = int(barycentric.w);
45
-
46
- if (triangle_idx < 0) {
47
- output[dispatch_id.x, dispatch_id.y] = float3(0.0, 0.0, 0.0);
48
- return;
49
- }
50
-
51
- float3 v1 = attr[indices[triangle_idx].x];
52
- float3 v2 = attr[indices[triangle_idx].y];
53
- float3 v3 = attr[indices[triangle_idx].z];
54
-
55
- output[dispatch_id.x, dispatch_id.y] = v1 * barycentric.x + v2 * barycentric.y + v3 * barycentric.z;
56
- }
57
-
58
- [AutoPyBindCUDA]
59
- [CUDAKernel]
60
- void bake_uv(
61
- TensorView<float2> uv,
62
- TensorView<int3> indices,
63
- TensorView<float4> output)
64
- {
65
- uint3 dispatch_id = cudaBlockIdx() * cudaBlockDim() + cudaThreadIdx();
66
-
67
- if (dispatch_id.y > output.size(0) || dispatch_id.x > output.size(1))
68
- return;
69
-
70
- // We index x,y but the orginal coords are HW. So swap them
71
- float2 pixel_coord = float2(dispatch_id.y, dispatch_id.x);
72
- // Normalize to [0, 1]
73
- pixel_coord /= float2(output.size(1), output.size(0));
74
- pixel_coord = clamp(pixel_coord, 0.0, 1.0);
75
- // Flip x-axis
76
- pixel_coord.y = 1 - pixel_coord.y;
77
-
78
- for (int i = 0; i < indices.size(0); i++) {
79
- float2 v1 = float2(uv[indices[i].x].x, uv[indices[i].x].y);
80
- float2 v2 = float2(uv[indices[i].y].x, uv[indices[i].y].y);
81
- float2 v3 = float2(uv[indices[i].z].x, uv[indices[i].z].y);
82
-
83
- float u, v, w;
84
- bool hit = barycentric_coordinates(pixel_coord, v1, v2, v3, u, v, w);
85
-
86
- if (hit){
87
- output[dispatch_id.x, dispatch_id.y] = float4(u, v, w, i);
88
- return;
89
- }
90
- }
91
-
92
- output[dispatch_id.x, dispatch_id.y] = float4(0.0, 0.0, 0.0, -1);
93
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sf3d/utils.py DELETED
@@ -1,91 +0,0 @@
1
- from typing import Any
2
-
3
- import numpy as np
4
- import rembg
5
- import torch
6
- from PIL import Image
7
-
8
- import sf3d.models.utils as sf3d_utils
9
-
10
-
11
- def create_intrinsic_from_fov_deg(fov_deg: float, cond_height: int, cond_width: int):
12
- intrinsic = sf3d_utils.get_intrinsic_from_fov(
13
- np.deg2rad(fov_deg),
14
- H=cond_height,
15
- W=cond_width,
16
- )
17
- intrinsic_normed_cond = intrinsic.clone()
18
- intrinsic_normed_cond[..., 0, 2] /= cond_width
19
- intrinsic_normed_cond[..., 1, 2] /= cond_height
20
- intrinsic_normed_cond[..., 0, 0] /= cond_width
21
- intrinsic_normed_cond[..., 1, 1] /= cond_height
22
-
23
- return intrinsic, intrinsic_normed_cond
24
-
25
-
26
- def default_cond_c2w(distance: float):
27
- c2w_cond = torch.as_tensor(
28
- [
29
- [0, 0, 1, distance],
30
- [1, 0, 0, 0],
31
- [0, 1, 0, 0],
32
- [0, 0, 0, 1],
33
- ]
34
- ).float()
35
- return c2w_cond
36
-
37
-
38
- def remove_background(
39
- image: Image,
40
- rembg_session: Any = None,
41
- force: bool = False,
42
- **rembg_kwargs,
43
- ) -> Image:
44
- do_remove = True
45
- if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
46
- do_remove = False
47
- do_remove = do_remove or force
48
- if do_remove:
49
- image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
50
- return image
51
-
52
-
53
- def resize_foreground(
54
- image: Image,
55
- ratio: float,
56
- ) -> Image:
57
- image = np.array(image)
58
- assert image.shape[-1] == 4
59
- alpha = np.where(image[..., 3] > 0)
60
- y1, y2, x1, x2 = (
61
- alpha[0].min(),
62
- alpha[0].max(),
63
- alpha[1].min(),
64
- alpha[1].max(),
65
- )
66
- # crop the foreground
67
- fg = image[y1:y2, x1:x2]
68
- # pad to square
69
- size = max(fg.shape[0], fg.shape[1])
70
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
71
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
72
- new_image = np.pad(
73
- fg,
74
- ((ph0, ph1), (pw0, pw1), (0, 0)),
75
- mode="constant",
76
- constant_values=((0, 0), (0, 0), (0, 0)),
77
- )
78
-
79
- # compute padding according to the ratio
80
- new_size = int(new_image.shape[0] / ratio)
81
- # pad to size, double side
82
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
83
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
84
- new_image = np.pad(
85
- new_image,
86
- ((ph0, ph1), (pw0, pw1), (0, 0)),
87
- mode="constant",
88
- constant_values=((0, 0), (0, 0), (0, 0)),
89
- )
90
- new_image = Image.fromarray(new_image, mode="RGBA")
91
- return new_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
stable_fast.py DELETED
@@ -1,355 +0,0 @@
1
- import os
2
- import tempfile
3
- import time
4
- from functools import lru_cache
5
- from typing import Any
6
-
7
- import gradio as gr
8
- import numpy as np
9
- import rembg
10
- import torch
11
- from gradio_litmodel3d import LitModel3D
12
- from PIL import Image
13
-
14
- import sf3d.utils as sf3d_utils
15
- from sf3d.system import SF3D
16
-
17
- rembg_session = rembg.new_session()
18
-
19
- COND_WIDTH = 512
20
- COND_HEIGHT = 512
21
- COND_DISTANCE = 1.6
22
- COND_FOVY_DEG = 40
23
- BACKGROUND_COLOR = [0.5, 0.5, 0.5]
24
-
25
- # Cached. Doesn't change
26
- c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
27
- intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
28
- COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
29
- )
30
-
31
-
32
- model = SF3D.from_pretrained(
33
- "stabilityai/stable-fast-3d",
34
- config_name="config.yaml",
35
- weight_name="model.safetensors",
36
- )
37
- model.eval().cuda()
38
-
39
- example_files = [
40
- os.path.join("demo_files/examples", f) for f in os.listdir("demo_files/examples")
41
- ]
42
-
43
-
44
- def run_model(input_image):
45
- start = time.time()
46
- with torch.no_grad():
47
- with torch.autocast(device_type="cuda", dtype=torch.float16):
48
- model_batch = create_batch(input_image)
49
- model_batch = {k: v.cuda() for k, v in model_batch.items()}
50
- trimesh_mesh, _glob_dict = model.generate_mesh(model_batch, 1024)
51
- trimesh_mesh = trimesh_mesh[0]
52
-
53
- # Create new tmp file
54
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
55
-
56
- trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
57
-
58
- print("Generation took:", time.time() - start, "s")
59
-
60
- return tmp_file.name
61
-
62
-
63
- def create_batch(input_image: Image) -> dict[str, Any]:
64
- img_cond = (
65
- torch.from_numpy(
66
- np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32)
67
- / 255.0
68
- )
69
- .float()
70
- .clip(0, 1)
71
- )
72
- mask_cond = img_cond[:, :, -1:]
73
- rgb_cond = torch.lerp(
74
- torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
75
- )
76
-
77
- batch_elem = {
78
- "rgb_cond": rgb_cond,
79
- "mask_cond": mask_cond,
80
- "c2w_cond": c2w_cond.unsqueeze(0),
81
- "intrinsic_cond": intrinsic.unsqueeze(0),
82
- "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
83
- }
84
- # Add batch dim
85
- batched = {k: v.unsqueeze(0) for k, v in batch_elem.items()}
86
- return batched
87
-
88
-
89
- @lru_cache
90
- def checkerboard(squares: int, size: int, min_value: float = 0.5):
91
- base = np.zeros((squares, squares)) + min_value
92
- base[1::2, ::2] = 1
93
- base[::2, 1::2] = 1
94
-
95
- repeat_mult = size // squares
96
- return (
97
- base.repeat(repeat_mult, axis=0)
98
- .repeat(repeat_mult, axis=1)[:, :, None]
99
- .repeat(3, axis=-1)
100
- )
101
-
102
-
103
- def remove_background(input_image: Image) -> Image:
104
- return rembg.remove(input_image, session=rembg_session)
105
-
106
-
107
- def resize_foreground(
108
- image: Image,
109
- ratio: float,
110
- ) -> Image:
111
- image = np.array(image)
112
- assert image.shape[-1] == 4
113
- alpha = np.where(image[..., 3] > 0)
114
- y1, y2, x1, x2 = (
115
- alpha[0].min(),
116
- alpha[0].max(),
117
- alpha[1].min(),
118
- alpha[1].max(),
119
- )
120
- # crop the foreground
121
- fg = image[y1:y2, x1:x2]
122
- # pad to square
123
- size = max(fg.shape[0], fg.shape[1])
124
- ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
125
- ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
126
- new_image = np.pad(
127
- fg,
128
- ((ph0, ph1), (pw0, pw1), (0, 0)),
129
- mode="constant",
130
- constant_values=((0, 0), (0, 0), (0, 0)),
131
- )
132
-
133
- # compute padding according to the ratio
134
- new_size = int(new_image.shape[0] / ratio)
135
- # pad to size, double side
136
- ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
137
- ph1, pw1 = new_size - size - ph0, new_size - size - pw0
138
- new_image = np.pad(
139
- new_image,
140
- ((ph0, ph1), (pw0, pw1), (0, 0)),
141
- mode="constant",
142
- constant_values=((0, 0), (0, 0), (0, 0)),
143
- )
144
- new_image = Image.fromarray(new_image, mode="RGBA").resize(
145
- (COND_WIDTH, COND_HEIGHT)
146
- )
147
- return new_image
148
-
149
-
150
- def square_crop(input_image: Image) -> Image:
151
- # Perform a center square crop
152
- min_size = min(input_image.size)
153
- left = (input_image.size[0] - min_size) // 2
154
- top = (input_image.size[1] - min_size) // 2
155
- right = (input_image.size[0] + min_size) // 2
156
- bottom = (input_image.size[1] + min_size) // 2
157
- return input_image.crop((left, top, right, bottom)).resize(
158
- (COND_WIDTH, COND_HEIGHT)
159
- )
160
-
161
-
162
- def show_mask_img(input_image: Image) -> Image:
163
- img_numpy = np.array(input_image)
164
- alpha = img_numpy[:, :, 3] / 255.0
165
- chkb = checkerboard(32, 512) * 255
166
- new_img = img_numpy[..., :3] * alpha[:, :, None] + chkb * (1 - alpha[:, :, None])
167
- return Image.fromarray(new_img.astype(np.uint8), mode="RGB")
168
-
169
-
170
- def run_button(run_btn, input_image, background_state, foreground_ratio):
171
- if run_btn == "Run":
172
- glb_file: str = run_model(background_state)
173
-
174
- return (
175
- gr.update(),
176
- gr.update(),
177
- gr.update(),
178
- gr.update(),
179
- gr.update(value=glb_file, visible=True),
180
- gr.update(visible=True),
181
- )
182
- elif run_btn == "Remove Background":
183
- rem_removed = remove_background(input_image)
184
-
185
- sqr_crop = square_crop(rem_removed)
186
- fr_res = resize_foreground(sqr_crop, foreground_ratio)
187
-
188
- return (
189
- gr.update(value="Run", visible=True),
190
- sqr_crop,
191
- fr_res,
192
- gr.update(value=show_mask_img(fr_res), visible=True),
193
- gr.update(value=None, visible=False),
194
- gr.update(visible=False),
195
- )
196
-
197
-
198
- def requires_bg_remove(image, fr):
199
- if image is None:
200
- return (
201
- gr.update(visible=False, value="Run"),
202
- None,
203
- None,
204
- gr.update(value=None, visible=False),
205
- gr.update(visible=False),
206
- gr.update(visible=False),
207
- )
208
- alpha_channel = np.array(image.getchannel("A"))
209
- min_alpha = alpha_channel.min()
210
-
211
- if min_alpha == 0:
212
- print("Already has alpha")
213
- sqr_crop = square_crop(image)
214
- fr_res = resize_foreground(sqr_crop, fr)
215
- return (
216
- gr.update(value="Run", visible=True),
217
- sqr_crop,
218
- fr_res,
219
- gr.update(value=show_mask_img(fr_res), visible=True),
220
- gr.update(visible=False),
221
- gr.update(visible=False),
222
- )
223
- return (
224
- gr.update(value="Remove Background", visible=True),
225
- None,
226
- None,
227
- gr.update(value=None, visible=False),
228
- gr.update(visible=False),
229
- gr.update(visible=False),
230
- )
231
-
232
-
233
- def update_foreground_ratio(img_proc, fr):
234
- foreground_res = resize_foreground(img_proc, fr)
235
- return (
236
- foreground_res,
237
- gr.update(value=show_mask_img(foreground_res)),
238
- )
239
-
240
-
241
- with gr.Blocks() as demo:
242
- img_proc_state = gr.State()
243
- background_remove_state = gr.State()
244
- gr.Markdown("""
245
- # SF3D: Stable Fast 3D Mesh Reconstruction with UV-unwrapping and Illumination Disentanglement
246
-
247
- **SF3D** is a state-of-the-art method for 3D mesh reconstruction from a single image.
248
- This demo allows you to upload an image and generate a 3D mesh model from it.
249
-
250
- **Tips**
251
- 1. If the image already has an alpha channel, you can skip the background removal step.
252
- 2. You can adjust the foreground ratio to control the size of the foreground object. This can influence the shape
253
- 3. You can upload your own HDR environment map to light the 3D model.
254
- """)
255
- with gr.Row(variant="panel"):
256
- with gr.Column():
257
- with gr.Row():
258
- input_img = gr.Image(
259
- type="pil", label="Input Image", sources="upload", image_mode="RGBA"
260
- )
261
- preview_removal = gr.Image(
262
- label="Preview Background Removal",
263
- type="pil",
264
- image_mode="RGB",
265
- interactive=False,
266
- visible=False,
267
- )
268
-
269
- foreground_ratio = gr.Slider(
270
- label="Foreground Ratio",
271
- minimum=0.5,
272
- maximum=1.0,
273
- value=0.85,
274
- step=0.05,
275
- )
276
-
277
- foreground_ratio.change(
278
- update_foreground_ratio,
279
- inputs=[img_proc_state, foreground_ratio],
280
- outputs=[background_remove_state, preview_removal],
281
- )
282
-
283
- run_btn = gr.Button("Run", variant="primary", visible=False)
284
-
285
- with gr.Column():
286
- output_3d = LitModel3D(
287
- label="3D Model",
288
- visible=False,
289
- clear_color=[0.0, 0.0, 0.0, 0.0],
290
- tonemapping="aces",
291
- contrast=1.0,
292
- scale=1.0,
293
- )
294
- with gr.Column(visible=False, scale=1.0) as hdr_row:
295
- gr.Markdown("""## HDR Environment Map
296
-
297
- Select an HDR environment map to light the 3D model. You can also upload your own HDR environment maps.
298
- """)
299
-
300
- with gr.Row():
301
- hdr_illumination_file = gr.File(
302
- label="HDR Env Map", file_types=[".hdr"], file_count="single"
303
- )
304
- example_hdris = [
305
- os.path.join("demo_files/hdri", f)
306
- for f in os.listdir("demo_files/hdri")
307
- ]
308
- hdr_illumination_example = gr.Examples(
309
- examples=example_hdris,
310
- inputs=hdr_illumination_file,
311
- )
312
-
313
- hdr_illumination_file.change(
314
- lambda x: gr.update(env_map=x.name if x is not None else None),
315
- inputs=hdr_illumination_file,
316
- outputs=[output_3d],
317
- )
318
-
319
- examples = gr.Examples(
320
- examples=example_files,
321
- inputs=input_img,
322
- )
323
-
324
- input_img.change(
325
- requires_bg_remove,
326
- inputs=[input_img, foreground_ratio],
327
- outputs=[
328
- run_btn,
329
- img_proc_state,
330
- background_remove_state,
331
- preview_removal,
332
- output_3d,
333
- hdr_row,
334
- ],
335
- )
336
-
337
- run_btn.click(
338
- run_button,
339
- inputs=[
340
- run_btn,
341
- input_img,
342
- background_remove_state,
343
- foreground_ratio,
344
- ],
345
- outputs=[
346
- run_btn,
347
- img_proc_state,
348
- background_remove_state,
349
- preview_removal,
350
- output_3d,
351
- hdr_row,
352
- ],
353
- )
354
-
355
- demo.launch()