huanngzh commited on
Commit
c9724af
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. README.md +14 -0
  3. app.py +334 -0
  4. assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png +0 -0
  5. assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png +0 -0
  6. assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png +0 -0
  7. assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png +0 -0
  8. assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png +0 -0
  9. assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png +0 -0
  10. assets/example_data/Cartoon-Style/00_rgb.png +0 -0
  11. assets/example_data/Cartoon-Style/00_seg.png +0 -0
  12. assets/example_data/Cartoon-Style/01_rgb.png +0 -0
  13. assets/example_data/Cartoon-Style/01_seg.png +0 -0
  14. assets/example_data/Cartoon-Style/02_rgb.png +0 -0
  15. assets/example_data/Cartoon-Style/02_seg.png +0 -0
  16. assets/example_data/Cartoon-Style/03_rgb.png +0 -0
  17. assets/example_data/Cartoon-Style/03_seg.png +0 -0
  18. assets/example_data/Cartoon-Style/04_rgb.png +0 -0
  19. assets/example_data/Cartoon-Style/04_seg.png +0 -0
  20. assets/example_data/Realistic-Style/00_rgb.png +0 -0
  21. assets/example_data/Realistic-Style/00_seg.png +0 -0
  22. assets/example_data/Realistic-Style/01_rgb.png +0 -0
  23. assets/example_data/Realistic-Style/01_seg.png +0 -0
  24. assets/example_data/Realistic-Style/02_rgb.png +0 -0
  25. assets/example_data/Realistic-Style/02_seg.png +0 -0
  26. assets/example_data/Realistic-Style/03_rgb.png +0 -0
  27. assets/example_data/Realistic-Style/03_seg.png +0 -0
  28. assets/example_data/Realistic-Style/04_rgb.png +0 -0
  29. assets/example_data/Realistic-Style/04_seg.png +0 -0
  30. assets/example_data/Realistic-Style/05_rgb.png +0 -0
  31. assets/example_data/Realistic-Style/05_seg.png +0 -0
  32. assets/example_data/Realistic-Style/06_rgb.png +0 -0
  33. assets/example_data/Realistic-Style/06_seg.png +0 -0
  34. midi/inference_utils.py +22 -0
  35. midi/loaders/__init__.py +1 -0
  36. midi/loaders/custom_adapter.py +99 -0
  37. midi/models/attention_processor.py +412 -0
  38. midi/models/autoencoders/__init__.py +1 -0
  39. midi/models/autoencoders/autoencoder_kl_triposg.py +541 -0
  40. midi/models/autoencoders/vae.py +69 -0
  41. midi/models/embeddings.py +96 -0
  42. midi/models/transformers/__init__.py +61 -0
  43. midi/models/transformers/modeling_outputs.py +8 -0
  44. midi/models/transformers/triposg_transformer.py +690 -0
  45. midi/pipelines/pipeline_midi.py +497 -0
  46. midi/pipelines/pipeline_triposg_output.py +25 -0
  47. midi/pipelines/pipeline_utils.py +96 -0
  48. midi/schedulers/__init__.py +5 -0
  49. midi/schedulers/scheduling_rectified_flow.py +327 -0
  50. midi/utils/smoothing.py +615 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MIDI 3D
3
+ emoji: 📚
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.44.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Image to Compositional 3D Scene Generation
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from typing import Any, List, Union
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import spaces
10
+ import torch
11
+ import trimesh
12
+ from gradio_image_prompter import ImagePrompter
13
+ from gradio_litmodel3d import LitModel3D
14
+ from huggingface_hub import snapshot_download
15
+ from PIL import Image
16
+ from skimage import measure
17
+ from transformers import AutoModelForMaskGeneration, AutoProcessor
18
+
19
+ from midi.pipelines.pipeline_midi import MIDIPipeline
20
+ from midi.utils.smoothing import smooth_gpu
21
+ from scripts.grounding_sam import plot_segmentation, segment
22
+ from scripts.inference_midi import preprocess_image, split_rgb_mask
23
+
24
+ # Constants
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
27
+ DTYPE = torch.bfloat16
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ REPO_ID = "VAST-AI/MIDI-3D"
30
+
31
+ MARKDOWN = """
32
+ ## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
33
+ <b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/4fc8aea4-010f-40c7-989d-6b1d9d3e3e09)!
34
+ 1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
35
+ 2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
36
+ 3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
37
+ """
38
+
39
+ EXAMPLES = [
40
+ [
41
+ {
42
+ "image": "assets/example_data/Cartoon-Style/03_rgb.png",
43
+ },
44
+ "assets/example_data/Cartoon-Style/03_seg.png",
45
+ 42,
46
+ False,
47
+ False,
48
+ ],
49
+ [
50
+ {
51
+ "image": "assets/example_data/Cartoon-Style/01_rgb.png",
52
+ },
53
+ "assets/example_data/Cartoon-Style/01_seg.png",
54
+ 42,
55
+ False,
56
+ False,
57
+ ],
58
+ [
59
+ {
60
+ "image": "assets/example_data/Realistic-Style/02_rgb.png",
61
+ },
62
+ "assets/example_data/Realistic-Style/02_seg.png",
63
+ 42,
64
+ False,
65
+ False,
66
+ ],
67
+ [
68
+ {
69
+ "image": "assets/example_data/Cartoon-Style/00_rgb.png",
70
+ },
71
+ "assets/example_data/Cartoon-Style/00_seg.png",
72
+ 42,
73
+ False,
74
+ False,
75
+ ],
76
+ [
77
+ {
78
+ "image": "assets/example_data/Realistic-Style/00_rgb.png",
79
+ },
80
+ "assets/example_data/Realistic-Style/00_seg.png",
81
+ 42,
82
+ False,
83
+ True,
84
+ ],
85
+ [
86
+ {
87
+ "image": "assets/example_data/Realistic-Style/01_rgb.png",
88
+ },
89
+ "assets/example_data/Realistic-Style/01_seg.png",
90
+ 42,
91
+ False,
92
+ True,
93
+ ],
94
+ [
95
+ {
96
+ "image": "assets/example_data/Realistic-Style/05_rgb.png",
97
+ },
98
+ "assets/example_data/Realistic-Style/05_seg.png",
99
+ 42,
100
+ False,
101
+ False,
102
+ ],
103
+ ]
104
+
105
+ os.makedirs(TMP_DIR, exist_ok=True)
106
+
107
+ # Prepare models
108
+ ## Grounding SAM
109
+ segmenter_id = "facebook/sam-vit-base"
110
+ sam_processor = AutoProcessor.from_pretrained(segmenter_id)
111
+ sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
112
+ DEVICE, DTYPE
113
+ )
114
+ ## MIDI-3D
115
+ local_dir = "pretrained_weights/MIDI-3D"
116
+ snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
117
+ pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE)
118
+ pipe.init_custom_adapter(
119
+ set_self_attn_module_names=[
120
+ "blocks.8",
121
+ "blocks.9",
122
+ "blocks.10",
123
+ "blocks.11",
124
+ "blocks.12",
125
+ ]
126
+ )
127
+
128
+
129
+ # Utils
130
+ def get_random_hex():
131
+ random_bytes = os.urandom(8)
132
+ random_hex = random_bytes.hex()
133
+ return random_hex
134
+
135
+
136
+ @spaces.GPU()
137
+ @torch.no_grad()
138
+ @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
139
+ def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image:
140
+ rgb_image = image_prompts["image"].convert("RGB")
141
+
142
+ # pre-process the layers and get the xyxy boxes of each layer
143
+ if len(image_prompts["points"]) == 0:
144
+ gr.Error("Please draw bounding boxes for each instance on the image.")
145
+ boxes = [
146
+ [
147
+ [int(box[0]), int(box[1]), int(box[3]), int(box[4])]
148
+ for box in image_prompts["points"]
149
+ ]
150
+ ]
151
+
152
+ # run the segmentation
153
+ detections = segment(
154
+ sam_processor,
155
+ sam_segmentator,
156
+ rgb_image,
157
+ boxes=[boxes],
158
+ polygon_refinement=polygon_refinement,
159
+ )
160
+ seg_map_pil = plot_segmentation(rgb_image, detections)
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ return seg_map_pil
165
+
166
+
167
+ @torch.no_grad()
168
+ def run_midi(
169
+ pipe: Any,
170
+ rgb_image: Union[str, Image.Image],
171
+ seg_image: Union[str, Image.Image],
172
+ seed: int,
173
+ num_inference_steps: int = 50,
174
+ guidance_scale: float = 7.0,
175
+ do_image_padding: bool = False,
176
+ ) -> trimesh.Scene:
177
+ if do_image_padding:
178
+ rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
179
+ instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)
180
+
181
+ num_instances = len(instance_rgbs)
182
+ outputs = pipe(
183
+ image=instance_rgbs,
184
+ mask=instance_masks,
185
+ image_scene=scene_rgbs,
186
+ attention_kwargs={"num_instances": num_instances},
187
+ generator=torch.Generator(device=pipe.device).manual_seed(seed),
188
+ num_inference_steps=num_inference_steps,
189
+ guidance_scale=guidance_scale,
190
+ decode_progressive=True,
191
+ return_dict=False,
192
+ )
193
+
194
+ return outputs
195
+
196
+
197
+ @spaces.GPU(duration=300)
198
+ @torch.no_grad()
199
+ @torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
200
+ def run_generation(
201
+ rgb_image: Any,
202
+ seg_image: Union[str, Image.Image],
203
+ seed: int,
204
+ randomize_seed: bool = False,
205
+ num_inference_steps: int = 50,
206
+ guidance_scale: float = 7.0,
207
+ do_image_padding: bool = False,
208
+ ):
209
+ if randomize_seed:
210
+ seed = random.randint(0, MAX_SEED)
211
+
212
+ if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
213
+ rgb_image = rgb_image["image"]
214
+
215
+ outputs = run_midi(
216
+ pipe,
217
+ rgb_image,
218
+ seg_image,
219
+ seed,
220
+ num_inference_steps,
221
+ guidance_scale,
222
+ do_image_padding,
223
+ )
224
+
225
+ # marching cubes
226
+ trimeshes = []
227
+ for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
228
+ zip(*outputs)
229
+ ):
230
+ grid_logits = logits_.view(grid_size)
231
+ grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
232
+ torch.cuda.empty_cache()
233
+ vertices, faces, normals, _ = measure.marching_cubes(
234
+ grid_logits.float().cpu().numpy(), 0, method="lewiner"
235
+ )
236
+ vertices = vertices / grid_size * bbox_size + bbox_min
237
+
238
+ # Trimesh
239
+ mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
240
+ trimeshes.append(mesh)
241
+
242
+ # compose the output meshes
243
+ scene = trimesh.Scene(trimeshes)
244
+
245
+ tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb")
246
+ scene.export(tmp_path)
247
+
248
+ torch.cuda.empty_cache()
249
+
250
+ return tmp_path, tmp_path, seed
251
+
252
+
253
+ # Demo
254
+ with gr.Blocks() as demo:
255
+ gr.Markdown(MARKDOWN)
256
+
257
+ with gr.Row():
258
+ with gr.Column():
259
+ with gr.Row():
260
+ image_prompts = ImagePrompter(label="Input Image", type="pil")
261
+ seg_image = gr.Image(
262
+ label="Segmentation Result", type="pil", format="png"
263
+ )
264
+
265
+ with gr.Accordion("Segmentation Settings", open=False):
266
+ polygon_refinement = gr.Checkbox(
267
+ label="Polygon Refinement", value=False
268
+ )
269
+ seg_button = gr.Button("Run Segmentation")
270
+
271
+ with gr.Accordion("Generation Settings", open=False):
272
+ do_image_padding = gr.Checkbox(label="Do image padding", value=False)
273
+ seed = gr.Slider(
274
+ label="Seed",
275
+ minimum=0,
276
+ maximum=MAX_SEED,
277
+ step=1,
278
+ value=0,
279
+ )
280
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
281
+ num_inference_steps = gr.Slider(
282
+ label="Number of inference steps",
283
+ minimum=1,
284
+ maximum=50,
285
+ step=1,
286
+ value=50,
287
+ )
288
+ guidance_scale = gr.Slider(
289
+ label="CFG scale",
290
+ minimum=0.0,
291
+ maximum=10.0,
292
+ step=0.1,
293
+ value=7.0,
294
+ )
295
+ gen_button = gr.Button("Run Generation", variant="primary")
296
+
297
+ with gr.Column():
298
+ model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500)
299
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
300
+
301
+ with gr.Row():
302
+ gr.Examples(
303
+ examples=EXAMPLES,
304
+ fn=run_generation,
305
+ inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding],
306
+ outputs=[model_output, download_glb, seed],
307
+ cache_examples=False,
308
+ )
309
+
310
+ seg_button.click(
311
+ run_segmentation,
312
+ inputs=[
313
+ image_prompts,
314
+ polygon_refinement,
315
+ ],
316
+ outputs=[seg_image],
317
+ ).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
318
+
319
+ gen_button.click(
320
+ run_generation,
321
+ inputs=[
322
+ image_prompts,
323
+ seg_image,
324
+ seed,
325
+ randomize_seed,
326
+ num_inference_steps,
327
+ guidance_scale,
328
+ do_image_padding,
329
+ ],
330
+ outputs=[model_output, download_glb, seed],
331
+ ).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
332
+
333
+
334
+ demo.launch()
assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png ADDED
assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png ADDED
assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png ADDED
assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png ADDED
assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png ADDED
assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png ADDED
assets/example_data/Cartoon-Style/00_rgb.png ADDED
assets/example_data/Cartoon-Style/00_seg.png ADDED
assets/example_data/Cartoon-Style/01_rgb.png ADDED
assets/example_data/Cartoon-Style/01_seg.png ADDED
assets/example_data/Cartoon-Style/02_rgb.png ADDED
assets/example_data/Cartoon-Style/02_seg.png ADDED
assets/example_data/Cartoon-Style/03_rgb.png ADDED
assets/example_data/Cartoon-Style/03_seg.png ADDED
assets/example_data/Cartoon-Style/04_rgb.png ADDED
assets/example_data/Cartoon-Style/04_seg.png ADDED
assets/example_data/Realistic-Style/00_rgb.png ADDED
assets/example_data/Realistic-Style/00_seg.png ADDED
assets/example_data/Realistic-Style/01_rgb.png ADDED
assets/example_data/Realistic-Style/01_seg.png ADDED
assets/example_data/Realistic-Style/02_rgb.png ADDED
assets/example_data/Realistic-Style/02_seg.png ADDED
assets/example_data/Realistic-Style/03_rgb.png ADDED
assets/example_data/Realistic-Style/03_seg.png ADDED
assets/example_data/Realistic-Style/04_rgb.png ADDED
assets/example_data/Realistic-Style/04_seg.png ADDED
assets/example_data/Realistic-Style/05_rgb.png ADDED
assets/example_data/Realistic-Style/05_seg.png ADDED
assets/example_data/Realistic-Style/06_rgb.png ADDED
assets/example_data/Realistic-Style/06_seg.png ADDED
midi/inference_utils.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+ import PIL
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+
8
+
9
+ def generate_dense_grid_points(
10
+ bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
11
+ ):
12
+ length = bbox_max - bbox_min
13
+ num_cells = np.exp2(octree_depth)
14
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
15
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
16
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
17
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
18
+ xyz = np.stack((xs, ys, zs), axis=-1)
19
+ xyz = xyz.reshape(-1, 3)
20
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
21
+
22
+ return xyz, grid_size, length
midi/loaders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .custom_adapter import CustomAdapterMixin
midi/loaders/custom_adapter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Optional, Union
3
+
4
+ import safetensors
5
+ import torch
6
+ from diffusers.utils import _get_model_file, logging
7
+ from safetensors import safe_open
8
+
9
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10
+
11
+
12
+ class CustomAdapterMixin:
13
+ def init_custom_adapter(self, *args, **kwargs):
14
+ self._init_custom_adapter(*args, **kwargs)
15
+
16
+ def _init_custom_adapter(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def load_custom_adapter(
20
+ self,
21
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
22
+ weight_name: str,
23
+ subfolder: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ # Load the main state dict first.
27
+ cache_dir = kwargs.pop("cache_dir", None)
28
+ force_download = kwargs.pop("force_download", False)
29
+ resume_download = kwargs.pop("resume_download", False)
30
+ proxies = kwargs.pop("proxies", None)
31
+ local_files_only = kwargs.pop("local_files_only", None)
32
+ token = kwargs.pop("token", None)
33
+ revision = kwargs.pop("revision", None)
34
+
35
+ user_agent = {
36
+ "file_type": "attn_procs_weights",
37
+ "framework": "pytorch",
38
+ }
39
+
40
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
41
+ model_file = _get_model_file(
42
+ pretrained_model_name_or_path_or_dict,
43
+ weights_name=weight_name,
44
+ cache_dir=cache_dir,
45
+ force_download=force_download,
46
+ proxies=proxies,
47
+ local_files_only=local_files_only,
48
+ token=token,
49
+ revision=revision,
50
+ subfolder=subfolder,
51
+ user_agent=user_agent,
52
+ )
53
+ if weight_name.endswith(".safetensors"):
54
+ state_dict = {}
55
+ with safe_open(model_file, framework="pt", device="cpu") as f:
56
+ for key in f.keys():
57
+ state_dict[key] = f.get_tensor(key)
58
+ else:
59
+ state_dict = torch.load(model_file, map_location="cpu")
60
+ else:
61
+ state_dict = pretrained_model_name_or_path_or_dict
62
+
63
+ self._load_custom_adapter(state_dict)
64
+
65
+ def _load_custom_adapter(self, state_dict):
66
+ raise NotImplementedError
67
+
68
+ def save_custom_adapter(
69
+ self,
70
+ save_directory: Union[str, os.PathLike],
71
+ weight_name: str,
72
+ safe_serialization: bool = False,
73
+ **kwargs,
74
+ ):
75
+ if os.path.isfile(save_directory):
76
+ logger.error(
77
+ f"Provided path ({save_directory}) should be a directory, not a file"
78
+ )
79
+ return
80
+
81
+ if safe_serialization:
82
+
83
+ def save_function(weights, filename):
84
+ return safetensors.torch.save_file(
85
+ weights, filename, metadata={"format": "pt"}
86
+ )
87
+
88
+ else:
89
+ save_function = torch.save
90
+
91
+ # Save the model
92
+ state_dict = self._save_custom_adapter(**kwargs)
93
+ save_function(state_dict, os.path.join(save_directory, weight_name))
94
+ logger.info(
95
+ f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
96
+ )
97
+
98
+ def _save_custom_adapter(self):
99
+ raise NotImplementedError
midi/models/attention_processor.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import logging
7
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
8
+ from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
9
+ from einops import rearrange
10
+ from torch import nn
11
+
12
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
13
+
14
+
15
+ class TripoSGAttnProcessor2_0:
16
+ r"""
17
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
18
+ used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
19
+ """
20
+
21
+ def __init__(self):
22
+ if not hasattr(F, "scaled_dot_product_attention"):
23
+ raise ImportError(
24
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
25
+ )
26
+
27
+ def __call__(
28
+ self,
29
+ attn: Attention,
30
+ hidden_states: torch.Tensor,
31
+ encoder_hidden_states: Optional[torch.Tensor] = None,
32
+ attention_mask: Optional[torch.Tensor] = None,
33
+ temb: Optional[torch.Tensor] = None,
34
+ image_rotary_emb: Optional[torch.Tensor] = None,
35
+ ) -> torch.Tensor:
36
+ from diffusers.models.embeddings import apply_rotary_emb
37
+
38
+ residual = hidden_states
39
+ if attn.spatial_norm is not None:
40
+ hidden_states = attn.spatial_norm(hidden_states, temb)
41
+
42
+ input_ndim = hidden_states.ndim
43
+
44
+ if input_ndim == 4:
45
+ batch_size, channel, height, width = hidden_states.shape
46
+ hidden_states = hidden_states.view(
47
+ batch_size, channel, height * width
48
+ ).transpose(1, 2)
49
+
50
+ batch_size, sequence_length, _ = (
51
+ hidden_states.shape
52
+ if encoder_hidden_states is None
53
+ else encoder_hidden_states.shape
54
+ )
55
+
56
+ if attention_mask is not None:
57
+ attention_mask = attn.prepare_attention_mask(
58
+ attention_mask, sequence_length, batch_size
59
+ )
60
+ # scaled_dot_product_attention expects attention_mask shape to be
61
+ # (batch, heads, source_length, target_length)
62
+ attention_mask = attention_mask.view(
63
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
64
+ )
65
+
66
+ if attn.group_norm is not None:
67
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
68
+ 1, 2
69
+ )
70
+
71
+ query = attn.to_q(hidden_states)
72
+
73
+ if encoder_hidden_states is None:
74
+ encoder_hidden_states = hidden_states
75
+ elif attn.norm_cross:
76
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
77
+ encoder_hidden_states
78
+ )
79
+
80
+ key = attn.to_k(encoder_hidden_states)
81
+ value = attn.to_v(encoder_hidden_states)
82
+
83
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
84
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
85
+ if not attn.is_cross_attention:
86
+ qkv = torch.cat((query, key, value), dim=-1)
87
+ split_size = qkv.shape[-1] // attn.heads // 3
88
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
89
+ query, key, value = torch.split(qkv, split_size, dim=-1)
90
+ else:
91
+ kv = torch.cat((key, value), dim=-1)
92
+ split_size = kv.shape[-1] // attn.heads // 2
93
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
94
+ key, value = torch.split(kv, split_size, dim=-1)
95
+
96
+ head_dim = key.shape[-1]
97
+
98
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
99
+
100
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
101
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
102
+
103
+ if attn.norm_q is not None:
104
+ query = attn.norm_q(query)
105
+ if attn.norm_k is not None:
106
+ key = attn.norm_k(key)
107
+
108
+ # Apply RoPE if needed
109
+ if image_rotary_emb is not None:
110
+ query = apply_rotary_emb(query, image_rotary_emb)
111
+ if not attn.is_cross_attention:
112
+ key = apply_rotary_emb(key, image_rotary_emb)
113
+
114
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
115
+ # TODO: add support for attn.scale when we move to Torch 2.1
116
+ hidden_states = F.scaled_dot_product_attention(
117
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
118
+ )
119
+
120
+ hidden_states = hidden_states.transpose(1, 2).reshape(
121
+ batch_size, -1, attn.heads * head_dim
122
+ )
123
+ hidden_states = hidden_states.to(query.dtype)
124
+
125
+ # linear proj
126
+ hidden_states = attn.to_out[0](hidden_states)
127
+ # dropout
128
+ hidden_states = attn.to_out[1](hidden_states)
129
+
130
+ if input_ndim == 4:
131
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
132
+ batch_size, channel, height, width
133
+ )
134
+
135
+ if attn.residual_connection:
136
+ hidden_states = hidden_states + residual
137
+
138
+ hidden_states = hidden_states / attn.rescale_output_factor
139
+
140
+ return hidden_states
141
+
142
+
143
+ class FusedTripoSGAttnProcessor2_0:
144
+ r"""
145
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
146
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
147
+ query and key vector.
148
+ """
149
+
150
+ def __init__(self):
151
+ if not hasattr(F, "scaled_dot_product_attention"):
152
+ raise ImportError(
153
+ "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
154
+ )
155
+
156
+ def __call__(
157
+ self,
158
+ attn: Attention,
159
+ hidden_states: torch.Tensor,
160
+ encoder_hidden_states: Optional[torch.Tensor] = None,
161
+ attention_mask: Optional[torch.Tensor] = None,
162
+ temb: Optional[torch.Tensor] = None,
163
+ image_rotary_emb: Optional[torch.Tensor] = None,
164
+ ) -> torch.Tensor:
165
+ from diffusers.models.embeddings import apply_rotary_emb
166
+
167
+ residual = hidden_states
168
+ if attn.spatial_norm is not None:
169
+ hidden_states = attn.spatial_norm(hidden_states, temb)
170
+
171
+ input_ndim = hidden_states.ndim
172
+
173
+ if input_ndim == 4:
174
+ batch_size, channel, height, width = hidden_states.shape
175
+ hidden_states = hidden_states.view(
176
+ batch_size, channel, height * width
177
+ ).transpose(1, 2)
178
+
179
+ batch_size, sequence_length, _ = (
180
+ hidden_states.shape
181
+ if encoder_hidden_states is None
182
+ else encoder_hidden_states.shape
183
+ )
184
+
185
+ if attention_mask is not None:
186
+ attention_mask = attn.prepare_attention_mask(
187
+ attention_mask, sequence_length, batch_size
188
+ )
189
+ # scaled_dot_product_attention expects attention_mask shape to be
190
+ # (batch, heads, source_length, target_length)
191
+ attention_mask = attention_mask.view(
192
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
193
+ )
194
+
195
+ if attn.group_norm is not None:
196
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
197
+ 1, 2
198
+ )
199
+
200
+ # NOTE that pre-trained split heads first, then split qkv
201
+ if encoder_hidden_states is None:
202
+ qkv = attn.to_qkv(hidden_states)
203
+ split_size = qkv.shape[-1] // attn.heads // 3
204
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
205
+ query, key, value = torch.split(qkv, split_size, dim=-1)
206
+ else:
207
+ if attn.norm_cross:
208
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
209
+ encoder_hidden_states
210
+ )
211
+ query = attn.to_q(hidden_states)
212
+
213
+ kv = attn.to_kv(encoder_hidden_states)
214
+ split_size = kv.shape[-1] // attn.heads // 2
215
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
216
+ key, value = torch.split(kv, split_size, dim=-1)
217
+
218
+ head_dim = key.shape[-1]
219
+
220
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
221
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
222
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
223
+
224
+ if attn.norm_q is not None:
225
+ query = attn.norm_q(query)
226
+ if attn.norm_k is not None:
227
+ key = attn.norm_k(key)
228
+
229
+ # Apply RoPE if needed
230
+ if image_rotary_emb is not None:
231
+ query = apply_rotary_emb(query, image_rotary_emb)
232
+ if not attn.is_cross_attention:
233
+ key = apply_rotary_emb(key, image_rotary_emb)
234
+
235
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
236
+ # TODO: add support for attn.scale when we move to Torch 2.1
237
+ hidden_states = F.scaled_dot_product_attention(
238
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
239
+ )
240
+
241
+ hidden_states = hidden_states.transpose(1, 2).reshape(
242
+ batch_size, -1, attn.heads * head_dim
243
+ )
244
+ hidden_states = hidden_states.to(query.dtype)
245
+
246
+ # linear proj
247
+ hidden_states = attn.to_out[0](hidden_states)
248
+ # dropout
249
+ hidden_states = attn.to_out[1](hidden_states)
250
+
251
+ if input_ndim == 4:
252
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
253
+ batch_size, channel, height, width
254
+ )
255
+
256
+ if attn.residual_connection:
257
+ hidden_states = hidden_states + residual
258
+
259
+ hidden_states = hidden_states / attn.rescale_output_factor
260
+
261
+ return hidden_states
262
+
263
+
264
+ class MIAttnProcessor2_0:
265
+ r"""
266
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
267
+ used in the MIDI model. It applies a normalization layer and rotary embedding on query and key vector.
268
+ """
269
+
270
+ def __init__(self, use_mi: bool = True):
271
+ if not hasattr(F, "scaled_dot_product_attention"):
272
+ raise ImportError(
273
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
274
+ )
275
+
276
+ self.use_mi = use_mi
277
+
278
+ def __call__(
279
+ self,
280
+ attn: Attention,
281
+ hidden_states: torch.Tensor,
282
+ encoder_hidden_states: Optional[torch.Tensor] = None,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ temb: Optional[torch.Tensor] = None,
285
+ image_rotary_emb: Optional[torch.Tensor] = None,
286
+ num_instances: Optional[torch.IntTensor] = None,
287
+ ) -> torch.Tensor:
288
+ from diffusers.models.embeddings import apply_rotary_emb
289
+
290
+ residual = hidden_states
291
+ if attn.spatial_norm is not None:
292
+ hidden_states = attn.spatial_norm(hidden_states, temb)
293
+
294
+ input_ndim = hidden_states.ndim
295
+
296
+ if input_ndim == 4:
297
+ batch_size, channel, height, width = hidden_states.shape
298
+ hidden_states = hidden_states.view(
299
+ batch_size, channel, height * width
300
+ ).transpose(1, 2)
301
+
302
+ batch_size, sequence_length, _ = (
303
+ hidden_states.shape
304
+ if encoder_hidden_states is None
305
+ else encoder_hidden_states.shape
306
+ )
307
+
308
+ if attention_mask is not None:
309
+ attention_mask = attn.prepare_attention_mask(
310
+ attention_mask, sequence_length, batch_size
311
+ )
312
+ # scaled_dot_product_attention expects attention_mask shape to be
313
+ # (batch, heads, source_length, target_length)
314
+ attention_mask = attention_mask.view(
315
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
316
+ )
317
+
318
+ if attn.group_norm is not None:
319
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
320
+ 1, 2
321
+ )
322
+
323
+ query = attn.to_q(hidden_states)
324
+
325
+ if encoder_hidden_states is None:
326
+ encoder_hidden_states = hidden_states
327
+ elif attn.norm_cross:
328
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
329
+ encoder_hidden_states
330
+ )
331
+
332
+ key = attn.to_k(encoder_hidden_states)
333
+ value = attn.to_v(encoder_hidden_states)
334
+
335
+ # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
336
+ # instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
337
+ if not attn.is_cross_attention:
338
+ qkv = torch.cat((query, key, value), dim=-1)
339
+ split_size = qkv.shape[-1] // attn.heads // 3
340
+ qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
341
+ query, key, value = torch.split(qkv, split_size, dim=-1)
342
+ else:
343
+ kv = torch.cat((key, value), dim=-1)
344
+ split_size = kv.shape[-1] // attn.heads // 2
345
+ kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
346
+ key, value = torch.split(kv, split_size, dim=-1)
347
+
348
+ head_dim = key.shape[-1]
349
+
350
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
351
+
352
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
353
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
354
+
355
+ if attn.norm_q is not None:
356
+ query = attn.norm_q(query)
357
+ if attn.norm_k is not None:
358
+ key = attn.norm_k(key)
359
+
360
+ # Apply RoPE if needed
361
+ if image_rotary_emb is not None:
362
+ query = apply_rotary_emb(query, image_rotary_emb)
363
+ if not attn.is_cross_attention:
364
+ key = apply_rotary_emb(key, image_rotary_emb)
365
+
366
+ if self.use_mi and num_instances is not None:
367
+ key = rearrange(
368
+ key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
369
+ ).repeat_interleave(num_instances, dim=0)
370
+ value = rearrange(
371
+ value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
372
+ ).repeat_interleave(num_instances, dim=0)
373
+
374
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
375
+ hidden_states = F.scaled_dot_product_attention(
376
+ query,
377
+ key,
378
+ value,
379
+ dropout_p=0.0,
380
+ is_causal=False,
381
+ )
382
+ else:
383
+ hidden_states = F.scaled_dot_product_attention(
384
+ query,
385
+ key,
386
+ value,
387
+ attn_mask=attention_mask,
388
+ dropout_p=0.0,
389
+ is_causal=False,
390
+ )
391
+
392
+ hidden_states = hidden_states.transpose(1, 2).reshape(
393
+ batch_size, -1, attn.heads * head_dim
394
+ )
395
+ hidden_states = hidden_states.to(query.dtype)
396
+
397
+ # linear proj
398
+ hidden_states = attn.to_out[0](hidden_states)
399
+ # dropout
400
+ hidden_states = attn.to_out[1](hidden_states)
401
+
402
+ if input_ndim == 4:
403
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
404
+ batch_size, channel, height, width
405
+ )
406
+
407
+ if attn.residual_connection:
408
+ hidden_states = hidden_states + residual
409
+
410
+ hidden_states = hidden_states / attn.rescale_output_factor
411
+
412
+ return hidden_states
midi/models/autoencoders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder_kl_triposg import TripoSGVAEModel
midi/models/autoencoders/autoencoder_kl_triposg.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
8
+ from diffusers.models.autoencoders.vae import DecoderOutput
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.models.normalization import FP32LayerNorm, LayerNorm
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.accelerate_utils import apply_forward_hook
14
+ from einops import repeat
15
+ from tqdm import tqdm
16
+ from torch_cluster import fps
17
+
18
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
19
+ from ..embeddings import FrequencyPositionalEmbedding
20
+ from ..transformers.triposg_transformer import DiTBlock
21
+ from .vae import DiagonalGaussianDistribution
22
+
23
+ import subprocess
24
+ import sys
25
+
26
+
27
+ def install_package(package_name):
28
+ try:
29
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
30
+ return True
31
+ except subprocess.CalledProcessError:
32
+ return False
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class TripoSGEncoder(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels: int = 3,
42
+ dim: int = 512,
43
+ num_attention_heads: int = 8,
44
+ num_layers: int = 8,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.proj_in = nn.Linear(in_channels, dim, bias=True)
49
+
50
+ self.blocks = nn.ModuleList(
51
+ [
52
+ DiTBlock(
53
+ dim=dim,
54
+ num_attention_heads=num_attention_heads,
55
+ use_self_attention=False,
56
+ use_cross_attention=True,
57
+ cross_attention_dim=dim,
58
+ cross_attention_norm_type="layer_norm",
59
+ activation_fn="gelu",
60
+ norm_type="fp32_layer_norm",
61
+ norm_eps=1e-5,
62
+ qk_norm=False,
63
+ qkv_bias=False,
64
+ ) # cross attention
65
+ ]
66
+ + [
67
+ DiTBlock(
68
+ dim=dim,
69
+ num_attention_heads=num_attention_heads,
70
+ use_self_attention=True,
71
+ self_attention_norm_type="fp32_layer_norm",
72
+ use_cross_attention=False,
73
+ use_cross_attention_2=False,
74
+ activation_fn="gelu",
75
+ norm_type="fp32_layer_norm",
76
+ norm_eps=1e-5,
77
+ qk_norm=False,
78
+ qkv_bias=False,
79
+ )
80
+ for _ in range(num_layers) # self attention
81
+ ]
82
+ )
83
+
84
+ self.norm_out = LayerNorm(dim)
85
+
86
+ def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
87
+ hidden_states = self.proj_in(sample_1)
88
+ encoder_hidden_states = self.proj_in(sample_2)
89
+
90
+ for layer, block in enumerate(self.blocks):
91
+ if layer == 0:
92
+ hidden_states = block(
93
+ hidden_states, encoder_hidden_states=encoder_hidden_states
94
+ )
95
+ else:
96
+ hidden_states = block(hidden_states)
97
+
98
+ hidden_states = self.norm_out(hidden_states)
99
+
100
+ return hidden_states
101
+
102
+
103
+ class TripoSGDecoder(nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int = 3,
107
+ out_channels: int = 1,
108
+ dim: int = 512,
109
+ num_attention_heads: int = 8,
110
+ num_layers: int = 16,
111
+ grad_type: str = "analytical",
112
+ grad_interval: float = 0.001,
113
+ ):
114
+ super().__init__()
115
+
116
+ if grad_type not in ["numerical", "analytical"]:
117
+ raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
118
+ self.grad_type = grad_type
119
+ self.grad_interval = grad_interval
120
+
121
+ self.blocks = nn.ModuleList(
122
+ [
123
+ DiTBlock(
124
+ dim=dim,
125
+ num_attention_heads=num_attention_heads,
126
+ use_self_attention=True,
127
+ self_attention_norm_type="fp32_layer_norm",
128
+ use_cross_attention=False,
129
+ use_cross_attention_2=False,
130
+ activation_fn="gelu",
131
+ norm_type="fp32_layer_norm",
132
+ norm_eps=1e-5,
133
+ qk_norm=False,
134
+ qkv_bias=False,
135
+ )
136
+ for _ in range(num_layers) # self attention
137
+ ]
138
+ + [
139
+ DiTBlock(
140
+ dim=dim,
141
+ num_attention_heads=num_attention_heads,
142
+ use_self_attention=False,
143
+ use_cross_attention=True,
144
+ cross_attention_dim=dim,
145
+ cross_attention_norm_type="layer_norm",
146
+ activation_fn="gelu",
147
+ norm_type="fp32_layer_norm",
148
+ norm_eps=1e-5,
149
+ qk_norm=False,
150
+ qkv_bias=False,
151
+ ) # cross attention
152
+ ]
153
+ )
154
+
155
+ self.proj_query = nn.Linear(in_channels, dim, bias=True)
156
+
157
+ self.norm_out = LayerNorm(dim)
158
+ self.proj_out = nn.Linear(dim, out_channels, bias=True)
159
+
160
+ def query_geometry(
161
+ self,
162
+ model_fn: callable,
163
+ queries: torch.Tensor,
164
+ sample: torch.Tensor,
165
+ grad: bool = False,
166
+ ):
167
+ logits = model_fn(queries, sample)
168
+ if grad:
169
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
170
+ if self.grad_type == "numerical":
171
+ interval = self.grad_interval
172
+ grad_value = []
173
+ for offset in [
174
+ (interval, 0, 0),
175
+ (0, interval, 0),
176
+ (0, 0, interval),
177
+ ]:
178
+ offset_tensor = torch.tensor(offset, device=queries.device)[
179
+ None, :
180
+ ]
181
+ res_p = model_fn(queries + offset_tensor, sample)[..., 0]
182
+ res_n = model_fn(queries - offset_tensor, sample)[..., 0]
183
+ grad_value.append((res_p - res_n) / (2 * interval))
184
+ grad_value = torch.stack(grad_value, dim=-1)
185
+ else:
186
+ queries_d = torch.clone(queries)
187
+ queries_d.requires_grad = True
188
+ with torch.enable_grad():
189
+ res_d = model_fn(queries_d, sample)
190
+ grad_value = torch.autograd.grad(
191
+ res_d,
192
+ [queries_d],
193
+ grad_outputs=torch.ones_like(res_d),
194
+ create_graph=self.training,
195
+ )[0]
196
+ else:
197
+ grad_value = None
198
+
199
+ return logits, grad_value
200
+
201
+ def forward(
202
+ self,
203
+ sample: torch.Tensor,
204
+ queries: torch.Tensor,
205
+ kv_cache: Optional[torch.Tensor] = None,
206
+ ):
207
+ if kv_cache is None:
208
+ hidden_states = sample
209
+ for _, block in enumerate(self.blocks[:-1]):
210
+ hidden_states = block(hidden_states)
211
+ kv_cache = hidden_states
212
+
213
+ # query grid logits by cross attention
214
+ def query_fn(q, kv):
215
+ q = self.proj_query(q)
216
+ l = self.blocks[-1](q, encoder_hidden_states=kv)
217
+ return self.proj_out(self.norm_out(l))
218
+
219
+ logits, grad = self.query_geometry(
220
+ query_fn, queries, kv_cache, grad=self.training
221
+ )
222
+ logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
223
+
224
+ return logits, kv_cache
225
+
226
+
227
+ class TripoSGVAEModel(ModelMixin, ConfigMixin):
228
+ @register_to_config
229
+ def __init__(
230
+ self,
231
+ in_channels: int = 3, # NOTE xyz instead of feature dim
232
+ latent_channels: int = 64,
233
+ num_attention_heads: int = 8,
234
+ width_encoder: int = 512,
235
+ width_decoder: int = 1024,
236
+ num_layers_encoder: int = 8,
237
+ num_layers_decoder: int = 16,
238
+ embedding_type: str = "frequency",
239
+ embed_frequency: int = 8,
240
+ embed_include_pi: bool = False,
241
+ ):
242
+ super().__init__()
243
+
244
+ self.out_channels = 1
245
+
246
+ if embedding_type == "frequency":
247
+ self.embedder = FrequencyPositionalEmbedding(
248
+ num_freqs=embed_frequency,
249
+ logspace=True,
250
+ input_dim=in_channels,
251
+ include_pi=embed_include_pi,
252
+ )
253
+ else:
254
+ raise NotImplementedError(
255
+ f"Embedding type {embedding_type} is not supported."
256
+ )
257
+
258
+ self.encoder = TripoSGEncoder(
259
+ in_channels=in_channels + self.embedder.out_dim,
260
+ dim=width_encoder,
261
+ num_attention_heads=num_attention_heads,
262
+ num_layers=num_layers_encoder,
263
+ )
264
+ self.decoder = TripoSGDecoder(
265
+ in_channels=self.embedder.out_dim,
266
+ out_channels=self.out_channels,
267
+ dim=width_decoder,
268
+ num_attention_heads=num_attention_heads,
269
+ num_layers=num_layers_decoder,
270
+ )
271
+
272
+ self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
273
+ self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
274
+
275
+ self.use_slicing = False
276
+ self.slicing_length = 1
277
+
278
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
279
+ def fuse_qkv_projections(self):
280
+ """
281
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
282
+ are fused. For cross-attention modules, key and value projection matrices are fused.
283
+
284
+ <Tip warning={true}>
285
+
286
+ This API is 🧪 experimental.
287
+
288
+ </Tip>
289
+ """
290
+ self.original_attn_processors = None
291
+
292
+ for _, attn_processor in self.attn_processors.items():
293
+ if "Added" in str(attn_processor.__class__.__name__):
294
+ raise ValueError(
295
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
296
+ )
297
+
298
+ self.original_attn_processors = self.attn_processors
299
+
300
+ for module in self.modules():
301
+ if isinstance(module, Attention):
302
+ module.fuse_projections(fuse=True)
303
+
304
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
305
+
306
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
307
+ def unfuse_qkv_projections(self):
308
+ """Disables the fused QKV projection if enabled.
309
+
310
+ <Tip warning={true}>
311
+
312
+ This API is 🧪 experimental.
313
+
314
+ </Tip>
315
+
316
+ """
317
+ if self.original_attn_processors is not None:
318
+ self.set_attn_processor(self.original_attn_processors)
319
+
320
+ @property
321
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
322
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
323
+ r"""
324
+ Returns:
325
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
326
+ indexed by its weight name.
327
+ """
328
+ # set recursively
329
+ processors = {}
330
+
331
+ def fn_recursive_add_processors(
332
+ name: str,
333
+ module: torch.nn.Module,
334
+ processors: Dict[str, AttentionProcessor],
335
+ ):
336
+ if hasattr(module, "get_processor"):
337
+ processors[f"{name}.processor"] = module.get_processor()
338
+
339
+ for sub_name, child in module.named_children():
340
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
341
+
342
+ return processors
343
+
344
+ for name, module in self.named_children():
345
+ fn_recursive_add_processors(name, module, processors)
346
+
347
+ return processors
348
+
349
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
350
+ def set_attn_processor(
351
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
352
+ ):
353
+ r"""
354
+ Sets the attention processor to use to compute attention.
355
+
356
+ Parameters:
357
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
358
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
359
+ for **all** `Attention` layers.
360
+
361
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
362
+ processor. This is strongly recommended when setting trainable attention processors.
363
+
364
+ """
365
+ count = len(self.attn_processors.keys())
366
+
367
+ if isinstance(processor, dict) and len(processor) != count:
368
+ raise ValueError(
369
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
370
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
371
+ )
372
+
373
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
374
+ if hasattr(module, "set_processor"):
375
+ if not isinstance(processor, dict):
376
+ module.set_processor(processor)
377
+ else:
378
+ module.set_processor(processor.pop(f"{name}.processor"))
379
+
380
+ for sub_name, child in module.named_children():
381
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
382
+
383
+ for name, module in self.named_children():
384
+ fn_recursive_attn_processor(name, module, processor)
385
+
386
+ def set_default_attn_processor(self):
387
+ """
388
+ Disables custom attention processors and sets the default attention implementation.
389
+ """
390
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
391
+
392
+ def enable_slicing(self, slicing_length: int = 1) -> None:
393
+ r"""
394
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
395
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
396
+ """
397
+ self.use_slicing = True
398
+ self.slicing_length = slicing_length
399
+
400
+ def disable_slicing(self) -> None:
401
+ r"""
402
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
403
+ decoding in one step.
404
+ """
405
+ self.use_slicing = False
406
+
407
+ def _sample_features(
408
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
409
+ ):
410
+ """
411
+ Sample points from features of the input point cloud.
412
+
413
+ Args:
414
+ x (torch.Tensor): The input point cloud. shape: (B, N, C)
415
+ num_tokens (int, optional): The number of points to sample. Defaults to 2048.
416
+ seed (Optional[int], optional): The random seed. Defaults to None.
417
+ """
418
+ rng = np.random.default_rng(seed)
419
+ indices = rng.choice(
420
+ x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
421
+ )
422
+ selected_points = x[:, indices]
423
+
424
+ batch_size, num_points, num_channels = selected_points.shape
425
+ flattened_points = selected_points.view(batch_size * num_points, num_channels)
426
+ batch_indices = (
427
+ torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
428
+ )
429
+
430
+ # fps sampling
431
+ sampling_ratio = 1.0 / 4
432
+ sampled_indices = fps(
433
+ flattened_points[:, :3],
434
+ batch_indices,
435
+ ratio=sampling_ratio,
436
+ random_start=self.training,
437
+ )
438
+ sampled_points = flattened_points[sampled_indices].view(
439
+ batch_size, -1, num_channels
440
+ )
441
+
442
+ return sampled_points
443
+
444
+ def _encode(
445
+ self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
446
+ ):
447
+ position_channels = self.config.in_channels
448
+ positions, features = x[..., :position_channels], x[..., position_channels:]
449
+ x_kv = torch.cat([self.embedder(positions), features], dim=-1)
450
+
451
+ sampled_x = self._sample_features(x, num_tokens, seed)
452
+ positions, features = (
453
+ sampled_x[..., :position_channels],
454
+ sampled_x[..., position_channels:],
455
+ )
456
+ x_q = torch.cat([self.embedder(positions), features], dim=-1)
457
+
458
+ x = self.encoder(x_q, x_kv)
459
+
460
+ x = self.quant(x)
461
+
462
+ return x
463
+
464
+ @apply_forward_hook
465
+ def encode(
466
+ self, x: torch.Tensor, return_dict: bool = True, **kwargs
467
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
468
+ """
469
+ Encode a batch of point features into latents.
470
+ """
471
+ if self.use_slicing and x.shape[0] > 1:
472
+ encoded_slices = [
473
+ self._encode(x_slice, **kwargs)
474
+ for x_slice in x.split(self.slicing_length)
475
+ ]
476
+ h = torch.cat(encoded_slices)
477
+ else:
478
+ h = self._encode(x, **kwargs)
479
+
480
+ posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
481
+
482
+ if not return_dict:
483
+ return (posterior,)
484
+ return AutoencoderKLOutput(latent_dist=posterior)
485
+
486
+ def _decode(
487
+ self,
488
+ z: torch.Tensor,
489
+ sampled_points: torch.Tensor,
490
+ num_chunks: int = 50000,
491
+ to_cpu: bool = False,
492
+ return_dict: bool = True,
493
+ ) -> Union[DecoderOutput, torch.Tensor]:
494
+ xyz_samples = sampled_points
495
+
496
+ z = self.post_quant(z)
497
+
498
+ num_points = xyz_samples.shape[1]
499
+ kv_cache = None
500
+ dec = []
501
+
502
+ for i in range(0, num_points, num_chunks):
503
+ queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
504
+ queries = self.embedder(queries)
505
+
506
+ z_, kv_cache = self.decoder(z, queries, kv_cache)
507
+ dec.append(z_ if not to_cpu else z_.cpu())
508
+
509
+ z = torch.cat(dec, dim=1)
510
+
511
+ if not return_dict:
512
+ return (z,)
513
+
514
+ return DecoderOutput(sample=z)
515
+
516
+ @apply_forward_hook
517
+ def decode(
518
+ self,
519
+ z: torch.Tensor,
520
+ sampled_points: torch.Tensor,
521
+ return_dict: bool = True,
522
+ **kwargs,
523
+ ) -> Union[DecoderOutput, torch.Tensor]:
524
+ if self.use_slicing and z.shape[0] > 1:
525
+ decoded_slices = [
526
+ self._decode(z_slice, p_slice, **kwargs).sample
527
+ for z_slice, p_slice in zip(
528
+ z.split(self.slicing_length),
529
+ sampled_points.split(self.slicing_length),
530
+ )
531
+ ]
532
+ decoded = torch.cat(decoded_slices)
533
+ else:
534
+ decoded = self._decode(z, sampled_points, **kwargs).sample
535
+
536
+ if not return_dict:
537
+ return (decoded,)
538
+ return DecoderOutput(sample=decoded)
539
+
540
+ def forward(self, x: torch.Tensor):
541
+ pass
midi/models/autoencoders/vae.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(
10
+ self,
11
+ parameters: torch.Tensor,
12
+ deterministic: bool = False,
13
+ feature_dim: int = 1,
14
+ ):
15
+ self.parameters = parameters
16
+ self.feature_dim = feature_dim
17
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
18
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
19
+ self.deterministic = deterministic
20
+ self.std = torch.exp(0.5 * self.logvar)
21
+ self.var = torch.exp(self.logvar)
22
+ if self.deterministic:
23
+ self.var = self.std = torch.zeros_like(
24
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
25
+ )
26
+
27
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
28
+ # make sure sample is on the same device as the parameters and has same dtype
29
+ sample = randn_tensor(
30
+ self.mean.shape,
31
+ generator=generator,
32
+ device=self.parameters.device,
33
+ dtype=self.parameters.dtype,
34
+ )
35
+ x = self.mean + self.std * sample
36
+ return x
37
+
38
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
39
+ if self.deterministic:
40
+ return torch.Tensor([0.0])
41
+ else:
42
+ if other is None:
43
+ return 0.5 * torch.sum(
44
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
45
+ dim=[1, 2, 3],
46
+ )
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var
51
+ - 1.0
52
+ - self.logvar
53
+ + other.logvar,
54
+ dim=[1, 2, 3],
55
+ )
56
+
57
+ def nll(
58
+ self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
59
+ ) -> torch.Tensor:
60
+ if self.deterministic:
61
+ return torch.Tensor([0.0])
62
+ logtwopi = np.log(2.0 * np.pi)
63
+ return 0.5 * torch.sum(
64
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
65
+ dim=dims,
66
+ )
67
+
68
+ def mode(self) -> torch.Tensor:
69
+ return self.mean
midi/models/embeddings.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class FrequencyPositionalEmbedding(nn.Module):
6
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
7
+ each feature dimension of `x[..., i]` into:
8
+ [
9
+ sin(x[..., i]),
10
+ sin(f_1*x[..., i]),
11
+ sin(f_2*x[..., i]),
12
+ ...
13
+ sin(f_N * x[..., i]),
14
+ cos(x[..., i]),
15
+ cos(f_1*x[..., i]),
16
+ cos(f_2*x[..., i]),
17
+ ...
18
+ cos(f_N * x[..., i]),
19
+ x[..., i] # only present if include_input is True.
20
+ ], here f_i is the frequency.
21
+
22
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
23
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
24
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
25
+
26
+ Args:
27
+ num_freqs (int): the number of frequencies, default is 6;
28
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
29
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
30
+ input_dim (int): the input dimension, default is 3;
31
+ include_input (bool): include the input tensor or not, default is True.
32
+
33
+ Attributes:
34
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
36
+
37
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
38
+ otherwise, it is input_dim * num_freqs * 2.
39
+
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ num_freqs: int = 6,
45
+ logspace: bool = True,
46
+ input_dim: int = 3,
47
+ include_input: bool = True,
48
+ include_pi: bool = True,
49
+ ) -> None:
50
+ """The initialization"""
51
+
52
+ super().__init__()
53
+
54
+ if logspace:
55
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
56
+ else:
57
+ frequencies = torch.linspace(
58
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
59
+ )
60
+
61
+ if include_pi:
62
+ frequencies *= torch.pi
63
+
64
+ self.register_buffer("frequencies", frequencies, persistent=False)
65
+ self.include_input = include_input
66
+ self.num_freqs = num_freqs
67
+
68
+ self.out_dim = self.get_dims(input_dim)
69
+
70
+ def get_dims(self, input_dim):
71
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
72
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
73
+
74
+ return out_dim
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ """Forward process.
78
+
79
+ Args:
80
+ x: tensor of shape [..., dim]
81
+
82
+ Returns:
83
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
84
+ where temp is 1 if include_input is True and 0 otherwise.
85
+ """
86
+
87
+ if self.num_freqs > 0:
88
+ embed = (x[..., None].contiguous() * self.frequencies).view(
89
+ *x.shape[:-1], -1
90
+ )
91
+ if self.include_input:
92
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
93
+ else:
94
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
95
+ else:
96
+ return x
midi/models/transformers/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .triposg_transformer import TripoSGDiTModel
4
+
5
+
6
+ def default_set_attn_proc_func(
7
+ name: str,
8
+ hidden_size: int,
9
+ cross_attention_dim: Optional[int],
10
+ ori_attn_proc: object,
11
+ ) -> object:
12
+ return ori_attn_proc
13
+
14
+
15
+ def set_transformer_attn_processor(
16
+ transformer: TripoSGDiTModel,
17
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
19
+ set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
20
+ set_self_attn_module_names: Optional[list[str]] = None,
21
+ set_cross_attn_1_module_names: Optional[list[str]] = None,
22
+ set_cross_attn_2_module_names: Optional[list[str]] = None,
23
+ ) -> None:
24
+ do_set_processor = lambda name, module_names: (
25
+ any([name.startswith(module_name) for module_name in module_names])
26
+ if module_names is not None
27
+ else True
28
+ ) # prefix match
29
+
30
+ attn_procs = {}
31
+ for name, attn_processor in transformer.attn_processors.items():
32
+ hidden_size = transformer.config.width
33
+ if name.endswith("attn1.processor"):
34
+ # self attention
35
+ attn_procs[name] = (
36
+ set_self_attn_proc_func(name, hidden_size, None, attn_processor)
37
+ if do_set_processor(name, set_self_attn_module_names)
38
+ else attn_processor
39
+ )
40
+ elif name.endswith("attn2.processor"):
41
+ # cross attention
42
+ cross_attention_dim = transformer.config.cross_attention_dim
43
+ attn_procs[name] = (
44
+ set_cross_attn_1_proc_func(
45
+ name, hidden_size, cross_attention_dim, attn_processor
46
+ )
47
+ if do_set_processor(name, set_cross_attn_1_module_names)
48
+ else attn_processor
49
+ )
50
+ elif name.endswith("attn2_2.processor"):
51
+ # cross attention 2
52
+ cross_attention_dim = transformer.config.cross_attention_2_dim
53
+ attn_procs[name] = (
54
+ set_cross_attn_2_proc_func(
55
+ name, hidden_size, cross_attention_dim, attn_processor
56
+ )
57
+ if do_set_processor(name, set_cross_attn_2_module_names)
58
+ else attn_processor
59
+ )
60
+
61
+ transformer.set_attn_processor(attn_procs)
midi/models/transformers/modeling_outputs.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Transformer1DModelOutput:
8
+ sample: torch.FloatTensor
midi/models/transformers/triposg_transformer.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
19
+ from diffusers.loaders import PeftAdapterMixin
20
+ from diffusers.models.attention import FeedForward
21
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
22
+ from diffusers.models.embeddings import (
23
+ GaussianFourierProjection,
24
+ TimestepEmbedding,
25
+ Timesteps,
26
+ )
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.models.normalization import (
29
+ AdaLayerNormContinuous,
30
+ FP32LayerNorm,
31
+ LayerNorm,
32
+ )
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_version,
36
+ logging,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
41
+ from torch import nn
42
+
43
+ from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
44
+ from .modeling_outputs import Transformer1DModelOutput
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @maybe_allow_in_graph
50
+ class DiTBlock(nn.Module):
51
+ r"""
52
+ Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
53
+ QKNorm
54
+
55
+ Parameters:
56
+ dim (`int`):
57
+ The number of channels in the input and output.
58
+ num_attention_heads (`int`):
59
+ The number of headsto use for multi-head attention.
60
+ cross_attention_dim (`int`,*optional*):
61
+ The size of the encoder_hidden_states vector for cross attention.
62
+ dropout(`float`, *optional*, defaults to 0.0):
63
+ The dropout probability to use.
64
+ activation_fn (`str`,*optional*, defaults to `"geglu"`):
65
+ Activation function to be used in feed-forward. .
66
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
67
+ Whether to use learnable elementwise affine parameters for normalization.
68
+ norm_eps (`float`, *optional*, defaults to 1e-6):
69
+ A small constant added to the denominator in normalization layers to prevent division by zero.
70
+ final_dropout (`bool` *optional*, defaults to False):
71
+ Whether to apply a final dropout after the last feed-forward layer.
72
+ ff_inner_dim (`int`, *optional*):
73
+ The size of the hidden layer in the feed-forward block. Defaults to `None`.
74
+ ff_bias (`bool`, *optional*, defaults to `True`):
75
+ Whether to use bias in the feed-forward block.
76
+ skip (`bool`, *optional*, defaults to `False`):
77
+ Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
78
+ qk_norm (`bool`, *optional*, defaults to `True`):
79
+ Whether to use normalization in QK calculation. Defaults to `True`.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ num_attention_heads: int,
86
+ use_self_attention: bool = True,
87
+ use_cross_attention: bool = False,
88
+ self_attention_norm_type: Optional[str] = None, # ada layer norm
89
+ cross_attention_dim: Optional[int] = None,
90
+ cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
91
+ # parallel second cross attention
92
+ use_cross_attention_2: bool = False,
93
+ cross_attention_2_dim: Optional[int] = None,
94
+ cross_attention_2_norm_type: Optional[str] = None,
95
+ dropout=0.0,
96
+ activation_fn: str = "gelu",
97
+ norm_type: str = "fp32_layer_norm", # TODO
98
+ norm_elementwise_affine: bool = True,
99
+ norm_eps: float = 1e-5,
100
+ final_dropout: bool = False,
101
+ ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
102
+ ff_bias: bool = True,
103
+ skip: bool = False,
104
+ skip_concat_front: bool = False, # [x, skip] or [skip, x]
105
+ skip_norm_last: bool = False, # this is an error
106
+ qk_norm: bool = True,
107
+ qkv_bias: bool = True,
108
+ ):
109
+ super().__init__()
110
+
111
+ self.use_self_attention = use_self_attention
112
+ self.use_cross_attention = use_cross_attention
113
+ self.use_cross_attention_2 = use_cross_attention_2
114
+ self.skip_concat_front = skip_concat_front
115
+ self.skip_norm_last = skip_norm_last
116
+ # Define 3 blocks. Each block has its own normalization layer.
117
+ # NOTE: when new version comes, check norm2 and norm 3
118
+ # 1. Self-Attn
119
+ if use_self_attention:
120
+ if (
121
+ self_attention_norm_type == "fp32_layer_norm"
122
+ or self_attention_norm_type is None
123
+ ):
124
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
125
+ else:
126
+ raise NotImplementedError
127
+
128
+ self.attn1 = Attention(
129
+ query_dim=dim,
130
+ cross_attention_dim=None,
131
+ dim_head=dim // num_attention_heads,
132
+ heads=num_attention_heads,
133
+ qk_norm="rms_norm" if qk_norm else None,
134
+ eps=1e-6,
135
+ bias=qkv_bias,
136
+ processor=TripoSGAttnProcessor2_0(),
137
+ )
138
+
139
+ # 2. Cross-Attn
140
+ if use_cross_attention:
141
+ assert cross_attention_dim is not None
142
+
143
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
144
+
145
+ self.attn2 = Attention(
146
+ query_dim=dim,
147
+ cross_attention_dim=cross_attention_dim,
148
+ dim_head=dim // num_attention_heads,
149
+ heads=num_attention_heads,
150
+ qk_norm="rms_norm" if qk_norm else None,
151
+ cross_attention_norm=cross_attention_norm_type,
152
+ eps=1e-6,
153
+ bias=qkv_bias,
154
+ processor=TripoSGAttnProcessor2_0(),
155
+ )
156
+
157
+ # 2'. Parallel Second Cross-Attn
158
+ if use_cross_attention_2:
159
+ assert cross_attention_2_dim is not None
160
+
161
+ self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
162
+
163
+ self.attn2_2 = Attention(
164
+ query_dim=dim,
165
+ cross_attention_dim=cross_attention_2_dim,
166
+ dim_head=dim // num_attention_heads,
167
+ heads=num_attention_heads,
168
+ qk_norm="rms_norm" if qk_norm else None,
169
+ cross_attention_norm=cross_attention_2_norm_type,
170
+ eps=1e-6,
171
+ bias=qkv_bias,
172
+ processor=TripoSGAttnProcessor2_0(),
173
+ )
174
+
175
+ # 3. Feed-forward
176
+ self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
177
+
178
+ self.ff = FeedForward(
179
+ dim,
180
+ dropout=dropout, ### 0.0
181
+ activation_fn=activation_fn, ### approx GeLU
182
+ final_dropout=final_dropout, ### 0.0
183
+ inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
184
+ bias=ff_bias,
185
+ )
186
+
187
+ # 4. Skip Connection
188
+ if skip:
189
+ self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
190
+ self.skip_linear = nn.Linear(2 * dim, dim)
191
+ else:
192
+ self.skip_linear = None
193
+
194
+ # let chunk size default to None
195
+ self._chunk_size = None
196
+ self._chunk_dim = 0
197
+
198
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
199
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
200
+ # Sets chunk feed-forward
201
+ self._chunk_size = chunk_size
202
+ self._chunk_dim = dim
203
+
204
+ def forward(
205
+ self,
206
+ hidden_states: torch.Tensor,
207
+ encoder_hidden_states: Optional[torch.Tensor] = None,
208
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
209
+ temb: Optional[torch.Tensor] = None,
210
+ image_rotary_emb: Optional[torch.Tensor] = None,
211
+ skip: Optional[torch.Tensor] = None,
212
+ attention_kwargs: Optional[Dict[str, Any]] = None,
213
+ ) -> torch.Tensor:
214
+ # Prepare attention kwargs
215
+ attention_kwargs = attention_kwargs or {}
216
+
217
+ # Notice that normalization is always applied before the real computation in the following blocks.
218
+ # 0. Long Skip Connection
219
+ if self.skip_linear is not None:
220
+ cat = torch.cat(
221
+ (
222
+ [skip, hidden_states]
223
+ if self.skip_concat_front
224
+ else [hidden_states, skip]
225
+ ),
226
+ dim=-1,
227
+ )
228
+ if self.skip_norm_last:
229
+ # don't do this
230
+ hidden_states = self.skip_linear(cat)
231
+ hidden_states = self.skip_norm(hidden_states)
232
+ else:
233
+ cat = self.skip_norm(cat)
234
+ hidden_states = self.skip_linear(cat)
235
+
236
+ # 1. Self-Attention
237
+ if self.use_self_attention:
238
+ norm_hidden_states = self.norm1(hidden_states)
239
+ attn_output = self.attn1(
240
+ norm_hidden_states,
241
+ image_rotary_emb=image_rotary_emb,
242
+ **attention_kwargs,
243
+ )
244
+ hidden_states = hidden_states + attn_output
245
+
246
+ # 2. Cross-Attention
247
+ if self.use_cross_attention:
248
+ if self.use_cross_attention_2:
249
+ hidden_states = (
250
+ hidden_states
251
+ + self.attn2(
252
+ self.norm2(hidden_states),
253
+ encoder_hidden_states=encoder_hidden_states,
254
+ image_rotary_emb=image_rotary_emb,
255
+ **attention_kwargs,
256
+ )
257
+ + self.attn2_2(
258
+ self.norm2_2(hidden_states),
259
+ encoder_hidden_states=encoder_hidden_states_2,
260
+ image_rotary_emb=image_rotary_emb,
261
+ **attention_kwargs,
262
+ )
263
+ )
264
+ else:
265
+ hidden_states = hidden_states + self.attn2(
266
+ self.norm2(hidden_states),
267
+ encoder_hidden_states=encoder_hidden_states,
268
+ image_rotary_emb=image_rotary_emb,
269
+ **attention_kwargs,
270
+ )
271
+
272
+ # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
273
+ mlp_inputs = self.norm3(hidden_states)
274
+ hidden_states = hidden_states + self.ff(mlp_inputs)
275
+
276
+ return hidden_states
277
+
278
+
279
+ class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
280
+ """
281
+ TripoSG: Diffusion model with a Transformer backbone.
282
+
283
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
284
+
285
+ Parameters:
286
+ num_attention_heads (`int`, *optional*, defaults to 16):
287
+ The number of heads to use for multi-head attention.
288
+ attention_head_dim (`int`, *optional*, defaults to 88):
289
+ The number of channels in each head.
290
+ in_channels (`int`, *optional*):
291
+ The number of channels in the input and output (specify if the input is **continuous**).
292
+ patch_size (`int`, *optional*):
293
+ The size of the patch to use for the input.
294
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
295
+ Activation function to use in feed-forward.
296
+ sample_size (`int`, *optional*):
297
+ The width of the latent images. This is fixed during training since it is used to learn a number of
298
+ position embeddings.
299
+ dropout (`float`, *optional*, defaults to 0.0):
300
+ The dropout probability to use.
301
+ cross_attention_dim (`int`, *optional*):
302
+ The number of dimension in the clip text embedding.
303
+ hidden_size (`int`, *optional*):
304
+ The size of hidden layer in the conditioning embedding layers.
305
+ num_layers (`int`, *optional*, defaults to 1):
306
+ The number of layers of Transformer blocks to use.
307
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
308
+ The ratio of the hidden layer size to the input size.
309
+ learn_sigma (`bool`, *optional*, defaults to `True`):
310
+ Whether to predict variance.
311
+ cross_attention_dim_t5 (`int`, *optional*):
312
+ The number dimensions in t5 text embedding.
313
+ pooled_projection_dim (`int`, *optional*):
314
+ The size of the pooled projection.
315
+ text_len (`int`, *optional*):
316
+ The length of the clip text embedding.
317
+ text_len_t5 (`int`, *optional*):
318
+ The length of the T5 text embedding.
319
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
320
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
321
+ """
322
+
323
+ _supports_gradient_checkpointing = True
324
+
325
+ @register_to_config
326
+ def __init__(
327
+ self,
328
+ num_attention_heads: int = 16,
329
+ width: int = 2048,
330
+ in_channels: int = 64,
331
+ num_layers: int = 21,
332
+ cross_attention_dim: int = 768,
333
+ cross_attention_2_dim: int = 1024,
334
+ ):
335
+ super().__init__()
336
+ self.out_channels = in_channels
337
+ self.num_heads = num_attention_heads
338
+ self.inner_dim = width
339
+ self.mlp_ratio = 4.0
340
+
341
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
342
+ "positional",
343
+ inner_dim=self.inner_dim,
344
+ flip_sin_to_cos=False,
345
+ freq_shift=0,
346
+ time_embedding_dim=None,
347
+ )
348
+ self.time_proj = TimestepEmbedding(
349
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
350
+ )
351
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
352
+
353
+ self.blocks = nn.ModuleList(
354
+ [
355
+ DiTBlock(
356
+ dim=self.inner_dim,
357
+ num_attention_heads=self.config.num_attention_heads,
358
+ use_self_attention=True,
359
+ use_cross_attention=True,
360
+ self_attention_norm_type="fp32_layer_norm",
361
+ cross_attention_dim=self.config.cross_attention_dim,
362
+ cross_attention_norm_type=None,
363
+ use_cross_attention_2=True,
364
+ cross_attention_2_dim=self.config.cross_attention_2_dim,
365
+ cross_attention_2_norm_type=None,
366
+ activation_fn="gelu",
367
+ norm_type="fp32_layer_norm", # TODO
368
+ norm_eps=1e-5,
369
+ ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
370
+ skip=layer > num_layers // 2,
371
+ skip_concat_front=True,
372
+ skip_norm_last=True, # this is an error
373
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
374
+ qkv_bias=False,
375
+ )
376
+ for layer in range(num_layers)
377
+ ]
378
+ )
379
+
380
+ self.norm_out = LayerNorm(self.inner_dim)
381
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
382
+
383
+ self.gradient_checkpointing = False
384
+
385
+ def _set_gradient_checkpointing(self, module, value=False):
386
+ self.gradient_checkpointing = value
387
+
388
+ def _set_time_proj(
389
+ self,
390
+ time_embedding_type: str,
391
+ inner_dim: int,
392
+ flip_sin_to_cos: bool,
393
+ freq_shift: float,
394
+ time_embedding_dim: int,
395
+ ) -> Tuple[int, int]:
396
+ if time_embedding_type == "fourier":
397
+ time_embed_dim = time_embedding_dim or inner_dim * 2
398
+ if time_embed_dim % 2 != 0:
399
+ raise ValueError(
400
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
401
+ )
402
+ self.time_embed = GaussianFourierProjection(
403
+ time_embed_dim // 2,
404
+ set_W_to_weight=False,
405
+ log=False,
406
+ flip_sin_to_cos=flip_sin_to_cos,
407
+ )
408
+ timestep_input_dim = time_embed_dim
409
+ elif time_embedding_type == "positional":
410
+ time_embed_dim = time_embedding_dim or inner_dim * 4
411
+
412
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
413
+ timestep_input_dim = inner_dim
414
+ else:
415
+ raise ValueError(
416
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
417
+ )
418
+
419
+ return time_embed_dim, timestep_input_dim
420
+
421
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
422
+ def fuse_qkv_projections(self):
423
+ """
424
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
425
+ are fused. For cross-attention modules, key and value projection matrices are fused.
426
+
427
+ <Tip warning={true}>
428
+
429
+ This API is 🧪 experimental.
430
+
431
+ </Tip>
432
+ """
433
+ self.original_attn_processors = None
434
+
435
+ for _, attn_processor in self.attn_processors.items():
436
+ if "Added" in str(attn_processor.__class__.__name__):
437
+ raise ValueError(
438
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
439
+ )
440
+
441
+ self.original_attn_processors = self.attn_processors
442
+
443
+ for module in self.modules():
444
+ if isinstance(module, Attention):
445
+ module.fuse_projections(fuse=True)
446
+
447
+ self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
448
+
449
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
450
+ def unfuse_qkv_projections(self):
451
+ """Disables the fused QKV projection if enabled.
452
+
453
+ <Tip warning={true}>
454
+
455
+ This API is 🧪 experimental.
456
+
457
+ </Tip>
458
+
459
+ """
460
+ if self.original_attn_processors is not None:
461
+ self.set_attn_processor(self.original_attn_processors)
462
+
463
+ @property
464
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
465
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
466
+ r"""
467
+ Returns:
468
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
469
+ indexed by its weight name.
470
+ """
471
+ # set recursively
472
+ processors = {}
473
+
474
+ def fn_recursive_add_processors(
475
+ name: str,
476
+ module: torch.nn.Module,
477
+ processors: Dict[str, AttentionProcessor],
478
+ ):
479
+ if hasattr(module, "get_processor"):
480
+ processors[f"{name}.processor"] = module.get_processor()
481
+
482
+ for sub_name, child in module.named_children():
483
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
484
+
485
+ return processors
486
+
487
+ for name, module in self.named_children():
488
+ fn_recursive_add_processors(name, module, processors)
489
+
490
+ return processors
491
+
492
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
493
+ def set_attn_processor(
494
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
495
+ ):
496
+ r"""
497
+ Sets the attention processor to use to compute attention.
498
+
499
+ Parameters:
500
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
501
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
502
+ for **all** `Attention` layers.
503
+
504
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
505
+ processor. This is strongly recommended when setting trainable attention processors.
506
+
507
+ """
508
+ count = len(self.attn_processors.keys())
509
+
510
+ if isinstance(processor, dict) and len(processor) != count:
511
+ raise ValueError(
512
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
513
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
514
+ )
515
+
516
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
517
+ if hasattr(module, "set_processor"):
518
+ if not isinstance(processor, dict):
519
+ module.set_processor(processor)
520
+ else:
521
+ module.set_processor(processor.pop(f"{name}.processor"))
522
+
523
+ for sub_name, child in module.named_children():
524
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
525
+
526
+ for name, module in self.named_children():
527
+ fn_recursive_attn_processor(name, module, processor)
528
+
529
+ def set_default_attn_processor(self):
530
+ """
531
+ Disables custom attention processors and sets the default attention implementation.
532
+ """
533
+ self.set_attn_processor(TripoSGAttnProcessor2_0())
534
+
535
+ def forward(
536
+ self,
537
+ hidden_states: Optional[torch.Tensor],
538
+ timestep: Union[int, float, torch.LongTensor],
539
+ encoder_hidden_states: Optional[torch.Tensor] = None,
540
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
541
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
542
+ attention_kwargs: Optional[Dict[str, Any]] = None,
543
+ return_dict: bool = True,
544
+ ):
545
+ """
546
+ The [`HunyuanDiT2DModel`] forward method.
547
+
548
+ Args:
549
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
550
+ The input tensor.
551
+ timestep ( `torch.LongTensor`, *optional*):
552
+ Used to indicate denoising step.
553
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
554
+ Conditional embeddings for cross attention layer.
555
+ encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
556
+ Conditional embeddings for cross attention layer.
557
+ return_dict: bool
558
+ Whether to return a dictionary.
559
+ """
560
+
561
+ if attention_kwargs is not None:
562
+ attention_kwargs = attention_kwargs.copy()
563
+ lora_scale = attention_kwargs.pop("scale", 1.0)
564
+ else:
565
+ lora_scale = 1.0
566
+
567
+ if USE_PEFT_BACKEND:
568
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
569
+ scale_lora_layers(self, lora_scale)
570
+ else:
571
+ if (
572
+ attention_kwargs is not None
573
+ and attention_kwargs.get("scale", None) is not None
574
+ ):
575
+ logger.warning(
576
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
577
+ )
578
+
579
+ _, N, _ = hidden_states.shape
580
+
581
+ temb = self.time_embed(timestep).to(hidden_states.dtype)
582
+ temb = self.time_proj(temb)
583
+ temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
584
+
585
+ hidden_states = self.proj_in(hidden_states)
586
+
587
+ # N + 1 token
588
+ hidden_states = torch.cat([temb, hidden_states], dim=1)
589
+
590
+ skips = []
591
+ for layer, block in enumerate(self.blocks):
592
+ skip = None if layer <= self.config.num_layers // 2 else skips.pop()
593
+
594
+ if self.training and self.gradient_checkpointing:
595
+
596
+ def create_custom_forward(module):
597
+ def custom_forward(*inputs):
598
+ return module(*inputs)
599
+
600
+ return custom_forward
601
+
602
+ ckpt_kwargs: Dict[str, Any] = (
603
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
604
+ )
605
+ hidden_states = torch.utils.checkpoint.checkpoint(
606
+ create_custom_forward(block),
607
+ hidden_states,
608
+ encoder_hidden_states,
609
+ encoder_hidden_states_2,
610
+ temb,
611
+ image_rotary_emb,
612
+ skip,
613
+ attention_kwargs,
614
+ **ckpt_kwargs,
615
+ )
616
+ else:
617
+ hidden_states = block(
618
+ hidden_states,
619
+ encoder_hidden_states=encoder_hidden_states,
620
+ encoder_hidden_states_2=encoder_hidden_states_2,
621
+ temb=temb,
622
+ image_rotary_emb=image_rotary_emb,
623
+ skip=skip,
624
+ attention_kwargs=attention_kwargs,
625
+ ) # (N, L, D)
626
+
627
+ if layer < self.config.num_layers // 2:
628
+ skips.append(hidden_states)
629
+
630
+ # final layer
631
+ hidden_states = self.norm_out(hidden_states)
632
+ hidden_states = hidden_states[:, -N:]
633
+ hidden_states = self.proj_out(hidden_states)
634
+
635
+ if USE_PEFT_BACKEND:
636
+ # remove `lora_scale` from each PEFT layer
637
+ unscale_lora_layers(self, lora_scale)
638
+
639
+ if not return_dict:
640
+ return (hidden_states,)
641
+
642
+ return Transformer1DModelOutput(sample=hidden_states)
643
+
644
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
645
+ def enable_forward_chunking(
646
+ self, chunk_size: Optional[int] = None, dim: int = 0
647
+ ) -> None:
648
+ """
649
+ Sets the attention processor to use [feed forward
650
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
651
+
652
+ Parameters:
653
+ chunk_size (`int`, *optional*):
654
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
655
+ over each tensor of dim=`dim`.
656
+ dim (`int`, *optional*, defaults to `0`):
657
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
658
+ or dim=1 (sequence length).
659
+ """
660
+ if dim not in [0, 1]:
661
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
662
+
663
+ # By default chunk size is 1
664
+ chunk_size = chunk_size or 1
665
+
666
+ def fn_recursive_feed_forward(
667
+ module: torch.nn.Module, chunk_size: int, dim: int
668
+ ):
669
+ if hasattr(module, "set_chunk_feed_forward"):
670
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
671
+
672
+ for child in module.children():
673
+ fn_recursive_feed_forward(child, chunk_size, dim)
674
+
675
+ for module in self.children():
676
+ fn_recursive_feed_forward(module, chunk_size, dim)
677
+
678
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
679
+ def disable_forward_chunking(self):
680
+ def fn_recursive_feed_forward(
681
+ module: torch.nn.Module, chunk_size: int, dim: int
682
+ ):
683
+ if hasattr(module, "set_chunk_feed_forward"):
684
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
685
+
686
+ for child in module.children():
687
+ fn_recursive_feed_forward(child, chunk_size, dim)
688
+
689
+ for module in self.children():
690
+ fn_recursive_feed_forward(module, None, 0)
midi/pipelines/pipeline_midi.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from diffusers.image_processor import PipelineImageInput
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from peft import LoraConfig, get_peft_model_state_dict
16
+ from transformers import (
17
+ BitImageProcessor,
18
+ CLIPImageProcessor,
19
+ CLIPVisionModelWithProjection,
20
+ Dinov2Model,
21
+ )
22
+
23
+ from ..inference_utils import generate_dense_grid_points
24
+ from ..loaders import CustomAdapterMixin
25
+ from ..models.attention_processor import MIAttnProcessor2_0
26
+ from ..models.autoencoders import TripoSGVAEModel
27
+ from ..models.transformers import TripoSGDiTModel, set_transformer_attn_processor
28
+ from .pipeline_triposg_output import TripoSGPipelineOutput
29
+ from .pipeline_utils import TransformerDiffusionMixin
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
35
+ def retrieve_timesteps(
36
+ scheduler,
37
+ num_inference_steps: Optional[int] = None,
38
+ device: Optional[Union[str, torch.device]] = None,
39
+ timesteps: Optional[List[int]] = None,
40
+ sigmas: Optional[List[float]] = None,
41
+ **kwargs,
42
+ ):
43
+ """
44
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
45
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
46
+
47
+ Args:
48
+ scheduler (`SchedulerMixin`):
49
+ The scheduler to get timesteps from.
50
+ num_inference_steps (`int`):
51
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
52
+ must be `None`.
53
+ device (`str` or `torch.device`, *optional*):
54
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
55
+ timesteps (`List[int]`, *optional*):
56
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
57
+ `num_inference_steps` and `sigmas` must be `None`.
58
+ sigmas (`List[float]`, *optional*):
59
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
60
+ `num_inference_steps` and `timesteps` must be `None`.
61
+
62
+ Returns:
63
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
64
+ second element is the number of inference steps.
65
+ """
66
+ if timesteps is not None and sigmas is not None:
67
+ raise ValueError(
68
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
69
+ )
70
+ if timesteps is not None:
71
+ accepts_timesteps = "timesteps" in set(
72
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
73
+ )
74
+ if not accepts_timesteps:
75
+ raise ValueError(
76
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
77
+ f" timestep schedules. Please check whether you are using the correct scheduler."
78
+ )
79
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
80
+ timesteps = scheduler.timesteps
81
+ num_inference_steps = len(timesteps)
82
+ elif sigmas is not None:
83
+ accept_sigmas = "sigmas" in set(
84
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
85
+ )
86
+ if not accept_sigmas:
87
+ raise ValueError(
88
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
89
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
90
+ )
91
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
92
+ timesteps = scheduler.timesteps
93
+ num_inference_steps = len(timesteps)
94
+ else:
95
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
96
+ timesteps = scheduler.timesteps
97
+ return timesteps, num_inference_steps
98
+
99
+
100
+ class MIDIPipeline(DiffusionPipeline, TransformerDiffusionMixin, CustomAdapterMixin):
101
+ """
102
+ Pipeline for image-to-scene generation based on pre-trained shape diffusion.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ vae: TripoSGVAEModel,
108
+ transformer: TripoSGDiTModel,
109
+ scheduler: FlowMatchEulerDiscreteScheduler,
110
+ image_encoder_1: CLIPVisionModelWithProjection,
111
+ image_encoder_2: Dinov2Model,
112
+ feature_extractor_1: CLIPImageProcessor,
113
+ feature_extractor_2: BitImageProcessor,
114
+ ):
115
+ super().__init__()
116
+
117
+ self.register_modules(
118
+ vae=vae,
119
+ transformer=transformer,
120
+ scheduler=scheduler,
121
+ image_encoder_1=image_encoder_1,
122
+ image_encoder_2=image_encoder_2,
123
+ feature_extractor_1=feature_extractor_1,
124
+ feature_extractor_2=feature_extractor_2,
125
+ )
126
+
127
+ @property
128
+ def guidance_scale(self):
129
+ return self._guidance_scale
130
+
131
+ @property
132
+ def do_classifier_free_guidance(self):
133
+ return self._guidance_scale > 1
134
+
135
+ @property
136
+ def num_timesteps(self):
137
+ return self._num_timesteps
138
+
139
+ @property
140
+ def attention_kwargs(self):
141
+ return self._attention_kwargs
142
+
143
+ @property
144
+ def interrupt(self):
145
+ return self._interrupt
146
+
147
+ @property
148
+ def decode_progressive(self):
149
+ return self._decode_progressive
150
+
151
+ def encode_image_1(self, image, device, num_images_per_prompt):
152
+ dtype = next(self.image_encoder_1.parameters()).dtype
153
+
154
+ if not isinstance(image, torch.Tensor):
155
+ image = self.feature_extractor_1(image, return_tensors="pt").pixel_values
156
+
157
+ image = image.to(device=device, dtype=dtype)
158
+ image_embeds = self.image_encoder_1(image).image_embeds
159
+ image_embeds = image_embeds.repeat_interleave(
160
+ num_images_per_prompt, dim=0
161
+ ).unsqueeze(1)
162
+ uncond_image_embeds = torch.zeros_like(image_embeds)
163
+
164
+ return image_embeds, uncond_image_embeds
165
+
166
+ def encode_image_2(
167
+ self,
168
+ image_one,
169
+ image_two,
170
+ mask,
171
+ device,
172
+ num_images_per_prompt,
173
+ ):
174
+ dtype = next(self.image_encoder_2.parameters()).dtype
175
+
176
+ images = [image_one, image_two, mask]
177
+ images_new = []
178
+ for i, image in enumerate(images):
179
+ if not isinstance(image, torch.Tensor):
180
+ if i <= 1:
181
+ images_new.append(
182
+ self.feature_extractor_2(
183
+ image, return_tensors="pt"
184
+ ).pixel_values
185
+ )
186
+ else:
187
+ image = [
188
+ torch.from_numpy(
189
+ (np.array(im) / 255.0).astype(np.float32)
190
+ ).unsqueeze(0)
191
+ for im in image
192
+ ]
193
+ image = torch.stack(image, dim=0)
194
+ images_new.append(
195
+ F.interpolate(
196
+ image, size=images_new[0].shape[-2:], mode="nearest"
197
+ )
198
+ )
199
+
200
+ image = torch.cat(images_new, dim=1).to(device=device, dtype=dtype)
201
+ image_embeds = self.image_encoder_2(image).last_hidden_state
202
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
203
+ uncond_image_embeds = torch.zeros_like(image_embeds)
204
+
205
+ return image_embeds, uncond_image_embeds
206
+
207
+ def prepare_latents(
208
+ self,
209
+ batch_size,
210
+ num_tokens,
211
+ num_channels_latents,
212
+ dtype,
213
+ device,
214
+ generator,
215
+ latents: Optional[torch.Tensor] = None,
216
+ ):
217
+ if latents is not None:
218
+ return latents.to(device=device, dtype=dtype)
219
+
220
+ shape = (batch_size, num_tokens, num_channels_latents)
221
+
222
+ if isinstance(generator, list) and len(generator) != batch_size:
223
+ raise ValueError(
224
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
225
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
226
+ )
227
+
228
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
229
+
230
+ return latents
231
+
232
+ @torch.no_grad()
233
+ def decode_latents(
234
+ self,
235
+ latents: torch.Tensor,
236
+ sampled_points: torch.Tensor,
237
+ decode_progressive: bool = False,
238
+ decode_to_cpu: bool = False,
239
+ # Params for sampling points
240
+ bbox_min: np.ndarray = np.array([-1.005, -1.005, -1.005]),
241
+ bbox_max: np.ndarray = np.array([1.005, 1.005, 1.005]),
242
+ octree_depth: int = 8,
243
+ indexing: str = "ij",
244
+ padding: float = 0.05,
245
+ ):
246
+ device, dtype = latents.device, latents.dtype
247
+ batch_size = latents.shape[0]
248
+
249
+ grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], []
250
+
251
+ if sampled_points is None:
252
+ sampled_points, grid_size, bbox_size = generate_dense_grid_points(
253
+ bbox_min, bbox_max, octree_depth, indexing
254
+ )
255
+ sampled_points = torch.FloatTensor(sampled_points).to(
256
+ device=device, dtype=dtype
257
+ )
258
+ sampled_points = sampled_points.unsqueeze(0).expand(batch_size, -1, -1)
259
+
260
+ grid_sizes.append(grid_size)
261
+ bbox_sizes.append(bbox_size)
262
+ bbox_mins.append(bbox_min)
263
+ bbox_maxs.append(bbox_max)
264
+
265
+ self.vae: TripoSGVAEModel
266
+ output = self.vae.decode(
267
+ latents, sampled_points=sampled_points, to_cpu=decode_to_cpu
268
+ ).sample
269
+
270
+ if not decode_progressive:
271
+ return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
272
+
273
+ grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], []
274
+ sampled_points_list = []
275
+
276
+ for i in range(batch_size):
277
+ sdf_ = output[i].squeeze(-1) # [num_points]
278
+ sampled_points_ = sampled_points[i]
279
+ occupied_points = sampled_points_[sdf_ <= 0] # [num_occupied_points, 3]
280
+
281
+ if occupied_points.shape[0] == 0:
282
+ logger.warning(
283
+ f"No occupied points found in batch {i}. Using original bounding box."
284
+ )
285
+ else:
286
+ bbox_min = occupied_points.min(dim=0).values
287
+ bbox_max = occupied_points.max(dim=0).values
288
+ bbox_min = (bbox_min - padding).float().cpu().numpy()
289
+ bbox_max = (bbox_max + padding).float().cpu().numpy()
290
+
291
+ sampled_points_, grid_size, bbox_size = generate_dense_grid_points(
292
+ bbox_min, bbox_max, octree_depth, indexing
293
+ )
294
+ sampled_points_ = torch.FloatTensor(sampled_points_).to(
295
+ device=device, dtype=dtype
296
+ )
297
+ sampled_points_list.append(sampled_points_)
298
+
299
+ grid_sizes.append(grid_size)
300
+ bbox_sizes.append(bbox_size)
301
+ bbox_mins.append(bbox_min)
302
+ bbox_maxs.append(bbox_max)
303
+
304
+ sampled_points = torch.stack(sampled_points_list, dim=0)
305
+
306
+ # Re-decode the new sampled points
307
+ output = self.vae.decode(
308
+ latents, sampled_points=sampled_points, to_cpu=decode_to_cpu
309
+ ).sample
310
+
311
+ return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
312
+
313
+ @torch.no_grad()
314
+ def __call__(
315
+ self,
316
+ image: PipelineImageInput,
317
+ mask: PipelineImageInput,
318
+ image_scene: PipelineImageInput,
319
+ num_inference_steps: int = 50,
320
+ timesteps: List[int] = None,
321
+ guidance_scale: float = 7.0,
322
+ num_images_per_prompt: int = 1,
323
+ sampled_points: Optional[torch.Tensor] = None,
324
+ decode_progressive: bool = False,
325
+ decode_to_cpu: bool = False,
326
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
327
+ latents: Optional[torch.FloatTensor] = None,
328
+ attention_kwargs: Optional[Dict[str, Any]] = None,
329
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
330
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
331
+ output_type: Optional[str] = "mesh_vf",
332
+ return_dict: bool = True,
333
+ ):
334
+ # 1. Check inputs. Raise error if not correct
335
+ # TODO
336
+
337
+ self._decode_progressive = decode_progressive
338
+ self._guidance_scale = guidance_scale
339
+ self._attention_kwargs = attention_kwargs
340
+ self._interrupt = False
341
+
342
+ # 2. Define call parameters
343
+ if isinstance(image, PIL.Image.Image):
344
+ batch_size = 1
345
+ elif isinstance(image, list):
346
+ batch_size = len(image)
347
+ elif isinstance(image, torch.Tensor):
348
+ batch_size = image.shape[0]
349
+ else:
350
+ raise ValueError("Invalid input type for image")
351
+
352
+ device = self._execution_device
353
+
354
+ # 3. Encode condition
355
+ image_embeds_1, negative_image_embeds_1 = self.encode_image_1(
356
+ image, device, num_images_per_prompt
357
+ )
358
+ image_embeds_2, negative_image_embeds_2 = self.encode_image_2(
359
+ image, image_scene, mask, device, num_images_per_prompt
360
+ )
361
+
362
+ if self.do_classifier_free_guidance:
363
+ image_embeds_1 = torch.cat([negative_image_embeds_1, image_embeds_1], dim=0)
364
+ image_embeds_2 = torch.cat([negative_image_embeds_2, image_embeds_2], dim=0)
365
+
366
+ # 4. Prepare timesteps
367
+ timesteps, num_inference_steps = retrieve_timesteps(
368
+ self.scheduler, num_inference_steps, device, timesteps
369
+ )
370
+ num_warmup_steps = max(
371
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
372
+ )
373
+ self._num_timesteps = len(timesteps)
374
+
375
+ # 5. Prepare latent variables
376
+ num_tokens = self.transformer.config.width
377
+ num_channels_latents = self.transformer.config.in_channels
378
+ latents = self.prepare_latents(
379
+ batch_size * num_images_per_prompt,
380
+ num_tokens,
381
+ num_channels_latents,
382
+ image_embeds_1.dtype,
383
+ device,
384
+ generator,
385
+ latents,
386
+ )
387
+
388
+ # 6. Denoising loop
389
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
390
+ for i, t in enumerate(timesteps):
391
+ if self.interrupt:
392
+ continue
393
+
394
+ # expand the latents if we are doing classifier free guidance
395
+ latent_model_input = (
396
+ torch.cat([latents] * 2)
397
+ if self.do_classifier_free_guidance
398
+ else latents
399
+ )
400
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
401
+ timestep = t.expand(latent_model_input.shape[0])
402
+
403
+ noise_pred = self.transformer(
404
+ latent_model_input,
405
+ timestep,
406
+ encoder_hidden_states=image_embeds_1,
407
+ encoder_hidden_states_2=image_embeds_2,
408
+ attention_kwargs=attention_kwargs,
409
+ return_dict=False,
410
+ )[0]
411
+
412
+ # perform guidance
413
+ if self.do_classifier_free_guidance:
414
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
415
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
416
+ noise_pred_image - noise_pred_uncond
417
+ )
418
+
419
+ # compute the previous noisy sample x_t -> x_t-1
420
+ latents_dtype = latents.dtype
421
+ latents = self.scheduler.step(
422
+ noise_pred, t, latents, return_dict=False
423
+ )[0]
424
+
425
+ if latents.dtype != latents_dtype:
426
+ if torch.backends.mps.is_available():
427
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
428
+ latents = latents.to(latents_dtype)
429
+
430
+ if callback_on_step_end is not None:
431
+ callback_kwargs = {}
432
+ for k in callback_on_step_end_tensor_inputs:
433
+ callback_kwargs[k] = locals()[k]
434
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
435
+
436
+ latents = callback_outputs.pop("latents", latents)
437
+ image_embeds_1 = callback_outputs.pop(
438
+ "image_embeds_1", image_embeds_1
439
+ )
440
+ negative_image_embeds_1 = callback_outputs.pop(
441
+ "negative_image_embeds_1", negative_image_embeds_1
442
+ )
443
+ image_embeds_2 = callback_outputs.pop(
444
+ "image_embeds_2", image_embeds_2
445
+ )
446
+ negative_image_embeds_2 = callback_outputs.pop(
447
+ "negative_image_embeds_2", negative_image_embeds_2
448
+ )
449
+
450
+ # call the callback, if provided
451
+ if i == len(timesteps) - 1 or (
452
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
453
+ ):
454
+ progress_bar.update()
455
+
456
+ grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = None, None, None, None
457
+
458
+ if output_type == "latent":
459
+ output = latents
460
+ else:
461
+ output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = self.decode_latents(
462
+ latents,
463
+ sampled_points=sampled_points,
464
+ decode_progressive=decode_progressive,
465
+ decode_to_cpu=decode_to_cpu,
466
+ )
467
+
468
+ # Offload all models
469
+ self.maybe_free_model_hooks()
470
+
471
+ if not return_dict:
472
+ return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
473
+
474
+ return TripoSGPipelineOutput(
475
+ samples=output,
476
+ grid_sizes=grid_sizes,
477
+ bbox_sizes=bbox_sizes,
478
+ bbox_mins=bbox_mins,
479
+ bbox_maxs=bbox_maxs,
480
+ )
481
+
482
+ def _init_custom_adapter(
483
+ self, set_self_attn_module_names: Optional[List[str]] = None
484
+ ):
485
+ # Set attention processor
486
+ func_default = lambda name, hs, cad, ap: MIAttnProcessor2_0(use_mi=False)
487
+ set_transformer_attn_processor( # avoid warning
488
+ self.transformer,
489
+ set_self_attn_proc_func=func_default,
490
+ set_cross_attn_1_proc_func=func_default,
491
+ set_cross_attn_2_proc_func=func_default,
492
+ )
493
+ set_transformer_attn_processor(
494
+ self.transformer,
495
+ set_self_attn_proc_func=lambda name, hs, cad, ap: MIAttnProcessor2_0(),
496
+ set_self_attn_module_names=set_self_attn_module_names,
497
+ )
midi/pipelines/pipeline_triposg_output.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers.utils import BaseOutput
7
+
8
+ PipelineBoxOutput = Union[
9
+ List[List[int]], # [[257, 257, 257], ...]
10
+ List[List[float]], # [[-1.05, -1.05, -1.05], ...]
11
+ List[np.ndarray],
12
+ ]
13
+
14
+
15
+ @dataclass
16
+ class TripoSGPipelineOutput(BaseOutput):
17
+ r"""
18
+ Output class for TripoSG pipelines.
19
+ """
20
+
21
+ samples: torch.Tensor
22
+ grid_sizes: Optional[PipelineBoxOutput] = None
23
+ bbox_sizes: Optional[PipelineBoxOutput] = None
24
+ bbox_mins: Optional[PipelineBoxOutput] = None
25
+ bbox_maxs: Optional[PipelineBoxOutput] = None
midi/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.utils import logging
2
+
3
+ logger = logging.get_logger(__name__)
4
+
5
+
6
+ class TransformerDiffusionMixin:
7
+ r"""
8
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
9
+ """
10
+
11
+ def enable_vae_slicing(self):
12
+ r"""
13
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
14
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
15
+ """
16
+ self.vae.enable_slicing()
17
+
18
+ def disable_vae_slicing(self):
19
+ r"""
20
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
21
+ computing decoding in one step.
22
+ """
23
+ self.vae.disable_slicing()
24
+
25
+ def enable_vae_tiling(self):
26
+ r"""
27
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
28
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
29
+ processing larger images.
30
+ """
31
+ self.vae.enable_tiling()
32
+
33
+ def disable_vae_tiling(self):
34
+ r"""
35
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
36
+ computing decoding in one step.
37
+ """
38
+ self.vae.disable_tiling()
39
+
40
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
41
+ """
42
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
43
+ are fused. For cross-attention modules, key and value projection matrices are fused.
44
+
45
+ <Tip warning={true}>
46
+
47
+ This API is 🧪 experimental.
48
+
49
+ </Tip>
50
+
51
+ Args:
52
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
53
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
54
+ """
55
+ self.fusing_transformer = False
56
+ self.fusing_vae = False
57
+
58
+ if transformer:
59
+ self.fusing_transformer = True
60
+ self.transformer.fuse_qkv_projections()
61
+
62
+ if vae:
63
+ self.fusing_vae = True
64
+ self.vae.fuse_qkv_projections()
65
+
66
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
67
+ """Disable QKV projection fusion if enabled.
68
+
69
+ <Tip warning={true}>
70
+
71
+ This API is 🧪 experimental.
72
+
73
+ </Tip>
74
+
75
+ Args:
76
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
77
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
78
+
79
+ """
80
+ if transformer:
81
+ if not self.fusing_transformer:
82
+ logger.warning(
83
+ "The UNet was not initially fused for QKV projections. Doing nothing."
84
+ )
85
+ else:
86
+ self.transformer.unfuse_qkv_projections()
87
+ self.fusing_transformer = False
88
+
89
+ if vae:
90
+ if not self.fusing_vae:
91
+ logger.warning(
92
+ "The VAE was not initially fused for QKV projections. Doing nothing."
93
+ )
94
+ else:
95
+ self.vae.unfuse_qkv_projections()
96
+ self.fusing_vae = False
midi/schedulers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .scheduling_rectified_flow import (
2
+ RectifiedFlowScheduler,
3
+ compute_density_for_timestep_sampling,
4
+ compute_loss_weighting,
5
+ )
midi/schedulers/scheduling_rectified_flow.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
3
+ """
4
+
5
+ import math
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
13
+ from diffusers.utils import BaseOutput, logging
14
+ from torch.distributions import LogisticNormal
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ # TODO: may move to training_utils.py
20
+ def compute_density_for_timestep_sampling(
21
+ weighting_scheme: str,
22
+ batch_size: int,
23
+ logit_mean: float = 0.0,
24
+ logit_std: float = 1.0,
25
+ mode_scale: float = None,
26
+ ):
27
+ if weighting_scheme == "logit_normal":
28
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
29
+ u = torch.normal(
30
+ mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
31
+ )
32
+ u = torch.nn.functional.sigmoid(u)
33
+ elif weighting_scheme == "logit_normal_dist":
34
+ u = (
35
+ LogisticNormal(loc=logit_mean, scale=logit_std)
36
+ .sample((batch_size,))[:, 0]
37
+ .to("cpu")
38
+ )
39
+ elif weighting_scheme == "mode":
40
+ u = torch.rand(size=(batch_size,), device="cpu")
41
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
42
+ else:
43
+ u = torch.rand(size=(batch_size,), device="cpu")
44
+ return u
45
+
46
+
47
+ def compute_loss_weighting(weighting_scheme: str, sigmas=None):
48
+ """
49
+ Computes loss weighting scheme for SD3 training.
50
+
51
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
52
+
53
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
54
+ """
55
+ if weighting_scheme == "sigma_sqrt":
56
+ weighting = (sigmas**-2.0).float()
57
+ elif weighting_scheme == "cosmap":
58
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
59
+ weighting = 2 / (math.pi * bot)
60
+ else:
61
+ weighting = torch.ones_like(sigmas)
62
+ return weighting
63
+
64
+
65
+ @dataclass
66
+ class RectifiedFlowSchedulerOutput(BaseOutput):
67
+ """
68
+ Output class for the scheduler's `step` function output.
69
+
70
+ Args:
71
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
72
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
73
+ denoising loop.
74
+ """
75
+
76
+ prev_sample: torch.FloatTensor
77
+
78
+
79
+ class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
80
+ """
81
+ The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
82
+
83
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
84
+ methods the library implements for all schedulers such as loading and saving.
85
+
86
+ Args:
87
+ num_train_timesteps (`int`, defaults to 1000):
88
+ The number of diffusion steps to train the model.
89
+ timestep_spacing (`str`, defaults to `"linspace"`):
90
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
91
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
92
+ shift (`float`, defaults to 1.0):
93
+ The shift value for the timestep schedule.
94
+ """
95
+
96
+ _compatibles = []
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ shift: float = 1.0,
104
+ use_dynamic_shifting: bool = False,
105
+ ):
106
+ # pre-compute timesteps and sigmas; no use in fact
107
+ # NOTE that shape diffusion sample timesteps randomly or in a distribution,
108
+ # instead of sampling from the pre-defined linspace
109
+ timesteps = np.array(
110
+ [
111
+ (1.0 - i / num_train_timesteps) * num_train_timesteps
112
+ for i in range(num_train_timesteps)
113
+ ]
114
+ )
115
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
116
+
117
+ sigmas = timesteps / num_train_timesteps
118
+ if not use_dynamic_shifting:
119
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
120
+ sigmas = self.time_shift(sigmas)
121
+
122
+ self.timesteps = sigmas * num_train_timesteps
123
+
124
+ self._step_index = None
125
+ self._begin_index = None
126
+
127
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
128
+
129
+ @property
130
+ def step_index(self):
131
+ """
132
+ The index counter for current timestep. It will increase 1 after each scheduler step.
133
+ """
134
+ return self._step_index
135
+
136
+ @property
137
+ def begin_index(self):
138
+ """
139
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
140
+ """
141
+ return self._begin_index
142
+
143
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
144
+ def set_begin_index(self, begin_index: int = 0):
145
+ """
146
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
147
+
148
+ Args:
149
+ begin_index (`int`):
150
+ The begin index for the scheduler.
151
+ """
152
+ self._begin_index = begin_index
153
+
154
+ def _sigma_to_t(self, sigma):
155
+ return sigma * self.config.num_train_timesteps
156
+
157
+ def _t_to_sigma(self, timestep):
158
+ return timestep / self.config.num_train_timesteps
159
+
160
+ def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
161
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
162
+
163
+ def time_shift(self, t: torch.Tensor):
164
+ return self.config.shift * t / (1 + (self.config.shift - 1) * t)
165
+
166
+ def set_timesteps(
167
+ self,
168
+ num_inference_steps: int = None,
169
+ device: Union[str, torch.device] = None,
170
+ sigmas: Optional[List[float]] = None,
171
+ mu: Optional[float] = None,
172
+ ):
173
+ """
174
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
175
+
176
+ Args:
177
+ num_inference_steps (`int`):
178
+ The number of diffusion steps used when generating samples with a pre-trained model.
179
+ device (`str` or `torch.device`, *optional*):
180
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
181
+ """
182
+
183
+ if self.config.use_dynamic_shifting and mu is None:
184
+ raise ValueError(
185
+ " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
186
+ )
187
+
188
+ if sigmas is None:
189
+ self.num_inference_steps = num_inference_steps
190
+ timesteps = np.array(
191
+ [
192
+ (1.0 - i / num_inference_steps) * self.config.num_train_timesteps
193
+ for i in range(num_inference_steps)
194
+ ]
195
+ ) # different from the original code in SD3
196
+ sigmas = timesteps / self.config.num_train_timesteps
197
+
198
+ if self.config.use_dynamic_shifting:
199
+ sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
200
+ else:
201
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
202
+
203
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
204
+ timesteps = sigmas * self.config.num_train_timesteps
205
+
206
+ self.timesteps = timesteps.to(device=device)
207
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
208
+
209
+ self._step_index = None
210
+ self._begin_index = None
211
+
212
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
213
+ if schedule_timesteps is None:
214
+ schedule_timesteps = self.timesteps
215
+
216
+ indices = (schedule_timesteps == timestep).nonzero()
217
+
218
+ # The sigma index that is taken for the **very** first `step`
219
+ # is always the second index (or the last index if there is only 1)
220
+ # This way we can ensure we don't accidentally skip a sigma in
221
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
222
+ pos = 1 if len(indices) > 1 else 0
223
+
224
+ return indices[pos].item()
225
+
226
+ def _init_step_index(self, timestep):
227
+ if self.begin_index is None:
228
+ if isinstance(timestep, torch.Tensor):
229
+ timestep = timestep.to(self.timesteps.device)
230
+ self._step_index = self.index_for_timestep(timestep)
231
+ else:
232
+ self._step_index = self._begin_index
233
+
234
+ def step(
235
+ self,
236
+ model_output: torch.FloatTensor,
237
+ timestep: Union[float, torch.FloatTensor],
238
+ sample: torch.FloatTensor,
239
+ s_churn: float = 0.0,
240
+ s_tmin: float = 0.0,
241
+ s_tmax: float = float("inf"),
242
+ s_noise: float = 1.0,
243
+ generator: Optional[torch.Generator] = None,
244
+ return_dict: bool = True,
245
+ ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
246
+ """
247
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
248
+ process from the learned model outputs (most often the predicted noise).
249
+
250
+ Args:
251
+ model_output (`torch.FloatTensor`):
252
+ The direct output from learned diffusion model.
253
+ timestep (`float`):
254
+ The current discrete timestep in the diffusion chain.
255
+ sample (`torch.FloatTensor`):
256
+ A current instance of a sample created by the diffusion process.
257
+ s_churn (`float`):
258
+ s_tmin (`float`):
259
+ s_tmax (`float`):
260
+ s_noise (`float`, defaults to 1.0):
261
+ Scaling factor for noise added to the sample.
262
+ generator (`torch.Generator`, *optional*):
263
+ A random number generator.
264
+ return_dict (`bool`):
265
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
266
+ tuple.
267
+
268
+ Returns:
269
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
270
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
271
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
272
+ """
273
+
274
+ if (
275
+ isinstance(timestep, int)
276
+ or isinstance(timestep, torch.IntTensor)
277
+ or isinstance(timestep, torch.LongTensor)
278
+ ):
279
+ raise ValueError(
280
+ (
281
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
282
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
283
+ " one of the `scheduler.timesteps` as a timestep."
284
+ ),
285
+ )
286
+
287
+ if self.step_index is None:
288
+ self._init_step_index(timestep)
289
+
290
+ # Upcast to avoid precision issues when computing prev_sample
291
+ sample = sample.to(torch.float32)
292
+
293
+ sigma = self.sigmas[self.step_index]
294
+ sigma_next = self.sigmas[self.step_index + 1]
295
+
296
+ # Here different directions are used for the flow matching
297
+ prev_sample = sample + (sigma - sigma_next) * model_output
298
+
299
+ # Cast sample back to model compatible dtype
300
+ prev_sample = prev_sample.to(model_output.dtype)
301
+
302
+ # upon completion increase step index by one
303
+ self._step_index += 1
304
+
305
+ if not return_dict:
306
+ return (prev_sample,)
307
+
308
+ return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
309
+
310
+ def scale_noise(
311
+ self,
312
+ original_samples: torch.Tensor,
313
+ noise: torch.Tensor,
314
+ timesteps: torch.IntTensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Forward function for the noise scaling in the flow matching.
318
+ """
319
+ sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
320
+
321
+ while len(sigmas.shape) < len(original_samples.shape):
322
+ sigmas = sigmas.unsqueeze(-1)
323
+
324
+ return (1.0 - sigmas) * original_samples + sigmas * noise
325
+
326
+ def __len__(self):
327
+ return self.config.num_train_timesteps
midi/utils/smoothing.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Utilities for smoothing the occ/sdf grids.
5
+ """
6
+
7
+ import logging
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from scipy import ndimage as ndi
14
+ from scipy import sparse
15
+
16
+ __all__ = [
17
+ "smooth",
18
+ "smooth_constrained",
19
+ "smooth_gaussian",
20
+ "signed_distance_function",
21
+ "smooth_gpu",
22
+ "smooth_constrained_gpu",
23
+ "smooth_gaussian_gpu",
24
+ "signed_distance_function_gpu",
25
+ ]
26
+
27
+
28
+ def _build_variable_indices(band: np.ndarray) -> np.ndarray:
29
+ num_variables = np.count_nonzero(band)
30
+ variable_indices = np.full(band.shape, -1, dtype=np.int_)
31
+ variable_indices[band] = np.arange(num_variables)
32
+ return variable_indices
33
+
34
+
35
+ def _buildq3d(variable_indices: np.ndarray):
36
+ """
37
+ Builds the filterq matrix for the given variables.
38
+ """
39
+
40
+ num_variables = variable_indices.max() + 1
41
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
42
+
43
+ # Pad variable_indices to simplify out-of-bounds accesses
44
+ variable_indices = np.pad(
45
+ variable_indices, [(0, 1), (0, 1), (0, 1)], mode="constant", constant_values=-1
46
+ )
47
+
48
+ coords = np.nonzero(variable_indices >= 0)
49
+ for count, (i, j, k) in enumerate(zip(*coords)):
50
+
51
+ assert variable_indices[i, j, k] == count
52
+
53
+ filterq[3 * count, count] = -2
54
+ neighbor = variable_indices[i - 1, j, k]
55
+ if neighbor >= 0:
56
+ filterq[3 * count, neighbor] = 1
57
+ else:
58
+ filterq[3 * count, count] += 1
59
+
60
+ neighbor = variable_indices[i + 1, j, k]
61
+ if neighbor >= 0:
62
+ filterq[3 * count, neighbor] = 1
63
+ else:
64
+ filterq[3 * count, count] += 1
65
+
66
+ filterq[3 * count + 1, count] = -2
67
+ neighbor = variable_indices[i, j - 1, k]
68
+ if neighbor >= 0:
69
+ filterq[3 * count + 1, neighbor] = 1
70
+ else:
71
+ filterq[3 * count + 1, count] += 1
72
+
73
+ neighbor = variable_indices[i, j + 1, k]
74
+ if neighbor >= 0:
75
+ filterq[3 * count + 1, neighbor] = 1
76
+ else:
77
+ filterq[3 * count + 1, count] += 1
78
+
79
+ filterq[3 * count + 2, count] = -2
80
+ neighbor = variable_indices[i, j, k - 1]
81
+ if neighbor >= 0:
82
+ filterq[3 * count + 2, neighbor] = 1
83
+ else:
84
+ filterq[3 * count + 2, count] += 1
85
+
86
+ neighbor = variable_indices[i, j, k + 1]
87
+ if neighbor >= 0:
88
+ filterq[3 * count + 2, neighbor] = 1
89
+ else:
90
+ filterq[3 * count + 2, count] += 1
91
+
92
+ filterq = filterq.tocsr()
93
+ return filterq.T.dot(filterq)
94
+
95
+
96
+ def _buildq3d_gpu(variable_indices: torch.Tensor, chunk_size=10000):
97
+ """
98
+ Builds the filterq matrix for the given variables on GPU, using chunking to reduce memory usage.
99
+ """
100
+ device = variable_indices.device
101
+ num_variables = variable_indices.max().item() + 1
102
+
103
+ # Pad variable_indices to simplify out-of-bounds accesses
104
+ variable_indices = torch.nn.functional.pad(
105
+ variable_indices, (0, 1, 0, 1, 0, 1), mode="constant", value=-1
106
+ )
107
+
108
+ coords = torch.nonzero(variable_indices >= 0)
109
+ i, j, k = coords[:, 0], coords[:, 1], coords[:, 2]
110
+
111
+ # Function to process a chunk of data
112
+ def process_chunk(start, end):
113
+ row_indices = []
114
+ col_indices = []
115
+ values = []
116
+
117
+ for axis in range(3):
118
+ row_indices.append(3 * torch.arange(start, end, device=device) + axis)
119
+ col_indices.append(
120
+ variable_indices[i[start:end], j[start:end], k[start:end]]
121
+ )
122
+ values.append(torch.full((end - start,), -2, device=device))
123
+
124
+ for offset in [-1, 1]:
125
+ if axis == 0:
126
+ neighbor = variable_indices[
127
+ i[start:end] + offset, j[start:end], k[start:end]
128
+ ]
129
+ elif axis == 1:
130
+ neighbor = variable_indices[
131
+ i[start:end], j[start:end] + offset, k[start:end]
132
+ ]
133
+ else:
134
+ neighbor = variable_indices[
135
+ i[start:end], j[start:end], k[start:end] + offset
136
+ ]
137
+
138
+ mask = neighbor >= 0
139
+ row_indices.append(
140
+ 3 * torch.arange(start, end, device=device)[mask] + axis
141
+ )
142
+ col_indices.append(neighbor[mask])
143
+ values.append(torch.ones(mask.sum(), device=device))
144
+
145
+ # Add 1 to the diagonal for out-of-bounds neighbors
146
+ row_indices.append(
147
+ 3 * torch.arange(start, end, device=device)[~mask] + axis
148
+ )
149
+ col_indices.append(
150
+ variable_indices[i[start:end], j[start:end], k[start:end]][~mask]
151
+ )
152
+ values.append(torch.ones((~mask).sum(), device=device))
153
+
154
+ return torch.cat(row_indices), torch.cat(col_indices), torch.cat(values)
155
+
156
+ # Process data in chunks
157
+ all_row_indices = []
158
+ all_col_indices = []
159
+ all_values = []
160
+
161
+ for start in range(0, coords.shape[0], chunk_size):
162
+ end = min(start + chunk_size, coords.shape[0])
163
+ row_indices, col_indices, values = process_chunk(start, end)
164
+ all_row_indices.append(row_indices)
165
+ all_col_indices.append(col_indices)
166
+ all_values.append(values)
167
+
168
+ # Concatenate all chunks
169
+ row_indices = torch.cat(all_row_indices)
170
+ col_indices = torch.cat(all_col_indices)
171
+ values = torch.cat(all_values)
172
+
173
+ # Create sparse tensor
174
+ indices = torch.stack([row_indices, col_indices])
175
+ filterq = torch.sparse_coo_tensor(
176
+ indices, values, (3 * num_variables, num_variables)
177
+ )
178
+
179
+ # Compute filterq.T @ filterq
180
+ return torch.sparse.mm(filterq.t(), filterq)
181
+
182
+
183
+ # Usage example:
184
+ # variable_indices = torch.tensor(...).cuda() # Your input tensor on GPU
185
+ # result = _buildq3d_gpu(variable_indices)
186
+
187
+
188
+ def _buildq2d(variable_indices: np.ndarray):
189
+ """
190
+ Builds the filterq matrix for the given variables.
191
+
192
+ Version for 2 dimensions.
193
+ """
194
+
195
+ num_variables = variable_indices.max() + 1
196
+ filterq = sparse.lil_matrix((3 * num_variables, num_variables))
197
+
198
+ # Pad variable_indices to simplify out-of-bounds accesses
199
+ variable_indices = np.pad(
200
+ variable_indices, [(0, 1), (0, 1)], mode="constant", constant_values=-1
201
+ )
202
+
203
+ coords = np.nonzero(variable_indices >= 0)
204
+ for count, (i, j) in enumerate(zip(*coords)):
205
+ assert variable_indices[i, j] == count
206
+
207
+ filterq[2 * count, count] = -2
208
+ neighbor = variable_indices[i - 1, j]
209
+ if neighbor >= 0:
210
+ filterq[2 * count, neighbor] = 1
211
+ else:
212
+ filterq[2 * count, count] += 1
213
+
214
+ neighbor = variable_indices[i + 1, j]
215
+ if neighbor >= 0:
216
+ filterq[2 * count, neighbor] = 1
217
+ else:
218
+ filterq[2 * count, count] += 1
219
+
220
+ filterq[2 * count + 1, count] = -2
221
+ neighbor = variable_indices[i, j - 1]
222
+ if neighbor >= 0:
223
+ filterq[2 * count + 1, neighbor] = 1
224
+ else:
225
+ filterq[2 * count + 1, count] += 1
226
+
227
+ neighbor = variable_indices[i, j + 1]
228
+ if neighbor >= 0:
229
+ filterq[2 * count + 1, neighbor] = 1
230
+ else:
231
+ filterq[2 * count + 1, count] += 1
232
+
233
+ filterq = filterq.tocsr()
234
+ return filterq.T.dot(filterq)
235
+
236
+
237
+ def _jacobi(
238
+ filterq,
239
+ x0: np.ndarray,
240
+ lower_bound: np.ndarray,
241
+ upper_bound: np.ndarray,
242
+ max_iters: int = 10,
243
+ rel_tol: float = 1e-6,
244
+ weight: float = 0.5,
245
+ ):
246
+ """Jacobi method with constraints."""
247
+
248
+ jacobi_r = sparse.lil_matrix(filterq)
249
+ shp = jacobi_r.shape
250
+ jacobi_d = 1.0 / filterq.diagonal()
251
+ jacobi_r.setdiag((0,) * shp[0])
252
+ jacobi_r = jacobi_r.tocsr()
253
+
254
+ x = x0
255
+
256
+ # We check the stopping criterion each 10 iterations
257
+ check_each = 10
258
+ cum_rel_tol = 1 - (1 - rel_tol) ** check_each
259
+
260
+ energy_now = np.dot(x, filterq.dot(x)) / 2
261
+ logging.info("Energy at iter %d: %.6g", 0, energy_now)
262
+ for i in range(max_iters):
263
+
264
+ x_1 = -jacobi_d * jacobi_r.dot(x)
265
+ x = weight * x_1 + (1 - weight) * x
266
+
267
+ # Constraints.
268
+ x = np.maximum(x, lower_bound)
269
+ x = np.minimum(x, upper_bound)
270
+
271
+ # Stopping criterion
272
+ if (i + 1) % check_each == 0:
273
+ # Update energy
274
+ energy_before = energy_now
275
+ energy_now = np.dot(x, filterq.dot(x)) / 2
276
+
277
+ logging.info("Energy at iter %d: %.6g", i + 1, energy_now)
278
+
279
+ # Check stopping criterion
280
+ cum_rel_improvement = (energy_before - energy_now) / energy_before
281
+ if cum_rel_improvement < cum_rel_tol:
282
+ break
283
+
284
+ return x
285
+
286
+
287
+ def signed_distance_function(
288
+ levelset: np.ndarray, band_radius: int
289
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
290
+ """
291
+ Return the distance to the 0.5 levelset of a function, the mask of the
292
+ border (i.e., the nearest cells to the 0.5 level-set) and the mask of the
293
+ band (i.e., the cells of the function whose distance to the 0.5 level-set
294
+ is less of equal to `band_radius`).
295
+ """
296
+
297
+ binary_array = np.where(levelset > 0, True, False)
298
+
299
+ # Compute the band and the border.
300
+ dist_func = ndi.distance_transform_edt
301
+ distance = np.where(
302
+ binary_array, dist_func(binary_array) - 0.5, -dist_func(~binary_array) + 0.5
303
+ )
304
+ border = np.abs(distance) < 1
305
+ band = np.abs(distance) <= band_radius
306
+
307
+ return distance, border, band
308
+
309
+
310
+ def signed_distance_function_iso0(
311
+ levelset: np.ndarray, band_radius: int
312
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
313
+ """
314
+ Return the distance to the 0 levelset of a function, the mask of the
315
+ border (i.e., the nearest cells to the 0 level-set) and the mask of the
316
+ band (i.e., the cells of the function whose distance to the 0 level-set
317
+ is less of equal to `band_radius`).
318
+ """
319
+
320
+ binary_array = levelset > 0
321
+
322
+ # Compute the band and the border.
323
+ dist_func = ndi.distance_transform_edt
324
+ distance = np.where(
325
+ binary_array, dist_func(binary_array), -dist_func(~binary_array)
326
+ )
327
+ border = np.zeros_like(levelset, dtype=bool)
328
+ border[:-1, :, :] |= levelset[:-1, :, :] * levelset[1:, :, :] <= 0
329
+ border[:, :-1, :] |= levelset[:, :-1, :] * levelset[:, 1:, :] <= 0
330
+ border[:, :, :-1] |= levelset[:, :, :-1] * levelset[:, :, 1:] <= 0
331
+ band = np.abs(distance) <= band_radius
332
+
333
+ return distance, border, band
334
+
335
+
336
+ def signed_distance_function_gpu(levelset: torch.Tensor, band_radius: int):
337
+ binary_array = (levelset > 0).float()
338
+
339
+ # Compute distance transform
340
+ dist_pos = (
341
+ F.max_pool3d(
342
+ -binary_array.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
343
+ )
344
+ .squeeze(0)
345
+ .squeeze(0)
346
+ + binary_array
347
+ )
348
+ dist_neg = F.max_pool3d(
349
+ (binary_array - 1).unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
350
+ ).squeeze(0).squeeze(0) + (1 - binary_array)
351
+
352
+ distance = torch.where(binary_array > 0, dist_pos - 0.5, -dist_neg + 0.5)
353
+
354
+ # breakpoint()
355
+
356
+ # Use levelset as distance directly
357
+ # distance = levelset
358
+ # print(distance.shape)
359
+ # Compute border and band
360
+ border = torch.abs(distance) < 1
361
+ band = torch.abs(distance) <= band_radius
362
+
363
+ return distance, border, band
364
+
365
+
366
+ def smooth_constrained(
367
+ binary_array: np.ndarray,
368
+ band_radius: int = 4,
369
+ max_iters: int = 250,
370
+ rel_tol: float = 1e-6,
371
+ ) -> np.ndarray:
372
+ """
373
+ Implementation of the smoothing method from
374
+
375
+ "Surface Extraction from Binary Volumes with Higher-Order Smoothness"
376
+ Victor Lempitsky, CVPR10
377
+ """
378
+
379
+ # # Compute the distance map, the border and the band.
380
+ logging.info("Computing distance transform...")
381
+ # distance, _, band = signed_distance_function(binary_array, band_radius)
382
+ binary_array_gpu = torch.from_numpy(binary_array).cuda()
383
+ distance, _, band = signed_distance_function_gpu(binary_array_gpu, band_radius)
384
+ distance = distance.cpu().numpy()
385
+ band = band.cpu().numpy()
386
+
387
+ variable_indices = _build_variable_indices(band)
388
+
389
+ # Compute filterq.
390
+ logging.info("Building matrix filterq...")
391
+ if binary_array.ndim == 3:
392
+ filterq = _buildq3d(variable_indices)
393
+ # variable_indices_gpu = torch.from_numpy(variable_indices).cuda()
394
+ # filterq_gpu = _buildq3d_gpu(variable_indices_gpu)
395
+ # filterq = filterq_gpu.cpu().numpy()
396
+ elif binary_array.ndim == 2:
397
+ filterq = _buildq2d(variable_indices)
398
+ else:
399
+ raise ValueError("binary_array.ndim not in [2, 3]")
400
+
401
+ # Initialize the variables.
402
+ res = np.asarray(distance, dtype=np.double)
403
+ x = res[band]
404
+ upper_bound = np.where(x < 0, x, np.inf)
405
+ lower_bound = np.where(x > 0, x, -np.inf)
406
+
407
+ upper_bound[np.abs(upper_bound) < 1] = 0
408
+ lower_bound[np.abs(lower_bound) < 1] = 0
409
+
410
+ # Solve.
411
+ logging.info("Minimizing energy...")
412
+ x = _jacobi(
413
+ filterq=filterq,
414
+ x0=x,
415
+ lower_bound=lower_bound,
416
+ upper_bound=upper_bound,
417
+ max_iters=max_iters,
418
+ rel_tol=rel_tol,
419
+ )
420
+
421
+ res[band] = x
422
+ return res
423
+
424
+
425
+ def total_variation_denoising(x, weight=0.1, num_iterations=5, eps=1e-8):
426
+ diff_x = torch.diff(x, dim=0, prepend=x[:1])
427
+ diff_y = torch.diff(x, dim=1, prepend=x[:, :1])
428
+ diff_z = torch.diff(x, dim=2, prepend=x[:, :, :1])
429
+
430
+ norm = torch.sqrt(diff_x**2 + diff_y**2 + diff_z**2 + eps)
431
+
432
+ div_x = torch.diff(diff_x / norm, dim=0, append=diff_x[-1:] / norm[-1:])
433
+ div_y = torch.diff(diff_y / norm, dim=1, append=diff_y[:, -1:] / norm[:, -1:])
434
+ div_z = torch.diff(diff_z / norm, dim=2, append=diff_z[:, :, -1:] / norm[:, :, -1:])
435
+
436
+ return x - weight * (div_x + div_y + div_z)
437
+
438
+
439
+ def smooth_constrained_gpu(
440
+ binary_array: torch.Tensor,
441
+ band_radius: int = 4,
442
+ max_iters: int = 250,
443
+ rel_tol: float = 1e-4,
444
+ ):
445
+ distance, _, band = signed_distance_function_gpu(binary_array, band_radius)
446
+
447
+ # Initialize variables
448
+ x = distance[band]
449
+ upper_bound = torch.where(x < 0, x, torch.tensor(float("inf"), device=x.device))
450
+ lower_bound = torch.where(x > 0, x, torch.tensor(float("-inf"), device=x.device))
451
+
452
+ upper_bound[torch.abs(upper_bound) < 1] = 0
453
+ lower_bound[torch.abs(lower_bound) < 1] = 0
454
+
455
+ # Define the 3D Laplacian kernel
456
+ laplacian_kernel = torch.tensor(
457
+ [
458
+ [
459
+ [
460
+ [[0, 1, 0], [1, -6, 1], [0, 1, 0]],
461
+ [[1, 0, 1], [0, 0, 0], [1, 0, 1]],
462
+ [[0, 1, 0], [1, 0, 1], [0, 1, 0]],
463
+ ]
464
+ ]
465
+ ],
466
+ device=x.device,
467
+ ).float()
468
+
469
+ laplacian_kernel = laplacian_kernel / laplacian_kernel.abs().sum()
470
+
471
+ breakpoint()
472
+
473
+ # Simplified Jacobi iteration
474
+ for i in range(max_iters):
475
+ # Reshape x to 5D tensor (batch, channel, depth, height, width)
476
+ x_5d = x.view(1, 1, *band.shape)
477
+ x_3d = x.view(*band.shape)
478
+
479
+ # Apply 3D convolution
480
+ laplacian = F.conv3d(x_5d, laplacian_kernel, padding=1)
481
+
482
+ # Reshape back to original dimensions
483
+ laplacian = laplacian.view(x.shape)
484
+
485
+ # Use a small relaxation factor to improve stability
486
+ relaxation_factor = 0.1
487
+ tv_weight = 0.1
488
+ # x_new = x + relaxation_factor * laplacian
489
+ x_new = total_variation_denoising(x_3d, weight=tv_weight)
490
+ # Print laplacian min and max
491
+ # print(f"Laplacian min: {laplacian.min().item():.4f}, max: {laplacian.max().item():.4f}")
492
+
493
+ # Apply constraints
494
+ # Reshape x_new to match the dimensions of lower_bound and upper_bound
495
+ x_new = x_new.view(x.shape)
496
+ x_new = torch.clamp(x_new, min=lower_bound, max=upper_bound)
497
+
498
+ # Check for convergence
499
+ diff_norm = torch.norm(x_new - x)
500
+ print(diff_norm)
501
+ x_norm = torch.norm(x)
502
+
503
+ if x_norm > 1e-8: # Avoid division by very small numbers
504
+ relative_change = diff_norm / x_norm
505
+ if relative_change < rel_tol:
506
+ break
507
+ elif diff_norm < rel_tol: # If x_norm is very small, check absolute change
508
+ break
509
+
510
+ x = x_new
511
+
512
+ # Check for NaN and break if found, also check for inf
513
+ if torch.isnan(x).any() or torch.isinf(x).any():
514
+ print(f"NaN or Inf detected at iteration {i}")
515
+ breakpoint()
516
+ break
517
+
518
+ result = distance.clone()
519
+ result[band] = x
520
+ return result
521
+
522
+
523
+ def smooth_gaussian(binary_array: np.ndarray, sigma: float = 3) -> np.ndarray:
524
+ vol = np.float_(binary_array) - 0.5
525
+ return ndi.gaussian_filter(vol, sigma=sigma)
526
+
527
+
528
+ def smooth_gaussian_gpu(binary_array: torch.Tensor, sigma: float = 3):
529
+ # vol = binary_array.float()
530
+ vol = binary_array
531
+ kernel_size = int(2 * sigma + 1)
532
+ kernel = torch.ones(
533
+ 1,
534
+ 1,
535
+ kernel_size,
536
+ kernel_size,
537
+ kernel_size,
538
+ device=binary_array.device,
539
+ dtype=vol.dtype,
540
+ ) / (kernel_size**3)
541
+ return F.conv3d(
542
+ vol.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2
543
+ ).squeeze()
544
+
545
+
546
+ def smooth(binary_array: np.ndarray, method: str = "auto", **kwargs) -> np.ndarray:
547
+ """
548
+ Smooths the 0.5 level-set of a binary array. Returns a floating-point
549
+ array with a smoothed version of the original level-set in the 0 isovalue.
550
+
551
+ This function can apply two different methods:
552
+
553
+ - A constrained smoothing method which preserves details and fine
554
+ structures, but it is slow and requires a large amount of memory. This
555
+ method is recommended when the input array is small (smaller than
556
+ (500, 500, 500)).
557
+ - A Gaussian filter applied over the binary array. This method is fast, but
558
+ not very precise, as it can destroy fine details. It is only recommended
559
+ when the input array is large and the 0.5 level-set does not contain
560
+ thin structures.
561
+
562
+ Parameters
563
+ ----------
564
+ binary_array : ndarray
565
+ Input binary array with the 0.5 level-set to smooth.
566
+ method : str, one of ['auto', 'gaussian', 'constrained']
567
+ Smoothing method. If 'auto' is given, the method will be automatically
568
+ chosen based on the size of `binary_array`.
569
+
570
+ Parameters for 'gaussian'
571
+ -------------------------
572
+ sigma : float
573
+ Size of the Gaussian filter (default 3).
574
+
575
+ Parameters for 'constrained'
576
+ ----------------------------
577
+ max_iters : positive integer
578
+ Number of iterations of the constrained optimization method
579
+ (default 250).
580
+ rel_tol: float
581
+ Relative tolerance as a stopping criterion (default 1e-6).
582
+
583
+ Output
584
+ ------
585
+ res : ndarray
586
+ Floating-point array with a smoothed 0 level-set.
587
+ """
588
+
589
+ binary_array = np.asarray(binary_array)
590
+
591
+ if method == "auto":
592
+ if binary_array.size > 512**3:
593
+ method = "gaussian"
594
+ else:
595
+ method = "constrained"
596
+
597
+ if method == "gaussian":
598
+ return smooth_gaussian(binary_array, **kwargs)
599
+
600
+ if method == "constrained":
601
+ return smooth_constrained(binary_array, **kwargs)
602
+
603
+ raise ValueError("Unknown method '{}'".format(method))
604
+
605
+
606
+ def smooth_gpu(binary_array: torch.Tensor, method: str = "auto", **kwargs):
607
+ if method == "auto":
608
+ method = "gaussian" if binary_array.numel() > 512**3 else "constrained"
609
+
610
+ if method == "gaussian":
611
+ return smooth_gaussian_gpu(binary_array, **kwargs)
612
+ elif method == "constrained":
613
+ return smooth_constrained_gpu(binary_array, **kwargs)
614
+ else:
615
+ raise ValueError(f"Unknown method '{method}'")