Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
c9724af
0
Parent(s):
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- README.md +14 -0
- app.py +334 -0
- assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png +0 -0
- assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png +0 -0
- assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png +0 -0
- assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png +0 -0
- assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png +0 -0
- assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png +0 -0
- assets/example_data/Cartoon-Style/00_rgb.png +0 -0
- assets/example_data/Cartoon-Style/00_seg.png +0 -0
- assets/example_data/Cartoon-Style/01_rgb.png +0 -0
- assets/example_data/Cartoon-Style/01_seg.png +0 -0
- assets/example_data/Cartoon-Style/02_rgb.png +0 -0
- assets/example_data/Cartoon-Style/02_seg.png +0 -0
- assets/example_data/Cartoon-Style/03_rgb.png +0 -0
- assets/example_data/Cartoon-Style/03_seg.png +0 -0
- assets/example_data/Cartoon-Style/04_rgb.png +0 -0
- assets/example_data/Cartoon-Style/04_seg.png +0 -0
- assets/example_data/Realistic-Style/00_rgb.png +0 -0
- assets/example_data/Realistic-Style/00_seg.png +0 -0
- assets/example_data/Realistic-Style/01_rgb.png +0 -0
- assets/example_data/Realistic-Style/01_seg.png +0 -0
- assets/example_data/Realistic-Style/02_rgb.png +0 -0
- assets/example_data/Realistic-Style/02_seg.png +0 -0
- assets/example_data/Realistic-Style/03_rgb.png +0 -0
- assets/example_data/Realistic-Style/03_seg.png +0 -0
- assets/example_data/Realistic-Style/04_rgb.png +0 -0
- assets/example_data/Realistic-Style/04_seg.png +0 -0
- assets/example_data/Realistic-Style/05_rgb.png +0 -0
- assets/example_data/Realistic-Style/05_seg.png +0 -0
- assets/example_data/Realistic-Style/06_rgb.png +0 -0
- assets/example_data/Realistic-Style/06_seg.png +0 -0
- midi/inference_utils.py +22 -0
- midi/loaders/__init__.py +1 -0
- midi/loaders/custom_adapter.py +99 -0
- midi/models/attention_processor.py +412 -0
- midi/models/autoencoders/__init__.py +1 -0
- midi/models/autoencoders/autoencoder_kl_triposg.py +541 -0
- midi/models/autoencoders/vae.py +69 -0
- midi/models/embeddings.py +96 -0
- midi/models/transformers/__init__.py +61 -0
- midi/models/transformers/modeling_outputs.py +8 -0
- midi/models/transformers/triposg_transformer.py +690 -0
- midi/pipelines/pipeline_midi.py +497 -0
- midi/pipelines/pipeline_triposg_output.py +25 -0
- midi/pipelines/pipeline_utils.py +96 -0
- midi/schedulers/__init__.py +5 -0
- midi/schedulers/scheduling_rectified_flow.py +327 -0
- midi/utils/smoothing.py +615 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MIDI 3D
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
short_description: Image to Compositional 3D Scene Generation
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import tempfile
|
5 |
+
from typing import Any, List, Union
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import spaces
|
10 |
+
import torch
|
11 |
+
import trimesh
|
12 |
+
from gradio_image_prompter import ImagePrompter
|
13 |
+
from gradio_litmodel3d import LitModel3D
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
from PIL import Image
|
16 |
+
from skimage import measure
|
17 |
+
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
18 |
+
|
19 |
+
from midi.pipelines.pipeline_midi import MIDIPipeline
|
20 |
+
from midi.utils.smoothing import smooth_gpu
|
21 |
+
from scripts.grounding_sam import plot_segmentation, segment
|
22 |
+
from scripts.inference_midi import preprocess_image, split_rgb_mask
|
23 |
+
|
24 |
+
# Constants
|
25 |
+
MAX_SEED = np.iinfo(np.int32).max
|
26 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
27 |
+
DTYPE = torch.bfloat16
|
28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
REPO_ID = "VAST-AI/MIDI-3D"
|
30 |
+
|
31 |
+
MARKDOWN = """
|
32 |
+
## Image to 3D Scene with [MIDI-3D](https://huanngzh.github.io/MIDI-Page/)
|
33 |
+
<b>Important!</b> Please check out our [instruction video](https://github.com/user-attachments/assets/4fc8aea4-010f-40c7-989d-6b1d9d3e3e09)!
|
34 |
+
1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then clik "Run Segmentation" to generate the segmentation result. <b>Ensure instances should not be too small and bounding boxes fit snugly around each instance.</b>
|
35 |
+
2. <b>Check "Do image padding" in "Generation Settings" if instances in your image are too close to the image border.</b> Then click "Run Generation" to generate a 3D scene from the image and segmentation result.
|
36 |
+
3. If you find the generated 3D scene satisfactory, download it by clicking the "Download GLB" button.
|
37 |
+
"""
|
38 |
+
|
39 |
+
EXAMPLES = [
|
40 |
+
[
|
41 |
+
{
|
42 |
+
"image": "assets/example_data/Cartoon-Style/03_rgb.png",
|
43 |
+
},
|
44 |
+
"assets/example_data/Cartoon-Style/03_seg.png",
|
45 |
+
42,
|
46 |
+
False,
|
47 |
+
False,
|
48 |
+
],
|
49 |
+
[
|
50 |
+
{
|
51 |
+
"image": "assets/example_data/Cartoon-Style/01_rgb.png",
|
52 |
+
},
|
53 |
+
"assets/example_data/Cartoon-Style/01_seg.png",
|
54 |
+
42,
|
55 |
+
False,
|
56 |
+
False,
|
57 |
+
],
|
58 |
+
[
|
59 |
+
{
|
60 |
+
"image": "assets/example_data/Realistic-Style/02_rgb.png",
|
61 |
+
},
|
62 |
+
"assets/example_data/Realistic-Style/02_seg.png",
|
63 |
+
42,
|
64 |
+
False,
|
65 |
+
False,
|
66 |
+
],
|
67 |
+
[
|
68 |
+
{
|
69 |
+
"image": "assets/example_data/Cartoon-Style/00_rgb.png",
|
70 |
+
},
|
71 |
+
"assets/example_data/Cartoon-Style/00_seg.png",
|
72 |
+
42,
|
73 |
+
False,
|
74 |
+
False,
|
75 |
+
],
|
76 |
+
[
|
77 |
+
{
|
78 |
+
"image": "assets/example_data/Realistic-Style/00_rgb.png",
|
79 |
+
},
|
80 |
+
"assets/example_data/Realistic-Style/00_seg.png",
|
81 |
+
42,
|
82 |
+
False,
|
83 |
+
True,
|
84 |
+
],
|
85 |
+
[
|
86 |
+
{
|
87 |
+
"image": "assets/example_data/Realistic-Style/01_rgb.png",
|
88 |
+
},
|
89 |
+
"assets/example_data/Realistic-Style/01_seg.png",
|
90 |
+
42,
|
91 |
+
False,
|
92 |
+
True,
|
93 |
+
],
|
94 |
+
[
|
95 |
+
{
|
96 |
+
"image": "assets/example_data/Realistic-Style/05_rgb.png",
|
97 |
+
},
|
98 |
+
"assets/example_data/Realistic-Style/05_seg.png",
|
99 |
+
42,
|
100 |
+
False,
|
101 |
+
False,
|
102 |
+
],
|
103 |
+
]
|
104 |
+
|
105 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
106 |
+
|
107 |
+
# Prepare models
|
108 |
+
## Grounding SAM
|
109 |
+
segmenter_id = "facebook/sam-vit-base"
|
110 |
+
sam_processor = AutoProcessor.from_pretrained(segmenter_id)
|
111 |
+
sam_segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(
|
112 |
+
DEVICE, DTYPE
|
113 |
+
)
|
114 |
+
## MIDI-3D
|
115 |
+
local_dir = "pretrained_weights/MIDI-3D"
|
116 |
+
snapshot_download(repo_id=REPO_ID, local_dir=local_dir)
|
117 |
+
pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(DEVICE, DTYPE)
|
118 |
+
pipe.init_custom_adapter(
|
119 |
+
set_self_attn_module_names=[
|
120 |
+
"blocks.8",
|
121 |
+
"blocks.9",
|
122 |
+
"blocks.10",
|
123 |
+
"blocks.11",
|
124 |
+
"blocks.12",
|
125 |
+
]
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
# Utils
|
130 |
+
def get_random_hex():
|
131 |
+
random_bytes = os.urandom(8)
|
132 |
+
random_hex = random_bytes.hex()
|
133 |
+
return random_hex
|
134 |
+
|
135 |
+
|
136 |
+
@spaces.GPU()
|
137 |
+
@torch.no_grad()
|
138 |
+
@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
139 |
+
def run_segmentation(image_prompts: Any, polygon_refinement: bool) -> Image.Image:
|
140 |
+
rgb_image = image_prompts["image"].convert("RGB")
|
141 |
+
|
142 |
+
# pre-process the layers and get the xyxy boxes of each layer
|
143 |
+
if len(image_prompts["points"]) == 0:
|
144 |
+
gr.Error("Please draw bounding boxes for each instance on the image.")
|
145 |
+
boxes = [
|
146 |
+
[
|
147 |
+
[int(box[0]), int(box[1]), int(box[3]), int(box[4])]
|
148 |
+
for box in image_prompts["points"]
|
149 |
+
]
|
150 |
+
]
|
151 |
+
|
152 |
+
# run the segmentation
|
153 |
+
detections = segment(
|
154 |
+
sam_processor,
|
155 |
+
sam_segmentator,
|
156 |
+
rgb_image,
|
157 |
+
boxes=[boxes],
|
158 |
+
polygon_refinement=polygon_refinement,
|
159 |
+
)
|
160 |
+
seg_map_pil = plot_segmentation(rgb_image, detections)
|
161 |
+
|
162 |
+
torch.cuda.empty_cache()
|
163 |
+
|
164 |
+
return seg_map_pil
|
165 |
+
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def run_midi(
|
169 |
+
pipe: Any,
|
170 |
+
rgb_image: Union[str, Image.Image],
|
171 |
+
seg_image: Union[str, Image.Image],
|
172 |
+
seed: int,
|
173 |
+
num_inference_steps: int = 50,
|
174 |
+
guidance_scale: float = 7.0,
|
175 |
+
do_image_padding: bool = False,
|
176 |
+
) -> trimesh.Scene:
|
177 |
+
if do_image_padding:
|
178 |
+
rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
|
179 |
+
instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)
|
180 |
+
|
181 |
+
num_instances = len(instance_rgbs)
|
182 |
+
outputs = pipe(
|
183 |
+
image=instance_rgbs,
|
184 |
+
mask=instance_masks,
|
185 |
+
image_scene=scene_rgbs,
|
186 |
+
attention_kwargs={"num_instances": num_instances},
|
187 |
+
generator=torch.Generator(device=pipe.device).manual_seed(seed),
|
188 |
+
num_inference_steps=num_inference_steps,
|
189 |
+
guidance_scale=guidance_scale,
|
190 |
+
decode_progressive=True,
|
191 |
+
return_dict=False,
|
192 |
+
)
|
193 |
+
|
194 |
+
return outputs
|
195 |
+
|
196 |
+
|
197 |
+
@spaces.GPU(duration=300)
|
198 |
+
@torch.no_grad()
|
199 |
+
@torch.autocast(device_type=DEVICE, dtype=torch.bfloat16)
|
200 |
+
def run_generation(
|
201 |
+
rgb_image: Any,
|
202 |
+
seg_image: Union[str, Image.Image],
|
203 |
+
seed: int,
|
204 |
+
randomize_seed: bool = False,
|
205 |
+
num_inference_steps: int = 50,
|
206 |
+
guidance_scale: float = 7.0,
|
207 |
+
do_image_padding: bool = False,
|
208 |
+
):
|
209 |
+
if randomize_seed:
|
210 |
+
seed = random.randint(0, MAX_SEED)
|
211 |
+
|
212 |
+
if not isinstance(rgb_image, Image.Image) and "image" in rgb_image:
|
213 |
+
rgb_image = rgb_image["image"]
|
214 |
+
|
215 |
+
outputs = run_midi(
|
216 |
+
pipe,
|
217 |
+
rgb_image,
|
218 |
+
seg_image,
|
219 |
+
seed,
|
220 |
+
num_inference_steps,
|
221 |
+
guidance_scale,
|
222 |
+
do_image_padding,
|
223 |
+
)
|
224 |
+
|
225 |
+
# marching cubes
|
226 |
+
trimeshes = []
|
227 |
+
for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
|
228 |
+
zip(*outputs)
|
229 |
+
):
|
230 |
+
grid_logits = logits_.view(grid_size)
|
231 |
+
grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
|
232 |
+
torch.cuda.empty_cache()
|
233 |
+
vertices, faces, normals, _ = measure.marching_cubes(
|
234 |
+
grid_logits.float().cpu().numpy(), 0, method="lewiner"
|
235 |
+
)
|
236 |
+
vertices = vertices / grid_size * bbox_size + bbox_min
|
237 |
+
|
238 |
+
# Trimesh
|
239 |
+
mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
|
240 |
+
trimeshes.append(mesh)
|
241 |
+
|
242 |
+
# compose the output meshes
|
243 |
+
scene = trimesh.Scene(trimeshes)
|
244 |
+
|
245 |
+
tmp_path = os.path.join(TMP_DIR, f"midi3d_{get_random_hex()}.glb")
|
246 |
+
scene.export(tmp_path)
|
247 |
+
|
248 |
+
torch.cuda.empty_cache()
|
249 |
+
|
250 |
+
return tmp_path, tmp_path, seed
|
251 |
+
|
252 |
+
|
253 |
+
# Demo
|
254 |
+
with gr.Blocks() as demo:
|
255 |
+
gr.Markdown(MARKDOWN)
|
256 |
+
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column():
|
259 |
+
with gr.Row():
|
260 |
+
image_prompts = ImagePrompter(label="Input Image", type="pil")
|
261 |
+
seg_image = gr.Image(
|
262 |
+
label="Segmentation Result", type="pil", format="png"
|
263 |
+
)
|
264 |
+
|
265 |
+
with gr.Accordion("Segmentation Settings", open=False):
|
266 |
+
polygon_refinement = gr.Checkbox(
|
267 |
+
label="Polygon Refinement", value=False
|
268 |
+
)
|
269 |
+
seg_button = gr.Button("Run Segmentation")
|
270 |
+
|
271 |
+
with gr.Accordion("Generation Settings", open=False):
|
272 |
+
do_image_padding = gr.Checkbox(label="Do image padding", value=False)
|
273 |
+
seed = gr.Slider(
|
274 |
+
label="Seed",
|
275 |
+
minimum=0,
|
276 |
+
maximum=MAX_SEED,
|
277 |
+
step=1,
|
278 |
+
value=0,
|
279 |
+
)
|
280 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
281 |
+
num_inference_steps = gr.Slider(
|
282 |
+
label="Number of inference steps",
|
283 |
+
minimum=1,
|
284 |
+
maximum=50,
|
285 |
+
step=1,
|
286 |
+
value=50,
|
287 |
+
)
|
288 |
+
guidance_scale = gr.Slider(
|
289 |
+
label="CFG scale",
|
290 |
+
minimum=0.0,
|
291 |
+
maximum=10.0,
|
292 |
+
step=0.1,
|
293 |
+
value=7.0,
|
294 |
+
)
|
295 |
+
gen_button = gr.Button("Run Generation", variant="primary")
|
296 |
+
|
297 |
+
with gr.Column():
|
298 |
+
model_output = LitModel3D(label="Generated GLB", exposure=1.0, height=500)
|
299 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
300 |
+
|
301 |
+
with gr.Row():
|
302 |
+
gr.Examples(
|
303 |
+
examples=EXAMPLES,
|
304 |
+
fn=run_generation,
|
305 |
+
inputs=[image_prompts, seg_image, seed, randomize_seed, do_image_padding],
|
306 |
+
outputs=[model_output, download_glb, seed],
|
307 |
+
cache_examples=False,
|
308 |
+
)
|
309 |
+
|
310 |
+
seg_button.click(
|
311 |
+
run_segmentation,
|
312 |
+
inputs=[
|
313 |
+
image_prompts,
|
314 |
+
polygon_refinement,
|
315 |
+
],
|
316 |
+
outputs=[seg_image],
|
317 |
+
).then(lambda: gr.Button(interactive=True), outputs=[gen_button])
|
318 |
+
|
319 |
+
gen_button.click(
|
320 |
+
run_generation,
|
321 |
+
inputs=[
|
322 |
+
image_prompts,
|
323 |
+
seg_image,
|
324 |
+
seed,
|
325 |
+
randomize_seed,
|
326 |
+
num_inference_steps,
|
327 |
+
guidance_scale,
|
328 |
+
do_image_padding,
|
329 |
+
],
|
330 |
+
outputs=[model_output, download_glb, seed],
|
331 |
+
).then(lambda: gr.Button(interactive=True), outputs=[download_glb])
|
332 |
+
|
333 |
+
|
334 |
+
demo.launch()
|
assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_rgb.png
ADDED
![]() |
assets/example_data/3D-Front/ffb067ad-cf9a-4321-82ae-4e684c59ea3e_KidsRoom-5300_seg.png
ADDED
![]() |
assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_rgb.png
ADDED
![]() |
assets/example_data/3D-Front/ffd98024-7200-429e-8b9a-1234a5937826_LivingRoom-360_seg.png
ADDED
![]() |
assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_rgb.png
ADDED
![]() |
assets/example_data/3D-Front/fff98d42-99a4-43fc-9639-5761cb4f87df_SecondBedroom-127961_seg.png
ADDED
![]() |
assets/example_data/Cartoon-Style/00_rgb.png
ADDED
![]() |
assets/example_data/Cartoon-Style/00_seg.png
ADDED
![]() |
assets/example_data/Cartoon-Style/01_rgb.png
ADDED
![]() |
assets/example_data/Cartoon-Style/01_seg.png
ADDED
![]() |
assets/example_data/Cartoon-Style/02_rgb.png
ADDED
![]() |
assets/example_data/Cartoon-Style/02_seg.png
ADDED
![]() |
assets/example_data/Cartoon-Style/03_rgb.png
ADDED
![]() |
assets/example_data/Cartoon-Style/03_seg.png
ADDED
![]() |
assets/example_data/Cartoon-Style/04_rgb.png
ADDED
![]() |
assets/example_data/Cartoon-Style/04_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/00_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/00_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/01_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/01_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/02_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/02_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/03_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/03_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/04_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/04_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/05_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/05_seg.png
ADDED
![]() |
assets/example_data/Realistic-Style/06_rgb.png
ADDED
![]() |
assets/example_data/Realistic-Style/06_seg.png
ADDED
![]() |
midi/inference_utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
def generate_dense_grid_points(
|
10 |
+
bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij"
|
11 |
+
):
|
12 |
+
length = bbox_max - bbox_min
|
13 |
+
num_cells = np.exp2(octree_depth)
|
14 |
+
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
|
15 |
+
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
|
16 |
+
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
|
17 |
+
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
|
18 |
+
xyz = np.stack((xs, ys, zs), axis=-1)
|
19 |
+
xyz = xyz.reshape(-1, 3)
|
20 |
+
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
|
21 |
+
|
22 |
+
return xyz, grid_size, length
|
midi/loaders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .custom_adapter import CustomAdapterMixin
|
midi/loaders/custom_adapter.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
from diffusers.utils import _get_model_file, logging
|
7 |
+
from safetensors import safe_open
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
10 |
+
|
11 |
+
|
12 |
+
class CustomAdapterMixin:
|
13 |
+
def init_custom_adapter(self, *args, **kwargs):
|
14 |
+
self._init_custom_adapter(*args, **kwargs)
|
15 |
+
|
16 |
+
def _init_custom_adapter(self, *args, **kwargs):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def load_custom_adapter(
|
20 |
+
self,
|
21 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
22 |
+
weight_name: str,
|
23 |
+
subfolder: Optional[str] = None,
|
24 |
+
**kwargs,
|
25 |
+
):
|
26 |
+
# Load the main state dict first.
|
27 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
28 |
+
force_download = kwargs.pop("force_download", False)
|
29 |
+
resume_download = kwargs.pop("resume_download", False)
|
30 |
+
proxies = kwargs.pop("proxies", None)
|
31 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
32 |
+
token = kwargs.pop("token", None)
|
33 |
+
revision = kwargs.pop("revision", None)
|
34 |
+
|
35 |
+
user_agent = {
|
36 |
+
"file_type": "attn_procs_weights",
|
37 |
+
"framework": "pytorch",
|
38 |
+
}
|
39 |
+
|
40 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
41 |
+
model_file = _get_model_file(
|
42 |
+
pretrained_model_name_or_path_or_dict,
|
43 |
+
weights_name=weight_name,
|
44 |
+
cache_dir=cache_dir,
|
45 |
+
force_download=force_download,
|
46 |
+
proxies=proxies,
|
47 |
+
local_files_only=local_files_only,
|
48 |
+
token=token,
|
49 |
+
revision=revision,
|
50 |
+
subfolder=subfolder,
|
51 |
+
user_agent=user_agent,
|
52 |
+
)
|
53 |
+
if weight_name.endswith(".safetensors"):
|
54 |
+
state_dict = {}
|
55 |
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
56 |
+
for key in f.keys():
|
57 |
+
state_dict[key] = f.get_tensor(key)
|
58 |
+
else:
|
59 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
60 |
+
else:
|
61 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
62 |
+
|
63 |
+
self._load_custom_adapter(state_dict)
|
64 |
+
|
65 |
+
def _load_custom_adapter(self, state_dict):
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
def save_custom_adapter(
|
69 |
+
self,
|
70 |
+
save_directory: Union[str, os.PathLike],
|
71 |
+
weight_name: str,
|
72 |
+
safe_serialization: bool = False,
|
73 |
+
**kwargs,
|
74 |
+
):
|
75 |
+
if os.path.isfile(save_directory):
|
76 |
+
logger.error(
|
77 |
+
f"Provided path ({save_directory}) should be a directory, not a file"
|
78 |
+
)
|
79 |
+
return
|
80 |
+
|
81 |
+
if safe_serialization:
|
82 |
+
|
83 |
+
def save_function(weights, filename):
|
84 |
+
return safetensors.torch.save_file(
|
85 |
+
weights, filename, metadata={"format": "pt"}
|
86 |
+
)
|
87 |
+
|
88 |
+
else:
|
89 |
+
save_function = torch.save
|
90 |
+
|
91 |
+
# Save the model
|
92 |
+
state_dict = self._save_custom_adapter(**kwargs)
|
93 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
94 |
+
logger.info(
|
95 |
+
f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
|
96 |
+
)
|
97 |
+
|
98 |
+
def _save_custom_adapter(self):
|
99 |
+
raise NotImplementedError
|
midi/models/attention_processor.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from diffusers.models.attention_processor import Attention
|
6 |
+
from diffusers.utils import logging
|
7 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
8 |
+
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
9 |
+
from einops import rearrange
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
13 |
+
|
14 |
+
|
15 |
+
class TripoSGAttnProcessor2_0:
|
16 |
+
r"""
|
17 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
18 |
+
used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
23 |
+
raise ImportError(
|
24 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
25 |
+
)
|
26 |
+
|
27 |
+
def __call__(
|
28 |
+
self,
|
29 |
+
attn: Attention,
|
30 |
+
hidden_states: torch.Tensor,
|
31 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
32 |
+
attention_mask: Optional[torch.Tensor] = None,
|
33 |
+
temb: Optional[torch.Tensor] = None,
|
34 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
35 |
+
) -> torch.Tensor:
|
36 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
37 |
+
|
38 |
+
residual = hidden_states
|
39 |
+
if attn.spatial_norm is not None:
|
40 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
41 |
+
|
42 |
+
input_ndim = hidden_states.ndim
|
43 |
+
|
44 |
+
if input_ndim == 4:
|
45 |
+
batch_size, channel, height, width = hidden_states.shape
|
46 |
+
hidden_states = hidden_states.view(
|
47 |
+
batch_size, channel, height * width
|
48 |
+
).transpose(1, 2)
|
49 |
+
|
50 |
+
batch_size, sequence_length, _ = (
|
51 |
+
hidden_states.shape
|
52 |
+
if encoder_hidden_states is None
|
53 |
+
else encoder_hidden_states.shape
|
54 |
+
)
|
55 |
+
|
56 |
+
if attention_mask is not None:
|
57 |
+
attention_mask = attn.prepare_attention_mask(
|
58 |
+
attention_mask, sequence_length, batch_size
|
59 |
+
)
|
60 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
61 |
+
# (batch, heads, source_length, target_length)
|
62 |
+
attention_mask = attention_mask.view(
|
63 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
64 |
+
)
|
65 |
+
|
66 |
+
if attn.group_norm is not None:
|
67 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
68 |
+
1, 2
|
69 |
+
)
|
70 |
+
|
71 |
+
query = attn.to_q(hidden_states)
|
72 |
+
|
73 |
+
if encoder_hidden_states is None:
|
74 |
+
encoder_hidden_states = hidden_states
|
75 |
+
elif attn.norm_cross:
|
76 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
77 |
+
encoder_hidden_states
|
78 |
+
)
|
79 |
+
|
80 |
+
key = attn.to_k(encoder_hidden_states)
|
81 |
+
value = attn.to_v(encoder_hidden_states)
|
82 |
+
|
83 |
+
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
84 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
85 |
+
if not attn.is_cross_attention:
|
86 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
87 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
88 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
89 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
90 |
+
else:
|
91 |
+
kv = torch.cat((key, value), dim=-1)
|
92 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
93 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
94 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
95 |
+
|
96 |
+
head_dim = key.shape[-1]
|
97 |
+
|
98 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
99 |
+
|
100 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
101 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
102 |
+
|
103 |
+
if attn.norm_q is not None:
|
104 |
+
query = attn.norm_q(query)
|
105 |
+
if attn.norm_k is not None:
|
106 |
+
key = attn.norm_k(key)
|
107 |
+
|
108 |
+
# Apply RoPE if needed
|
109 |
+
if image_rotary_emb is not None:
|
110 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
111 |
+
if not attn.is_cross_attention:
|
112 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
113 |
+
|
114 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
115 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
116 |
+
hidden_states = F.scaled_dot_product_attention(
|
117 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
118 |
+
)
|
119 |
+
|
120 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
121 |
+
batch_size, -1, attn.heads * head_dim
|
122 |
+
)
|
123 |
+
hidden_states = hidden_states.to(query.dtype)
|
124 |
+
|
125 |
+
# linear proj
|
126 |
+
hidden_states = attn.to_out[0](hidden_states)
|
127 |
+
# dropout
|
128 |
+
hidden_states = attn.to_out[1](hidden_states)
|
129 |
+
|
130 |
+
if input_ndim == 4:
|
131 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
132 |
+
batch_size, channel, height, width
|
133 |
+
)
|
134 |
+
|
135 |
+
if attn.residual_connection:
|
136 |
+
hidden_states = hidden_states + residual
|
137 |
+
|
138 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
139 |
+
|
140 |
+
return hidden_states
|
141 |
+
|
142 |
+
|
143 |
+
class FusedTripoSGAttnProcessor2_0:
|
144 |
+
r"""
|
145 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
146 |
+
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
|
147 |
+
query and key vector.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self):
|
151 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
152 |
+
raise ImportError(
|
153 |
+
"FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
154 |
+
)
|
155 |
+
|
156 |
+
def __call__(
|
157 |
+
self,
|
158 |
+
attn: Attention,
|
159 |
+
hidden_states: torch.Tensor,
|
160 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
161 |
+
attention_mask: Optional[torch.Tensor] = None,
|
162 |
+
temb: Optional[torch.Tensor] = None,
|
163 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
166 |
+
|
167 |
+
residual = hidden_states
|
168 |
+
if attn.spatial_norm is not None:
|
169 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
170 |
+
|
171 |
+
input_ndim = hidden_states.ndim
|
172 |
+
|
173 |
+
if input_ndim == 4:
|
174 |
+
batch_size, channel, height, width = hidden_states.shape
|
175 |
+
hidden_states = hidden_states.view(
|
176 |
+
batch_size, channel, height * width
|
177 |
+
).transpose(1, 2)
|
178 |
+
|
179 |
+
batch_size, sequence_length, _ = (
|
180 |
+
hidden_states.shape
|
181 |
+
if encoder_hidden_states is None
|
182 |
+
else encoder_hidden_states.shape
|
183 |
+
)
|
184 |
+
|
185 |
+
if attention_mask is not None:
|
186 |
+
attention_mask = attn.prepare_attention_mask(
|
187 |
+
attention_mask, sequence_length, batch_size
|
188 |
+
)
|
189 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
190 |
+
# (batch, heads, source_length, target_length)
|
191 |
+
attention_mask = attention_mask.view(
|
192 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
193 |
+
)
|
194 |
+
|
195 |
+
if attn.group_norm is not None:
|
196 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
197 |
+
1, 2
|
198 |
+
)
|
199 |
+
|
200 |
+
# NOTE that pre-trained split heads first, then split qkv
|
201 |
+
if encoder_hidden_states is None:
|
202 |
+
qkv = attn.to_qkv(hidden_states)
|
203 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
204 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
205 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
206 |
+
else:
|
207 |
+
if attn.norm_cross:
|
208 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
209 |
+
encoder_hidden_states
|
210 |
+
)
|
211 |
+
query = attn.to_q(hidden_states)
|
212 |
+
|
213 |
+
kv = attn.to_kv(encoder_hidden_states)
|
214 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
215 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
216 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
217 |
+
|
218 |
+
head_dim = key.shape[-1]
|
219 |
+
|
220 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
221 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
222 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
223 |
+
|
224 |
+
if attn.norm_q is not None:
|
225 |
+
query = attn.norm_q(query)
|
226 |
+
if attn.norm_k is not None:
|
227 |
+
key = attn.norm_k(key)
|
228 |
+
|
229 |
+
# Apply RoPE if needed
|
230 |
+
if image_rotary_emb is not None:
|
231 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
232 |
+
if not attn.is_cross_attention:
|
233 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
234 |
+
|
235 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
236 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
237 |
+
hidden_states = F.scaled_dot_product_attention(
|
238 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
239 |
+
)
|
240 |
+
|
241 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
242 |
+
batch_size, -1, attn.heads * head_dim
|
243 |
+
)
|
244 |
+
hidden_states = hidden_states.to(query.dtype)
|
245 |
+
|
246 |
+
# linear proj
|
247 |
+
hidden_states = attn.to_out[0](hidden_states)
|
248 |
+
# dropout
|
249 |
+
hidden_states = attn.to_out[1](hidden_states)
|
250 |
+
|
251 |
+
if input_ndim == 4:
|
252 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
253 |
+
batch_size, channel, height, width
|
254 |
+
)
|
255 |
+
|
256 |
+
if attn.residual_connection:
|
257 |
+
hidden_states = hidden_states + residual
|
258 |
+
|
259 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
260 |
+
|
261 |
+
return hidden_states
|
262 |
+
|
263 |
+
|
264 |
+
class MIAttnProcessor2_0:
|
265 |
+
r"""
|
266 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
267 |
+
used in the MIDI model. It applies a normalization layer and rotary embedding on query and key vector.
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(self, use_mi: bool = True):
|
271 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
272 |
+
raise ImportError(
|
273 |
+
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
274 |
+
)
|
275 |
+
|
276 |
+
self.use_mi = use_mi
|
277 |
+
|
278 |
+
def __call__(
|
279 |
+
self,
|
280 |
+
attn: Attention,
|
281 |
+
hidden_states: torch.Tensor,
|
282 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
283 |
+
attention_mask: Optional[torch.Tensor] = None,
|
284 |
+
temb: Optional[torch.Tensor] = None,
|
285 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
286 |
+
num_instances: Optional[torch.IntTensor] = None,
|
287 |
+
) -> torch.Tensor:
|
288 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
289 |
+
|
290 |
+
residual = hidden_states
|
291 |
+
if attn.spatial_norm is not None:
|
292 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
293 |
+
|
294 |
+
input_ndim = hidden_states.ndim
|
295 |
+
|
296 |
+
if input_ndim == 4:
|
297 |
+
batch_size, channel, height, width = hidden_states.shape
|
298 |
+
hidden_states = hidden_states.view(
|
299 |
+
batch_size, channel, height * width
|
300 |
+
).transpose(1, 2)
|
301 |
+
|
302 |
+
batch_size, sequence_length, _ = (
|
303 |
+
hidden_states.shape
|
304 |
+
if encoder_hidden_states is None
|
305 |
+
else encoder_hidden_states.shape
|
306 |
+
)
|
307 |
+
|
308 |
+
if attention_mask is not None:
|
309 |
+
attention_mask = attn.prepare_attention_mask(
|
310 |
+
attention_mask, sequence_length, batch_size
|
311 |
+
)
|
312 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
313 |
+
# (batch, heads, source_length, target_length)
|
314 |
+
attention_mask = attention_mask.view(
|
315 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
316 |
+
)
|
317 |
+
|
318 |
+
if attn.group_norm is not None:
|
319 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
320 |
+
1, 2
|
321 |
+
)
|
322 |
+
|
323 |
+
query = attn.to_q(hidden_states)
|
324 |
+
|
325 |
+
if encoder_hidden_states is None:
|
326 |
+
encoder_hidden_states = hidden_states
|
327 |
+
elif attn.norm_cross:
|
328 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
329 |
+
encoder_hidden_states
|
330 |
+
)
|
331 |
+
|
332 |
+
key = attn.to_k(encoder_hidden_states)
|
333 |
+
value = attn.to_v(encoder_hidden_states)
|
334 |
+
|
335 |
+
# NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim)
|
336 |
+
# instead of .view(..., 3, attn.heads, dim). So we need to re-split here.
|
337 |
+
if not attn.is_cross_attention:
|
338 |
+
qkv = torch.cat((query, key, value), dim=-1)
|
339 |
+
split_size = qkv.shape[-1] // attn.heads // 3
|
340 |
+
qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3)
|
341 |
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
342 |
+
else:
|
343 |
+
kv = torch.cat((key, value), dim=-1)
|
344 |
+
split_size = kv.shape[-1] // attn.heads // 2
|
345 |
+
kv = kv.view(batch_size, -1, attn.heads, split_size * 2)
|
346 |
+
key, value = torch.split(kv, split_size, dim=-1)
|
347 |
+
|
348 |
+
head_dim = key.shape[-1]
|
349 |
+
|
350 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
351 |
+
|
352 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
353 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
354 |
+
|
355 |
+
if attn.norm_q is not None:
|
356 |
+
query = attn.norm_q(query)
|
357 |
+
if attn.norm_k is not None:
|
358 |
+
key = attn.norm_k(key)
|
359 |
+
|
360 |
+
# Apply RoPE if needed
|
361 |
+
if image_rotary_emb is not None:
|
362 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
363 |
+
if not attn.is_cross_attention:
|
364 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
365 |
+
|
366 |
+
if self.use_mi and num_instances is not None:
|
367 |
+
key = rearrange(
|
368 |
+
key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
|
369 |
+
).repeat_interleave(num_instances, dim=0)
|
370 |
+
value = rearrange(
|
371 |
+
value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances
|
372 |
+
).repeat_interleave(num_instances, dim=0)
|
373 |
+
|
374 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
375 |
+
hidden_states = F.scaled_dot_product_attention(
|
376 |
+
query,
|
377 |
+
key,
|
378 |
+
value,
|
379 |
+
dropout_p=0.0,
|
380 |
+
is_causal=False,
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
hidden_states = F.scaled_dot_product_attention(
|
384 |
+
query,
|
385 |
+
key,
|
386 |
+
value,
|
387 |
+
attn_mask=attention_mask,
|
388 |
+
dropout_p=0.0,
|
389 |
+
is_causal=False,
|
390 |
+
)
|
391 |
+
|
392 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
393 |
+
batch_size, -1, attn.heads * head_dim
|
394 |
+
)
|
395 |
+
hidden_states = hidden_states.to(query.dtype)
|
396 |
+
|
397 |
+
# linear proj
|
398 |
+
hidden_states = attn.to_out[0](hidden_states)
|
399 |
+
# dropout
|
400 |
+
hidden_states = attn.to_out[1](hidden_states)
|
401 |
+
|
402 |
+
if input_ndim == 4:
|
403 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
404 |
+
batch_size, channel, height, width
|
405 |
+
)
|
406 |
+
|
407 |
+
if attn.residual_connection:
|
408 |
+
hidden_states = hidden_states + residual
|
409 |
+
|
410 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
411 |
+
|
412 |
+
return hidden_states
|
midi/models/autoencoders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .autoencoder_kl_triposg import TripoSGVAEModel
|
midi/models/autoencoders/autoencoder_kl_triposg.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
7 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
8 |
+
from diffusers.models.autoencoders.vae import DecoderOutput
|
9 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
10 |
+
from diffusers.models.modeling_utils import ModelMixin
|
11 |
+
from diffusers.models.normalization import FP32LayerNorm, LayerNorm
|
12 |
+
from diffusers.utils import logging
|
13 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
14 |
+
from einops import repeat
|
15 |
+
from tqdm import tqdm
|
16 |
+
from torch_cluster import fps
|
17 |
+
|
18 |
+
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
|
19 |
+
from ..embeddings import FrequencyPositionalEmbedding
|
20 |
+
from ..transformers.triposg_transformer import DiTBlock
|
21 |
+
from .vae import DiagonalGaussianDistribution
|
22 |
+
|
23 |
+
import subprocess
|
24 |
+
import sys
|
25 |
+
|
26 |
+
|
27 |
+
def install_package(package_name):
|
28 |
+
try:
|
29 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
|
30 |
+
return True
|
31 |
+
except subprocess.CalledProcessError:
|
32 |
+
return False
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36 |
+
|
37 |
+
|
38 |
+
class TripoSGEncoder(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_channels: int = 3,
|
42 |
+
dim: int = 512,
|
43 |
+
num_attention_heads: int = 8,
|
44 |
+
num_layers: int = 8,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.proj_in = nn.Linear(in_channels, dim, bias=True)
|
49 |
+
|
50 |
+
self.blocks = nn.ModuleList(
|
51 |
+
[
|
52 |
+
DiTBlock(
|
53 |
+
dim=dim,
|
54 |
+
num_attention_heads=num_attention_heads,
|
55 |
+
use_self_attention=False,
|
56 |
+
use_cross_attention=True,
|
57 |
+
cross_attention_dim=dim,
|
58 |
+
cross_attention_norm_type="layer_norm",
|
59 |
+
activation_fn="gelu",
|
60 |
+
norm_type="fp32_layer_norm",
|
61 |
+
norm_eps=1e-5,
|
62 |
+
qk_norm=False,
|
63 |
+
qkv_bias=False,
|
64 |
+
) # cross attention
|
65 |
+
]
|
66 |
+
+ [
|
67 |
+
DiTBlock(
|
68 |
+
dim=dim,
|
69 |
+
num_attention_heads=num_attention_heads,
|
70 |
+
use_self_attention=True,
|
71 |
+
self_attention_norm_type="fp32_layer_norm",
|
72 |
+
use_cross_attention=False,
|
73 |
+
use_cross_attention_2=False,
|
74 |
+
activation_fn="gelu",
|
75 |
+
norm_type="fp32_layer_norm",
|
76 |
+
norm_eps=1e-5,
|
77 |
+
qk_norm=False,
|
78 |
+
qkv_bias=False,
|
79 |
+
)
|
80 |
+
for _ in range(num_layers) # self attention
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
self.norm_out = LayerNorm(dim)
|
85 |
+
|
86 |
+
def forward(self, sample_1: torch.Tensor, sample_2: torch.Tensor):
|
87 |
+
hidden_states = self.proj_in(sample_1)
|
88 |
+
encoder_hidden_states = self.proj_in(sample_2)
|
89 |
+
|
90 |
+
for layer, block in enumerate(self.blocks):
|
91 |
+
if layer == 0:
|
92 |
+
hidden_states = block(
|
93 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
hidden_states = block(hidden_states)
|
97 |
+
|
98 |
+
hidden_states = self.norm_out(hidden_states)
|
99 |
+
|
100 |
+
return hidden_states
|
101 |
+
|
102 |
+
|
103 |
+
class TripoSGDecoder(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
in_channels: int = 3,
|
107 |
+
out_channels: int = 1,
|
108 |
+
dim: int = 512,
|
109 |
+
num_attention_heads: int = 8,
|
110 |
+
num_layers: int = 16,
|
111 |
+
grad_type: str = "analytical",
|
112 |
+
grad_interval: float = 0.001,
|
113 |
+
):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
if grad_type not in ["numerical", "analytical"]:
|
117 |
+
raise ValueError(f"grad_type must be one of ['numerical', 'analytical']")
|
118 |
+
self.grad_type = grad_type
|
119 |
+
self.grad_interval = grad_interval
|
120 |
+
|
121 |
+
self.blocks = nn.ModuleList(
|
122 |
+
[
|
123 |
+
DiTBlock(
|
124 |
+
dim=dim,
|
125 |
+
num_attention_heads=num_attention_heads,
|
126 |
+
use_self_attention=True,
|
127 |
+
self_attention_norm_type="fp32_layer_norm",
|
128 |
+
use_cross_attention=False,
|
129 |
+
use_cross_attention_2=False,
|
130 |
+
activation_fn="gelu",
|
131 |
+
norm_type="fp32_layer_norm",
|
132 |
+
norm_eps=1e-5,
|
133 |
+
qk_norm=False,
|
134 |
+
qkv_bias=False,
|
135 |
+
)
|
136 |
+
for _ in range(num_layers) # self attention
|
137 |
+
]
|
138 |
+
+ [
|
139 |
+
DiTBlock(
|
140 |
+
dim=dim,
|
141 |
+
num_attention_heads=num_attention_heads,
|
142 |
+
use_self_attention=False,
|
143 |
+
use_cross_attention=True,
|
144 |
+
cross_attention_dim=dim,
|
145 |
+
cross_attention_norm_type="layer_norm",
|
146 |
+
activation_fn="gelu",
|
147 |
+
norm_type="fp32_layer_norm",
|
148 |
+
norm_eps=1e-5,
|
149 |
+
qk_norm=False,
|
150 |
+
qkv_bias=False,
|
151 |
+
) # cross attention
|
152 |
+
]
|
153 |
+
)
|
154 |
+
|
155 |
+
self.proj_query = nn.Linear(in_channels, dim, bias=True)
|
156 |
+
|
157 |
+
self.norm_out = LayerNorm(dim)
|
158 |
+
self.proj_out = nn.Linear(dim, out_channels, bias=True)
|
159 |
+
|
160 |
+
def query_geometry(
|
161 |
+
self,
|
162 |
+
model_fn: callable,
|
163 |
+
queries: torch.Tensor,
|
164 |
+
sample: torch.Tensor,
|
165 |
+
grad: bool = False,
|
166 |
+
):
|
167 |
+
logits = model_fn(queries, sample)
|
168 |
+
if grad:
|
169 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
170 |
+
if self.grad_type == "numerical":
|
171 |
+
interval = self.grad_interval
|
172 |
+
grad_value = []
|
173 |
+
for offset in [
|
174 |
+
(interval, 0, 0),
|
175 |
+
(0, interval, 0),
|
176 |
+
(0, 0, interval),
|
177 |
+
]:
|
178 |
+
offset_tensor = torch.tensor(offset, device=queries.device)[
|
179 |
+
None, :
|
180 |
+
]
|
181 |
+
res_p = model_fn(queries + offset_tensor, sample)[..., 0]
|
182 |
+
res_n = model_fn(queries - offset_tensor, sample)[..., 0]
|
183 |
+
grad_value.append((res_p - res_n) / (2 * interval))
|
184 |
+
grad_value = torch.stack(grad_value, dim=-1)
|
185 |
+
else:
|
186 |
+
queries_d = torch.clone(queries)
|
187 |
+
queries_d.requires_grad = True
|
188 |
+
with torch.enable_grad():
|
189 |
+
res_d = model_fn(queries_d, sample)
|
190 |
+
grad_value = torch.autograd.grad(
|
191 |
+
res_d,
|
192 |
+
[queries_d],
|
193 |
+
grad_outputs=torch.ones_like(res_d),
|
194 |
+
create_graph=self.training,
|
195 |
+
)[0]
|
196 |
+
else:
|
197 |
+
grad_value = None
|
198 |
+
|
199 |
+
return logits, grad_value
|
200 |
+
|
201 |
+
def forward(
|
202 |
+
self,
|
203 |
+
sample: torch.Tensor,
|
204 |
+
queries: torch.Tensor,
|
205 |
+
kv_cache: Optional[torch.Tensor] = None,
|
206 |
+
):
|
207 |
+
if kv_cache is None:
|
208 |
+
hidden_states = sample
|
209 |
+
for _, block in enumerate(self.blocks[:-1]):
|
210 |
+
hidden_states = block(hidden_states)
|
211 |
+
kv_cache = hidden_states
|
212 |
+
|
213 |
+
# query grid logits by cross attention
|
214 |
+
def query_fn(q, kv):
|
215 |
+
q = self.proj_query(q)
|
216 |
+
l = self.blocks[-1](q, encoder_hidden_states=kv)
|
217 |
+
return self.proj_out(self.norm_out(l))
|
218 |
+
|
219 |
+
logits, grad = self.query_geometry(
|
220 |
+
query_fn, queries, kv_cache, grad=self.training
|
221 |
+
)
|
222 |
+
logits = logits * -1 if not isinstance(logits, Tuple) else logits[0] * -1
|
223 |
+
|
224 |
+
return logits, kv_cache
|
225 |
+
|
226 |
+
|
227 |
+
class TripoSGVAEModel(ModelMixin, ConfigMixin):
|
228 |
+
@register_to_config
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
in_channels: int = 3, # NOTE xyz instead of feature dim
|
232 |
+
latent_channels: int = 64,
|
233 |
+
num_attention_heads: int = 8,
|
234 |
+
width_encoder: int = 512,
|
235 |
+
width_decoder: int = 1024,
|
236 |
+
num_layers_encoder: int = 8,
|
237 |
+
num_layers_decoder: int = 16,
|
238 |
+
embedding_type: str = "frequency",
|
239 |
+
embed_frequency: int = 8,
|
240 |
+
embed_include_pi: bool = False,
|
241 |
+
):
|
242 |
+
super().__init__()
|
243 |
+
|
244 |
+
self.out_channels = 1
|
245 |
+
|
246 |
+
if embedding_type == "frequency":
|
247 |
+
self.embedder = FrequencyPositionalEmbedding(
|
248 |
+
num_freqs=embed_frequency,
|
249 |
+
logspace=True,
|
250 |
+
input_dim=in_channels,
|
251 |
+
include_pi=embed_include_pi,
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
raise NotImplementedError(
|
255 |
+
f"Embedding type {embedding_type} is not supported."
|
256 |
+
)
|
257 |
+
|
258 |
+
self.encoder = TripoSGEncoder(
|
259 |
+
in_channels=in_channels + self.embedder.out_dim,
|
260 |
+
dim=width_encoder,
|
261 |
+
num_attention_heads=num_attention_heads,
|
262 |
+
num_layers=num_layers_encoder,
|
263 |
+
)
|
264 |
+
self.decoder = TripoSGDecoder(
|
265 |
+
in_channels=self.embedder.out_dim,
|
266 |
+
out_channels=self.out_channels,
|
267 |
+
dim=width_decoder,
|
268 |
+
num_attention_heads=num_attention_heads,
|
269 |
+
num_layers=num_layers_decoder,
|
270 |
+
)
|
271 |
+
|
272 |
+
self.quant = nn.Linear(width_encoder, latent_channels * 2, bias=True)
|
273 |
+
self.post_quant = nn.Linear(latent_channels, width_decoder, bias=True)
|
274 |
+
|
275 |
+
self.use_slicing = False
|
276 |
+
self.slicing_length = 1
|
277 |
+
|
278 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
|
279 |
+
def fuse_qkv_projections(self):
|
280 |
+
"""
|
281 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
282 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
283 |
+
|
284 |
+
<Tip warning={true}>
|
285 |
+
|
286 |
+
This API is 🧪 experimental.
|
287 |
+
|
288 |
+
</Tip>
|
289 |
+
"""
|
290 |
+
self.original_attn_processors = None
|
291 |
+
|
292 |
+
for _, attn_processor in self.attn_processors.items():
|
293 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
294 |
+
raise ValueError(
|
295 |
+
"`fuse_qkv_projections()` is not supported for models having added KV projections."
|
296 |
+
)
|
297 |
+
|
298 |
+
self.original_attn_processors = self.attn_processors
|
299 |
+
|
300 |
+
for module in self.modules():
|
301 |
+
if isinstance(module, Attention):
|
302 |
+
module.fuse_projections(fuse=True)
|
303 |
+
|
304 |
+
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
|
305 |
+
|
306 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
307 |
+
def unfuse_qkv_projections(self):
|
308 |
+
"""Disables the fused QKV projection if enabled.
|
309 |
+
|
310 |
+
<Tip warning={true}>
|
311 |
+
|
312 |
+
This API is 🧪 experimental.
|
313 |
+
|
314 |
+
</Tip>
|
315 |
+
|
316 |
+
"""
|
317 |
+
if self.original_attn_processors is not None:
|
318 |
+
self.set_attn_processor(self.original_attn_processors)
|
319 |
+
|
320 |
+
@property
|
321 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
322 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
323 |
+
r"""
|
324 |
+
Returns:
|
325 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
326 |
+
indexed by its weight name.
|
327 |
+
"""
|
328 |
+
# set recursively
|
329 |
+
processors = {}
|
330 |
+
|
331 |
+
def fn_recursive_add_processors(
|
332 |
+
name: str,
|
333 |
+
module: torch.nn.Module,
|
334 |
+
processors: Dict[str, AttentionProcessor],
|
335 |
+
):
|
336 |
+
if hasattr(module, "get_processor"):
|
337 |
+
processors[f"{name}.processor"] = module.get_processor()
|
338 |
+
|
339 |
+
for sub_name, child in module.named_children():
|
340 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
341 |
+
|
342 |
+
return processors
|
343 |
+
|
344 |
+
for name, module in self.named_children():
|
345 |
+
fn_recursive_add_processors(name, module, processors)
|
346 |
+
|
347 |
+
return processors
|
348 |
+
|
349 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
350 |
+
def set_attn_processor(
|
351 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
352 |
+
):
|
353 |
+
r"""
|
354 |
+
Sets the attention processor to use to compute attention.
|
355 |
+
|
356 |
+
Parameters:
|
357 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
358 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
359 |
+
for **all** `Attention` layers.
|
360 |
+
|
361 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
362 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
363 |
+
|
364 |
+
"""
|
365 |
+
count = len(self.attn_processors.keys())
|
366 |
+
|
367 |
+
if isinstance(processor, dict) and len(processor) != count:
|
368 |
+
raise ValueError(
|
369 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
370 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
371 |
+
)
|
372 |
+
|
373 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
374 |
+
if hasattr(module, "set_processor"):
|
375 |
+
if not isinstance(processor, dict):
|
376 |
+
module.set_processor(processor)
|
377 |
+
else:
|
378 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
379 |
+
|
380 |
+
for sub_name, child in module.named_children():
|
381 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
382 |
+
|
383 |
+
for name, module in self.named_children():
|
384 |
+
fn_recursive_attn_processor(name, module, processor)
|
385 |
+
|
386 |
+
def set_default_attn_processor(self):
|
387 |
+
"""
|
388 |
+
Disables custom attention processors and sets the default attention implementation.
|
389 |
+
"""
|
390 |
+
self.set_attn_processor(TripoSGAttnProcessor2_0())
|
391 |
+
|
392 |
+
def enable_slicing(self, slicing_length: int = 1) -> None:
|
393 |
+
r"""
|
394 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
395 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
396 |
+
"""
|
397 |
+
self.use_slicing = True
|
398 |
+
self.slicing_length = slicing_length
|
399 |
+
|
400 |
+
def disable_slicing(self) -> None:
|
401 |
+
r"""
|
402 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
403 |
+
decoding in one step.
|
404 |
+
"""
|
405 |
+
self.use_slicing = False
|
406 |
+
|
407 |
+
def _sample_features(
|
408 |
+
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
|
409 |
+
):
|
410 |
+
"""
|
411 |
+
Sample points from features of the input point cloud.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
x (torch.Tensor): The input point cloud. shape: (B, N, C)
|
415 |
+
num_tokens (int, optional): The number of points to sample. Defaults to 2048.
|
416 |
+
seed (Optional[int], optional): The random seed. Defaults to None.
|
417 |
+
"""
|
418 |
+
rng = np.random.default_rng(seed)
|
419 |
+
indices = rng.choice(
|
420 |
+
x.shape[1], num_tokens * 4, replace=num_tokens * 4 > x.shape[1]
|
421 |
+
)
|
422 |
+
selected_points = x[:, indices]
|
423 |
+
|
424 |
+
batch_size, num_points, num_channels = selected_points.shape
|
425 |
+
flattened_points = selected_points.view(batch_size * num_points, num_channels)
|
426 |
+
batch_indices = (
|
427 |
+
torch.arange(batch_size).to(x.device).repeat_interleave(num_points)
|
428 |
+
)
|
429 |
+
|
430 |
+
# fps sampling
|
431 |
+
sampling_ratio = 1.0 / 4
|
432 |
+
sampled_indices = fps(
|
433 |
+
flattened_points[:, :3],
|
434 |
+
batch_indices,
|
435 |
+
ratio=sampling_ratio,
|
436 |
+
random_start=self.training,
|
437 |
+
)
|
438 |
+
sampled_points = flattened_points[sampled_indices].view(
|
439 |
+
batch_size, -1, num_channels
|
440 |
+
)
|
441 |
+
|
442 |
+
return sampled_points
|
443 |
+
|
444 |
+
def _encode(
|
445 |
+
self, x: torch.Tensor, num_tokens: int = 2048, seed: Optional[int] = None
|
446 |
+
):
|
447 |
+
position_channels = self.config.in_channels
|
448 |
+
positions, features = x[..., :position_channels], x[..., position_channels:]
|
449 |
+
x_kv = torch.cat([self.embedder(positions), features], dim=-1)
|
450 |
+
|
451 |
+
sampled_x = self._sample_features(x, num_tokens, seed)
|
452 |
+
positions, features = (
|
453 |
+
sampled_x[..., :position_channels],
|
454 |
+
sampled_x[..., position_channels:],
|
455 |
+
)
|
456 |
+
x_q = torch.cat([self.embedder(positions), features], dim=-1)
|
457 |
+
|
458 |
+
x = self.encoder(x_q, x_kv)
|
459 |
+
|
460 |
+
x = self.quant(x)
|
461 |
+
|
462 |
+
return x
|
463 |
+
|
464 |
+
@apply_forward_hook
|
465 |
+
def encode(
|
466 |
+
self, x: torch.Tensor, return_dict: bool = True, **kwargs
|
467 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
468 |
+
"""
|
469 |
+
Encode a batch of point features into latents.
|
470 |
+
"""
|
471 |
+
if self.use_slicing and x.shape[0] > 1:
|
472 |
+
encoded_slices = [
|
473 |
+
self._encode(x_slice, **kwargs)
|
474 |
+
for x_slice in x.split(self.slicing_length)
|
475 |
+
]
|
476 |
+
h = torch.cat(encoded_slices)
|
477 |
+
else:
|
478 |
+
h = self._encode(x, **kwargs)
|
479 |
+
|
480 |
+
posterior = DiagonalGaussianDistribution(h, feature_dim=-1)
|
481 |
+
|
482 |
+
if not return_dict:
|
483 |
+
return (posterior,)
|
484 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
485 |
+
|
486 |
+
def _decode(
|
487 |
+
self,
|
488 |
+
z: torch.Tensor,
|
489 |
+
sampled_points: torch.Tensor,
|
490 |
+
num_chunks: int = 50000,
|
491 |
+
to_cpu: bool = False,
|
492 |
+
return_dict: bool = True,
|
493 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
494 |
+
xyz_samples = sampled_points
|
495 |
+
|
496 |
+
z = self.post_quant(z)
|
497 |
+
|
498 |
+
num_points = xyz_samples.shape[1]
|
499 |
+
kv_cache = None
|
500 |
+
dec = []
|
501 |
+
|
502 |
+
for i in range(0, num_points, num_chunks):
|
503 |
+
queries = xyz_samples[:, i : i + num_chunks, :].to(z.device, dtype=z.dtype)
|
504 |
+
queries = self.embedder(queries)
|
505 |
+
|
506 |
+
z_, kv_cache = self.decoder(z, queries, kv_cache)
|
507 |
+
dec.append(z_ if not to_cpu else z_.cpu())
|
508 |
+
|
509 |
+
z = torch.cat(dec, dim=1)
|
510 |
+
|
511 |
+
if not return_dict:
|
512 |
+
return (z,)
|
513 |
+
|
514 |
+
return DecoderOutput(sample=z)
|
515 |
+
|
516 |
+
@apply_forward_hook
|
517 |
+
def decode(
|
518 |
+
self,
|
519 |
+
z: torch.Tensor,
|
520 |
+
sampled_points: torch.Tensor,
|
521 |
+
return_dict: bool = True,
|
522 |
+
**kwargs,
|
523 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
524 |
+
if self.use_slicing and z.shape[0] > 1:
|
525 |
+
decoded_slices = [
|
526 |
+
self._decode(z_slice, p_slice, **kwargs).sample
|
527 |
+
for z_slice, p_slice in zip(
|
528 |
+
z.split(self.slicing_length),
|
529 |
+
sampled_points.split(self.slicing_length),
|
530 |
+
)
|
531 |
+
]
|
532 |
+
decoded = torch.cat(decoded_slices)
|
533 |
+
else:
|
534 |
+
decoded = self._decode(z, sampled_points, **kwargs).sample
|
535 |
+
|
536 |
+
if not return_dict:
|
537 |
+
return (decoded,)
|
538 |
+
return DecoderOutput(sample=decoded)
|
539 |
+
|
540 |
+
def forward(self, x: torch.Tensor):
|
541 |
+
pass
|
midi/models/autoencoders/vae.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
6 |
+
|
7 |
+
|
8 |
+
class DiagonalGaussianDistribution(object):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
parameters: torch.Tensor,
|
12 |
+
deterministic: bool = False,
|
13 |
+
feature_dim: int = 1,
|
14 |
+
):
|
15 |
+
self.parameters = parameters
|
16 |
+
self.feature_dim = feature_dim
|
17 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=feature_dim)
|
18 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
19 |
+
self.deterministic = deterministic
|
20 |
+
self.std = torch.exp(0.5 * self.logvar)
|
21 |
+
self.var = torch.exp(self.logvar)
|
22 |
+
if self.deterministic:
|
23 |
+
self.var = self.std = torch.zeros_like(
|
24 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
25 |
+
)
|
26 |
+
|
27 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
|
28 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
29 |
+
sample = randn_tensor(
|
30 |
+
self.mean.shape,
|
31 |
+
generator=generator,
|
32 |
+
device=self.parameters.device,
|
33 |
+
dtype=self.parameters.dtype,
|
34 |
+
)
|
35 |
+
x = self.mean + self.std * sample
|
36 |
+
return x
|
37 |
+
|
38 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
39 |
+
if self.deterministic:
|
40 |
+
return torch.Tensor([0.0])
|
41 |
+
else:
|
42 |
+
if other is None:
|
43 |
+
return 0.5 * torch.sum(
|
44 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
45 |
+
dim=[1, 2, 3],
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
return 0.5 * torch.sum(
|
49 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
50 |
+
+ self.var / other.var
|
51 |
+
- 1.0
|
52 |
+
- self.logvar
|
53 |
+
+ other.logvar,
|
54 |
+
dim=[1, 2, 3],
|
55 |
+
)
|
56 |
+
|
57 |
+
def nll(
|
58 |
+
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
|
59 |
+
) -> torch.Tensor:
|
60 |
+
if self.deterministic:
|
61 |
+
return torch.Tensor([0.0])
|
62 |
+
logtwopi = np.log(2.0 * np.pi)
|
63 |
+
return 0.5 * torch.sum(
|
64 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
65 |
+
dim=dims,
|
66 |
+
)
|
67 |
+
|
68 |
+
def mode(self) -> torch.Tensor:
|
69 |
+
return self.mean
|
midi/models/embeddings.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class FrequencyPositionalEmbedding(nn.Module):
|
6 |
+
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
|
7 |
+
each feature dimension of `x[..., i]` into:
|
8 |
+
[
|
9 |
+
sin(x[..., i]),
|
10 |
+
sin(f_1*x[..., i]),
|
11 |
+
sin(f_2*x[..., i]),
|
12 |
+
...
|
13 |
+
sin(f_N * x[..., i]),
|
14 |
+
cos(x[..., i]),
|
15 |
+
cos(f_1*x[..., i]),
|
16 |
+
cos(f_2*x[..., i]),
|
17 |
+
...
|
18 |
+
cos(f_N * x[..., i]),
|
19 |
+
x[..., i] # only present if include_input is True.
|
20 |
+
], here f_i is the frequency.
|
21 |
+
|
22 |
+
Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
|
23 |
+
If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
|
24 |
+
Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
|
25 |
+
|
26 |
+
Args:
|
27 |
+
num_freqs (int): the number of frequencies, default is 6;
|
28 |
+
logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
29 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
|
30 |
+
input_dim (int): the input dimension, default is 3;
|
31 |
+
include_input (bool): include the input tensor or not, default is True.
|
32 |
+
|
33 |
+
Attributes:
|
34 |
+
frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
|
35 |
+
otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
|
36 |
+
|
37 |
+
out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
|
38 |
+
otherwise, it is input_dim * num_freqs * 2.
|
39 |
+
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
num_freqs: int = 6,
|
45 |
+
logspace: bool = True,
|
46 |
+
input_dim: int = 3,
|
47 |
+
include_input: bool = True,
|
48 |
+
include_pi: bool = True,
|
49 |
+
) -> None:
|
50 |
+
"""The initialization"""
|
51 |
+
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
if logspace:
|
55 |
+
frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
|
56 |
+
else:
|
57 |
+
frequencies = torch.linspace(
|
58 |
+
1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
|
59 |
+
)
|
60 |
+
|
61 |
+
if include_pi:
|
62 |
+
frequencies *= torch.pi
|
63 |
+
|
64 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
65 |
+
self.include_input = include_input
|
66 |
+
self.num_freqs = num_freqs
|
67 |
+
|
68 |
+
self.out_dim = self.get_dims(input_dim)
|
69 |
+
|
70 |
+
def get_dims(self, input_dim):
|
71 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
72 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
73 |
+
|
74 |
+
return out_dim
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
"""Forward process.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
x: tensor of shape [..., dim]
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
|
84 |
+
where temp is 1 if include_input is True and 0 otherwise.
|
85 |
+
"""
|
86 |
+
|
87 |
+
if self.num_freqs > 0:
|
88 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(
|
89 |
+
*x.shape[:-1], -1
|
90 |
+
)
|
91 |
+
if self.include_input:
|
92 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
93 |
+
else:
|
94 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
95 |
+
else:
|
96 |
+
return x
|
midi/models/transformers/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Optional
|
2 |
+
|
3 |
+
from .triposg_transformer import TripoSGDiTModel
|
4 |
+
|
5 |
+
|
6 |
+
def default_set_attn_proc_func(
|
7 |
+
name: str,
|
8 |
+
hidden_size: int,
|
9 |
+
cross_attention_dim: Optional[int],
|
10 |
+
ori_attn_proc: object,
|
11 |
+
) -> object:
|
12 |
+
return ori_attn_proc
|
13 |
+
|
14 |
+
|
15 |
+
def set_transformer_attn_processor(
|
16 |
+
transformer: TripoSGDiTModel,
|
17 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
18 |
+
set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
|
19 |
+
set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
|
20 |
+
set_self_attn_module_names: Optional[list[str]] = None,
|
21 |
+
set_cross_attn_1_module_names: Optional[list[str]] = None,
|
22 |
+
set_cross_attn_2_module_names: Optional[list[str]] = None,
|
23 |
+
) -> None:
|
24 |
+
do_set_processor = lambda name, module_names: (
|
25 |
+
any([name.startswith(module_name) for module_name in module_names])
|
26 |
+
if module_names is not None
|
27 |
+
else True
|
28 |
+
) # prefix match
|
29 |
+
|
30 |
+
attn_procs = {}
|
31 |
+
for name, attn_processor in transformer.attn_processors.items():
|
32 |
+
hidden_size = transformer.config.width
|
33 |
+
if name.endswith("attn1.processor"):
|
34 |
+
# self attention
|
35 |
+
attn_procs[name] = (
|
36 |
+
set_self_attn_proc_func(name, hidden_size, None, attn_processor)
|
37 |
+
if do_set_processor(name, set_self_attn_module_names)
|
38 |
+
else attn_processor
|
39 |
+
)
|
40 |
+
elif name.endswith("attn2.processor"):
|
41 |
+
# cross attention
|
42 |
+
cross_attention_dim = transformer.config.cross_attention_dim
|
43 |
+
attn_procs[name] = (
|
44 |
+
set_cross_attn_1_proc_func(
|
45 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
46 |
+
)
|
47 |
+
if do_set_processor(name, set_cross_attn_1_module_names)
|
48 |
+
else attn_processor
|
49 |
+
)
|
50 |
+
elif name.endswith("attn2_2.processor"):
|
51 |
+
# cross attention 2
|
52 |
+
cross_attention_dim = transformer.config.cross_attention_2_dim
|
53 |
+
attn_procs[name] = (
|
54 |
+
set_cross_attn_2_proc_func(
|
55 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
56 |
+
)
|
57 |
+
if do_set_processor(name, set_cross_attn_2_module_names)
|
58 |
+
else attn_processor
|
59 |
+
)
|
60 |
+
|
61 |
+
transformer.set_attn_processor(attn_procs)
|
midi/models/transformers/modeling_outputs.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Transformer1DModelOutput:
|
8 |
+
sample: torch.FloatTensor
|
midi/models/transformers/triposg_transformer.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
19 |
+
from diffusers.loaders import PeftAdapterMixin
|
20 |
+
from diffusers.models.attention import FeedForward
|
21 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
22 |
+
from diffusers.models.embeddings import (
|
23 |
+
GaussianFourierProjection,
|
24 |
+
TimestepEmbedding,
|
25 |
+
Timesteps,
|
26 |
+
)
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from diffusers.models.normalization import (
|
29 |
+
AdaLayerNormContinuous,
|
30 |
+
FP32LayerNorm,
|
31 |
+
LayerNorm,
|
32 |
+
)
|
33 |
+
from diffusers.utils import (
|
34 |
+
USE_PEFT_BACKEND,
|
35 |
+
is_torch_version,
|
36 |
+
logging,
|
37 |
+
scale_lora_layers,
|
38 |
+
unscale_lora_layers,
|
39 |
+
)
|
40 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
41 |
+
from torch import nn
|
42 |
+
|
43 |
+
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
|
44 |
+
from .modeling_outputs import Transformer1DModelOutput
|
45 |
+
|
46 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47 |
+
|
48 |
+
|
49 |
+
@maybe_allow_in_graph
|
50 |
+
class DiTBlock(nn.Module):
|
51 |
+
r"""
|
52 |
+
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
|
53 |
+
QKNorm
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
dim (`int`):
|
57 |
+
The number of channels in the input and output.
|
58 |
+
num_attention_heads (`int`):
|
59 |
+
The number of headsto use for multi-head attention.
|
60 |
+
cross_attention_dim (`int`,*optional*):
|
61 |
+
The size of the encoder_hidden_states vector for cross attention.
|
62 |
+
dropout(`float`, *optional*, defaults to 0.0):
|
63 |
+
The dropout probability to use.
|
64 |
+
activation_fn (`str`,*optional*, defaults to `"geglu"`):
|
65 |
+
Activation function to be used in feed-forward. .
|
66 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
67 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
68 |
+
norm_eps (`float`, *optional*, defaults to 1e-6):
|
69 |
+
A small constant added to the denominator in normalization layers to prevent division by zero.
|
70 |
+
final_dropout (`bool` *optional*, defaults to False):
|
71 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
72 |
+
ff_inner_dim (`int`, *optional*):
|
73 |
+
The size of the hidden layer in the feed-forward block. Defaults to `None`.
|
74 |
+
ff_bias (`bool`, *optional*, defaults to `True`):
|
75 |
+
Whether to use bias in the feed-forward block.
|
76 |
+
skip (`bool`, *optional*, defaults to `False`):
|
77 |
+
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
|
78 |
+
qk_norm (`bool`, *optional*, defaults to `True`):
|
79 |
+
Whether to use normalization in QK calculation. Defaults to `True`.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim: int,
|
85 |
+
num_attention_heads: int,
|
86 |
+
use_self_attention: bool = True,
|
87 |
+
use_cross_attention: bool = False,
|
88 |
+
self_attention_norm_type: Optional[str] = None, # ada layer norm
|
89 |
+
cross_attention_dim: Optional[int] = None,
|
90 |
+
cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
|
91 |
+
# parallel second cross attention
|
92 |
+
use_cross_attention_2: bool = False,
|
93 |
+
cross_attention_2_dim: Optional[int] = None,
|
94 |
+
cross_attention_2_norm_type: Optional[str] = None,
|
95 |
+
dropout=0.0,
|
96 |
+
activation_fn: str = "gelu",
|
97 |
+
norm_type: str = "fp32_layer_norm", # TODO
|
98 |
+
norm_elementwise_affine: bool = True,
|
99 |
+
norm_eps: float = 1e-5,
|
100 |
+
final_dropout: bool = False,
|
101 |
+
ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
|
102 |
+
ff_bias: bool = True,
|
103 |
+
skip: bool = False,
|
104 |
+
skip_concat_front: bool = False, # [x, skip] or [skip, x]
|
105 |
+
skip_norm_last: bool = False, # this is an error
|
106 |
+
qk_norm: bool = True,
|
107 |
+
qkv_bias: bool = True,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.use_self_attention = use_self_attention
|
112 |
+
self.use_cross_attention = use_cross_attention
|
113 |
+
self.use_cross_attention_2 = use_cross_attention_2
|
114 |
+
self.skip_concat_front = skip_concat_front
|
115 |
+
self.skip_norm_last = skip_norm_last
|
116 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
117 |
+
# NOTE: when new version comes, check norm2 and norm 3
|
118 |
+
# 1. Self-Attn
|
119 |
+
if use_self_attention:
|
120 |
+
if (
|
121 |
+
self_attention_norm_type == "fp32_layer_norm"
|
122 |
+
or self_attention_norm_type is None
|
123 |
+
):
|
124 |
+
self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
125 |
+
else:
|
126 |
+
raise NotImplementedError
|
127 |
+
|
128 |
+
self.attn1 = Attention(
|
129 |
+
query_dim=dim,
|
130 |
+
cross_attention_dim=None,
|
131 |
+
dim_head=dim // num_attention_heads,
|
132 |
+
heads=num_attention_heads,
|
133 |
+
qk_norm="rms_norm" if qk_norm else None,
|
134 |
+
eps=1e-6,
|
135 |
+
bias=qkv_bias,
|
136 |
+
processor=TripoSGAttnProcessor2_0(),
|
137 |
+
)
|
138 |
+
|
139 |
+
# 2. Cross-Attn
|
140 |
+
if use_cross_attention:
|
141 |
+
assert cross_attention_dim is not None
|
142 |
+
|
143 |
+
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
144 |
+
|
145 |
+
self.attn2 = Attention(
|
146 |
+
query_dim=dim,
|
147 |
+
cross_attention_dim=cross_attention_dim,
|
148 |
+
dim_head=dim // num_attention_heads,
|
149 |
+
heads=num_attention_heads,
|
150 |
+
qk_norm="rms_norm" if qk_norm else None,
|
151 |
+
cross_attention_norm=cross_attention_norm_type,
|
152 |
+
eps=1e-6,
|
153 |
+
bias=qkv_bias,
|
154 |
+
processor=TripoSGAttnProcessor2_0(),
|
155 |
+
)
|
156 |
+
|
157 |
+
# 2'. Parallel Second Cross-Attn
|
158 |
+
if use_cross_attention_2:
|
159 |
+
assert cross_attention_2_dim is not None
|
160 |
+
|
161 |
+
self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
162 |
+
|
163 |
+
self.attn2_2 = Attention(
|
164 |
+
query_dim=dim,
|
165 |
+
cross_attention_dim=cross_attention_2_dim,
|
166 |
+
dim_head=dim // num_attention_heads,
|
167 |
+
heads=num_attention_heads,
|
168 |
+
qk_norm="rms_norm" if qk_norm else None,
|
169 |
+
cross_attention_norm=cross_attention_2_norm_type,
|
170 |
+
eps=1e-6,
|
171 |
+
bias=qkv_bias,
|
172 |
+
processor=TripoSGAttnProcessor2_0(),
|
173 |
+
)
|
174 |
+
|
175 |
+
# 3. Feed-forward
|
176 |
+
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
177 |
+
|
178 |
+
self.ff = FeedForward(
|
179 |
+
dim,
|
180 |
+
dropout=dropout, ### 0.0
|
181 |
+
activation_fn=activation_fn, ### approx GeLU
|
182 |
+
final_dropout=final_dropout, ### 0.0
|
183 |
+
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
|
184 |
+
bias=ff_bias,
|
185 |
+
)
|
186 |
+
|
187 |
+
# 4. Skip Connection
|
188 |
+
if skip:
|
189 |
+
self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
|
190 |
+
self.skip_linear = nn.Linear(2 * dim, dim)
|
191 |
+
else:
|
192 |
+
self.skip_linear = None
|
193 |
+
|
194 |
+
# let chunk size default to None
|
195 |
+
self._chunk_size = None
|
196 |
+
self._chunk_dim = 0
|
197 |
+
|
198 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
199 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
200 |
+
# Sets chunk feed-forward
|
201 |
+
self._chunk_size = chunk_size
|
202 |
+
self._chunk_dim = dim
|
203 |
+
|
204 |
+
def forward(
|
205 |
+
self,
|
206 |
+
hidden_states: torch.Tensor,
|
207 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
208 |
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
209 |
+
temb: Optional[torch.Tensor] = None,
|
210 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
211 |
+
skip: Optional[torch.Tensor] = None,
|
212 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
213 |
+
) -> torch.Tensor:
|
214 |
+
# Prepare attention kwargs
|
215 |
+
attention_kwargs = attention_kwargs or {}
|
216 |
+
|
217 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
218 |
+
# 0. Long Skip Connection
|
219 |
+
if self.skip_linear is not None:
|
220 |
+
cat = torch.cat(
|
221 |
+
(
|
222 |
+
[skip, hidden_states]
|
223 |
+
if self.skip_concat_front
|
224 |
+
else [hidden_states, skip]
|
225 |
+
),
|
226 |
+
dim=-1,
|
227 |
+
)
|
228 |
+
if self.skip_norm_last:
|
229 |
+
# don't do this
|
230 |
+
hidden_states = self.skip_linear(cat)
|
231 |
+
hidden_states = self.skip_norm(hidden_states)
|
232 |
+
else:
|
233 |
+
cat = self.skip_norm(cat)
|
234 |
+
hidden_states = self.skip_linear(cat)
|
235 |
+
|
236 |
+
# 1. Self-Attention
|
237 |
+
if self.use_self_attention:
|
238 |
+
norm_hidden_states = self.norm1(hidden_states)
|
239 |
+
attn_output = self.attn1(
|
240 |
+
norm_hidden_states,
|
241 |
+
image_rotary_emb=image_rotary_emb,
|
242 |
+
**attention_kwargs,
|
243 |
+
)
|
244 |
+
hidden_states = hidden_states + attn_output
|
245 |
+
|
246 |
+
# 2. Cross-Attention
|
247 |
+
if self.use_cross_attention:
|
248 |
+
if self.use_cross_attention_2:
|
249 |
+
hidden_states = (
|
250 |
+
hidden_states
|
251 |
+
+ self.attn2(
|
252 |
+
self.norm2(hidden_states),
|
253 |
+
encoder_hidden_states=encoder_hidden_states,
|
254 |
+
image_rotary_emb=image_rotary_emb,
|
255 |
+
**attention_kwargs,
|
256 |
+
)
|
257 |
+
+ self.attn2_2(
|
258 |
+
self.norm2_2(hidden_states),
|
259 |
+
encoder_hidden_states=encoder_hidden_states_2,
|
260 |
+
image_rotary_emb=image_rotary_emb,
|
261 |
+
**attention_kwargs,
|
262 |
+
)
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
hidden_states = hidden_states + self.attn2(
|
266 |
+
self.norm2(hidden_states),
|
267 |
+
encoder_hidden_states=encoder_hidden_states,
|
268 |
+
image_rotary_emb=image_rotary_emb,
|
269 |
+
**attention_kwargs,
|
270 |
+
)
|
271 |
+
|
272 |
+
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
|
273 |
+
mlp_inputs = self.norm3(hidden_states)
|
274 |
+
hidden_states = hidden_states + self.ff(mlp_inputs)
|
275 |
+
|
276 |
+
return hidden_states
|
277 |
+
|
278 |
+
|
279 |
+
class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
280 |
+
"""
|
281 |
+
TripoSG: Diffusion model with a Transformer backbone.
|
282 |
+
|
283 |
+
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
287 |
+
The number of heads to use for multi-head attention.
|
288 |
+
attention_head_dim (`int`, *optional*, defaults to 88):
|
289 |
+
The number of channels in each head.
|
290 |
+
in_channels (`int`, *optional*):
|
291 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
292 |
+
patch_size (`int`, *optional*):
|
293 |
+
The size of the patch to use for the input.
|
294 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
295 |
+
Activation function to use in feed-forward.
|
296 |
+
sample_size (`int`, *optional*):
|
297 |
+
The width of the latent images. This is fixed during training since it is used to learn a number of
|
298 |
+
position embeddings.
|
299 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
300 |
+
The dropout probability to use.
|
301 |
+
cross_attention_dim (`int`, *optional*):
|
302 |
+
The number of dimension in the clip text embedding.
|
303 |
+
hidden_size (`int`, *optional*):
|
304 |
+
The size of hidden layer in the conditioning embedding layers.
|
305 |
+
num_layers (`int`, *optional*, defaults to 1):
|
306 |
+
The number of layers of Transformer blocks to use.
|
307 |
+
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
308 |
+
The ratio of the hidden layer size to the input size.
|
309 |
+
learn_sigma (`bool`, *optional*, defaults to `True`):
|
310 |
+
Whether to predict variance.
|
311 |
+
cross_attention_dim_t5 (`int`, *optional*):
|
312 |
+
The number dimensions in t5 text embedding.
|
313 |
+
pooled_projection_dim (`int`, *optional*):
|
314 |
+
The size of the pooled projection.
|
315 |
+
text_len (`int`, *optional*):
|
316 |
+
The length of the clip text embedding.
|
317 |
+
text_len_t5 (`int`, *optional*):
|
318 |
+
The length of the T5 text embedding.
|
319 |
+
use_style_cond_and_image_meta_size (`bool`, *optional*):
|
320 |
+
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
321 |
+
"""
|
322 |
+
|
323 |
+
_supports_gradient_checkpointing = True
|
324 |
+
|
325 |
+
@register_to_config
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
num_attention_heads: int = 16,
|
329 |
+
width: int = 2048,
|
330 |
+
in_channels: int = 64,
|
331 |
+
num_layers: int = 21,
|
332 |
+
cross_attention_dim: int = 768,
|
333 |
+
cross_attention_2_dim: int = 1024,
|
334 |
+
):
|
335 |
+
super().__init__()
|
336 |
+
self.out_channels = in_channels
|
337 |
+
self.num_heads = num_attention_heads
|
338 |
+
self.inner_dim = width
|
339 |
+
self.mlp_ratio = 4.0
|
340 |
+
|
341 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
342 |
+
"positional",
|
343 |
+
inner_dim=self.inner_dim,
|
344 |
+
flip_sin_to_cos=False,
|
345 |
+
freq_shift=0,
|
346 |
+
time_embedding_dim=None,
|
347 |
+
)
|
348 |
+
self.time_proj = TimestepEmbedding(
|
349 |
+
timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
|
350 |
+
)
|
351 |
+
self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
|
352 |
+
|
353 |
+
self.blocks = nn.ModuleList(
|
354 |
+
[
|
355 |
+
DiTBlock(
|
356 |
+
dim=self.inner_dim,
|
357 |
+
num_attention_heads=self.config.num_attention_heads,
|
358 |
+
use_self_attention=True,
|
359 |
+
use_cross_attention=True,
|
360 |
+
self_attention_norm_type="fp32_layer_norm",
|
361 |
+
cross_attention_dim=self.config.cross_attention_dim,
|
362 |
+
cross_attention_norm_type=None,
|
363 |
+
use_cross_attention_2=True,
|
364 |
+
cross_attention_2_dim=self.config.cross_attention_2_dim,
|
365 |
+
cross_attention_2_norm_type=None,
|
366 |
+
activation_fn="gelu",
|
367 |
+
norm_type="fp32_layer_norm", # TODO
|
368 |
+
norm_eps=1e-5,
|
369 |
+
ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
|
370 |
+
skip=layer > num_layers // 2,
|
371 |
+
skip_concat_front=True,
|
372 |
+
skip_norm_last=True, # this is an error
|
373 |
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
374 |
+
qkv_bias=False,
|
375 |
+
)
|
376 |
+
for layer in range(num_layers)
|
377 |
+
]
|
378 |
+
)
|
379 |
+
|
380 |
+
self.norm_out = LayerNorm(self.inner_dim)
|
381 |
+
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
|
382 |
+
|
383 |
+
self.gradient_checkpointing = False
|
384 |
+
|
385 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
386 |
+
self.gradient_checkpointing = value
|
387 |
+
|
388 |
+
def _set_time_proj(
|
389 |
+
self,
|
390 |
+
time_embedding_type: str,
|
391 |
+
inner_dim: int,
|
392 |
+
flip_sin_to_cos: bool,
|
393 |
+
freq_shift: float,
|
394 |
+
time_embedding_dim: int,
|
395 |
+
) -> Tuple[int, int]:
|
396 |
+
if time_embedding_type == "fourier":
|
397 |
+
time_embed_dim = time_embedding_dim or inner_dim * 2
|
398 |
+
if time_embed_dim % 2 != 0:
|
399 |
+
raise ValueError(
|
400 |
+
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
|
401 |
+
)
|
402 |
+
self.time_embed = GaussianFourierProjection(
|
403 |
+
time_embed_dim // 2,
|
404 |
+
set_W_to_weight=False,
|
405 |
+
log=False,
|
406 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
407 |
+
)
|
408 |
+
timestep_input_dim = time_embed_dim
|
409 |
+
elif time_embedding_type == "positional":
|
410 |
+
time_embed_dim = time_embedding_dim or inner_dim * 4
|
411 |
+
|
412 |
+
self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
413 |
+
timestep_input_dim = inner_dim
|
414 |
+
else:
|
415 |
+
raise ValueError(
|
416 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
417 |
+
)
|
418 |
+
|
419 |
+
return time_embed_dim, timestep_input_dim
|
420 |
+
|
421 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
|
422 |
+
def fuse_qkv_projections(self):
|
423 |
+
"""
|
424 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
425 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
426 |
+
|
427 |
+
<Tip warning={true}>
|
428 |
+
|
429 |
+
This API is 🧪 experimental.
|
430 |
+
|
431 |
+
</Tip>
|
432 |
+
"""
|
433 |
+
self.original_attn_processors = None
|
434 |
+
|
435 |
+
for _, attn_processor in self.attn_processors.items():
|
436 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
437 |
+
raise ValueError(
|
438 |
+
"`fuse_qkv_projections()` is not supported for models having added KV projections."
|
439 |
+
)
|
440 |
+
|
441 |
+
self.original_attn_processors = self.attn_processors
|
442 |
+
|
443 |
+
for module in self.modules():
|
444 |
+
if isinstance(module, Attention):
|
445 |
+
module.fuse_projections(fuse=True)
|
446 |
+
|
447 |
+
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
|
448 |
+
|
449 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
450 |
+
def unfuse_qkv_projections(self):
|
451 |
+
"""Disables the fused QKV projection if enabled.
|
452 |
+
|
453 |
+
<Tip warning={true}>
|
454 |
+
|
455 |
+
This API is 🧪 experimental.
|
456 |
+
|
457 |
+
</Tip>
|
458 |
+
|
459 |
+
"""
|
460 |
+
if self.original_attn_processors is not None:
|
461 |
+
self.set_attn_processor(self.original_attn_processors)
|
462 |
+
|
463 |
+
@property
|
464 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
465 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
466 |
+
r"""
|
467 |
+
Returns:
|
468 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
469 |
+
indexed by its weight name.
|
470 |
+
"""
|
471 |
+
# set recursively
|
472 |
+
processors = {}
|
473 |
+
|
474 |
+
def fn_recursive_add_processors(
|
475 |
+
name: str,
|
476 |
+
module: torch.nn.Module,
|
477 |
+
processors: Dict[str, AttentionProcessor],
|
478 |
+
):
|
479 |
+
if hasattr(module, "get_processor"):
|
480 |
+
processors[f"{name}.processor"] = module.get_processor()
|
481 |
+
|
482 |
+
for sub_name, child in module.named_children():
|
483 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
484 |
+
|
485 |
+
return processors
|
486 |
+
|
487 |
+
for name, module in self.named_children():
|
488 |
+
fn_recursive_add_processors(name, module, processors)
|
489 |
+
|
490 |
+
return processors
|
491 |
+
|
492 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
493 |
+
def set_attn_processor(
|
494 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
495 |
+
):
|
496 |
+
r"""
|
497 |
+
Sets the attention processor to use to compute attention.
|
498 |
+
|
499 |
+
Parameters:
|
500 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
501 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
502 |
+
for **all** `Attention` layers.
|
503 |
+
|
504 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
505 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
506 |
+
|
507 |
+
"""
|
508 |
+
count = len(self.attn_processors.keys())
|
509 |
+
|
510 |
+
if isinstance(processor, dict) and len(processor) != count:
|
511 |
+
raise ValueError(
|
512 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
513 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
514 |
+
)
|
515 |
+
|
516 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
517 |
+
if hasattr(module, "set_processor"):
|
518 |
+
if not isinstance(processor, dict):
|
519 |
+
module.set_processor(processor)
|
520 |
+
else:
|
521 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
522 |
+
|
523 |
+
for sub_name, child in module.named_children():
|
524 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
525 |
+
|
526 |
+
for name, module in self.named_children():
|
527 |
+
fn_recursive_attn_processor(name, module, processor)
|
528 |
+
|
529 |
+
def set_default_attn_processor(self):
|
530 |
+
"""
|
531 |
+
Disables custom attention processors and sets the default attention implementation.
|
532 |
+
"""
|
533 |
+
self.set_attn_processor(TripoSGAttnProcessor2_0())
|
534 |
+
|
535 |
+
def forward(
|
536 |
+
self,
|
537 |
+
hidden_states: Optional[torch.Tensor],
|
538 |
+
timestep: Union[int, float, torch.LongTensor],
|
539 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
540 |
+
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
541 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
542 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
543 |
+
return_dict: bool = True,
|
544 |
+
):
|
545 |
+
"""
|
546 |
+
The [`HunyuanDiT2DModel`] forward method.
|
547 |
+
|
548 |
+
Args:
|
549 |
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
550 |
+
The input tensor.
|
551 |
+
timestep ( `torch.LongTensor`, *optional*):
|
552 |
+
Used to indicate denoising step.
|
553 |
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
554 |
+
Conditional embeddings for cross attention layer.
|
555 |
+
encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
556 |
+
Conditional embeddings for cross attention layer.
|
557 |
+
return_dict: bool
|
558 |
+
Whether to return a dictionary.
|
559 |
+
"""
|
560 |
+
|
561 |
+
if attention_kwargs is not None:
|
562 |
+
attention_kwargs = attention_kwargs.copy()
|
563 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
564 |
+
else:
|
565 |
+
lora_scale = 1.0
|
566 |
+
|
567 |
+
if USE_PEFT_BACKEND:
|
568 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
569 |
+
scale_lora_layers(self, lora_scale)
|
570 |
+
else:
|
571 |
+
if (
|
572 |
+
attention_kwargs is not None
|
573 |
+
and attention_kwargs.get("scale", None) is not None
|
574 |
+
):
|
575 |
+
logger.warning(
|
576 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
577 |
+
)
|
578 |
+
|
579 |
+
_, N, _ = hidden_states.shape
|
580 |
+
|
581 |
+
temb = self.time_embed(timestep).to(hidden_states.dtype)
|
582 |
+
temb = self.time_proj(temb)
|
583 |
+
temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
|
584 |
+
|
585 |
+
hidden_states = self.proj_in(hidden_states)
|
586 |
+
|
587 |
+
# N + 1 token
|
588 |
+
hidden_states = torch.cat([temb, hidden_states], dim=1)
|
589 |
+
|
590 |
+
skips = []
|
591 |
+
for layer, block in enumerate(self.blocks):
|
592 |
+
skip = None if layer <= self.config.num_layers // 2 else skips.pop()
|
593 |
+
|
594 |
+
if self.training and self.gradient_checkpointing:
|
595 |
+
|
596 |
+
def create_custom_forward(module):
|
597 |
+
def custom_forward(*inputs):
|
598 |
+
return module(*inputs)
|
599 |
+
|
600 |
+
return custom_forward
|
601 |
+
|
602 |
+
ckpt_kwargs: Dict[str, Any] = (
|
603 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
604 |
+
)
|
605 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
606 |
+
create_custom_forward(block),
|
607 |
+
hidden_states,
|
608 |
+
encoder_hidden_states,
|
609 |
+
encoder_hidden_states_2,
|
610 |
+
temb,
|
611 |
+
image_rotary_emb,
|
612 |
+
skip,
|
613 |
+
attention_kwargs,
|
614 |
+
**ckpt_kwargs,
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
hidden_states = block(
|
618 |
+
hidden_states,
|
619 |
+
encoder_hidden_states=encoder_hidden_states,
|
620 |
+
encoder_hidden_states_2=encoder_hidden_states_2,
|
621 |
+
temb=temb,
|
622 |
+
image_rotary_emb=image_rotary_emb,
|
623 |
+
skip=skip,
|
624 |
+
attention_kwargs=attention_kwargs,
|
625 |
+
) # (N, L, D)
|
626 |
+
|
627 |
+
if layer < self.config.num_layers // 2:
|
628 |
+
skips.append(hidden_states)
|
629 |
+
|
630 |
+
# final layer
|
631 |
+
hidden_states = self.norm_out(hidden_states)
|
632 |
+
hidden_states = hidden_states[:, -N:]
|
633 |
+
hidden_states = self.proj_out(hidden_states)
|
634 |
+
|
635 |
+
if USE_PEFT_BACKEND:
|
636 |
+
# remove `lora_scale` from each PEFT layer
|
637 |
+
unscale_lora_layers(self, lora_scale)
|
638 |
+
|
639 |
+
if not return_dict:
|
640 |
+
return (hidden_states,)
|
641 |
+
|
642 |
+
return Transformer1DModelOutput(sample=hidden_states)
|
643 |
+
|
644 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
645 |
+
def enable_forward_chunking(
|
646 |
+
self, chunk_size: Optional[int] = None, dim: int = 0
|
647 |
+
) -> None:
|
648 |
+
"""
|
649 |
+
Sets the attention processor to use [feed forward
|
650 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
651 |
+
|
652 |
+
Parameters:
|
653 |
+
chunk_size (`int`, *optional*):
|
654 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
655 |
+
over each tensor of dim=`dim`.
|
656 |
+
dim (`int`, *optional*, defaults to `0`):
|
657 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
658 |
+
or dim=1 (sequence length).
|
659 |
+
"""
|
660 |
+
if dim not in [0, 1]:
|
661 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
662 |
+
|
663 |
+
# By default chunk size is 1
|
664 |
+
chunk_size = chunk_size or 1
|
665 |
+
|
666 |
+
def fn_recursive_feed_forward(
|
667 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
668 |
+
):
|
669 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
670 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
671 |
+
|
672 |
+
for child in module.children():
|
673 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
674 |
+
|
675 |
+
for module in self.children():
|
676 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
677 |
+
|
678 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
679 |
+
def disable_forward_chunking(self):
|
680 |
+
def fn_recursive_feed_forward(
|
681 |
+
module: torch.nn.Module, chunk_size: int, dim: int
|
682 |
+
):
|
683 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
684 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
685 |
+
|
686 |
+
for child in module.children():
|
687 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
688 |
+
|
689 |
+
for module in self.children():
|
690 |
+
fn_recursive_feed_forward(module, None, 0)
|
midi/pipelines/pipeline_midi.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import math
|
3 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import PIL.Image
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from diffusers.image_processor import PipelineImageInput
|
11 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
12 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler # not sure
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
15 |
+
from peft import LoraConfig, get_peft_model_state_dict
|
16 |
+
from transformers import (
|
17 |
+
BitImageProcessor,
|
18 |
+
CLIPImageProcessor,
|
19 |
+
CLIPVisionModelWithProjection,
|
20 |
+
Dinov2Model,
|
21 |
+
)
|
22 |
+
|
23 |
+
from ..inference_utils import generate_dense_grid_points
|
24 |
+
from ..loaders import CustomAdapterMixin
|
25 |
+
from ..models.attention_processor import MIAttnProcessor2_0
|
26 |
+
from ..models.autoencoders import TripoSGVAEModel
|
27 |
+
from ..models.transformers import TripoSGDiTModel, set_transformer_attn_processor
|
28 |
+
from .pipeline_triposg_output import TripoSGPipelineOutput
|
29 |
+
from .pipeline_utils import TransformerDiffusionMixin
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32 |
+
|
33 |
+
|
34 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
35 |
+
def retrieve_timesteps(
|
36 |
+
scheduler,
|
37 |
+
num_inference_steps: Optional[int] = None,
|
38 |
+
device: Optional[Union[str, torch.device]] = None,
|
39 |
+
timesteps: Optional[List[int]] = None,
|
40 |
+
sigmas: Optional[List[float]] = None,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
45 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
scheduler (`SchedulerMixin`):
|
49 |
+
The scheduler to get timesteps from.
|
50 |
+
num_inference_steps (`int`):
|
51 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
52 |
+
must be `None`.
|
53 |
+
device (`str` or `torch.device`, *optional*):
|
54 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
55 |
+
timesteps (`List[int]`, *optional*):
|
56 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
57 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
58 |
+
sigmas (`List[float]`, *optional*):
|
59 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
60 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
64 |
+
second element is the number of inference steps.
|
65 |
+
"""
|
66 |
+
if timesteps is not None and sigmas is not None:
|
67 |
+
raise ValueError(
|
68 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
69 |
+
)
|
70 |
+
if timesteps is not None:
|
71 |
+
accepts_timesteps = "timesteps" in set(
|
72 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
73 |
+
)
|
74 |
+
if not accepts_timesteps:
|
75 |
+
raise ValueError(
|
76 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
77 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
78 |
+
)
|
79 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
80 |
+
timesteps = scheduler.timesteps
|
81 |
+
num_inference_steps = len(timesteps)
|
82 |
+
elif sigmas is not None:
|
83 |
+
accept_sigmas = "sigmas" in set(
|
84 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
85 |
+
)
|
86 |
+
if not accept_sigmas:
|
87 |
+
raise ValueError(
|
88 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
89 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
90 |
+
)
|
91 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
92 |
+
timesteps = scheduler.timesteps
|
93 |
+
num_inference_steps = len(timesteps)
|
94 |
+
else:
|
95 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
96 |
+
timesteps = scheduler.timesteps
|
97 |
+
return timesteps, num_inference_steps
|
98 |
+
|
99 |
+
|
100 |
+
class MIDIPipeline(DiffusionPipeline, TransformerDiffusionMixin, CustomAdapterMixin):
|
101 |
+
"""
|
102 |
+
Pipeline for image-to-scene generation based on pre-trained shape diffusion.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
vae: TripoSGVAEModel,
|
108 |
+
transformer: TripoSGDiTModel,
|
109 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
110 |
+
image_encoder_1: CLIPVisionModelWithProjection,
|
111 |
+
image_encoder_2: Dinov2Model,
|
112 |
+
feature_extractor_1: CLIPImageProcessor,
|
113 |
+
feature_extractor_2: BitImageProcessor,
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.register_modules(
|
118 |
+
vae=vae,
|
119 |
+
transformer=transformer,
|
120 |
+
scheduler=scheduler,
|
121 |
+
image_encoder_1=image_encoder_1,
|
122 |
+
image_encoder_2=image_encoder_2,
|
123 |
+
feature_extractor_1=feature_extractor_1,
|
124 |
+
feature_extractor_2=feature_extractor_2,
|
125 |
+
)
|
126 |
+
|
127 |
+
@property
|
128 |
+
def guidance_scale(self):
|
129 |
+
return self._guidance_scale
|
130 |
+
|
131 |
+
@property
|
132 |
+
def do_classifier_free_guidance(self):
|
133 |
+
return self._guidance_scale > 1
|
134 |
+
|
135 |
+
@property
|
136 |
+
def num_timesteps(self):
|
137 |
+
return self._num_timesteps
|
138 |
+
|
139 |
+
@property
|
140 |
+
def attention_kwargs(self):
|
141 |
+
return self._attention_kwargs
|
142 |
+
|
143 |
+
@property
|
144 |
+
def interrupt(self):
|
145 |
+
return self._interrupt
|
146 |
+
|
147 |
+
@property
|
148 |
+
def decode_progressive(self):
|
149 |
+
return self._decode_progressive
|
150 |
+
|
151 |
+
def encode_image_1(self, image, device, num_images_per_prompt):
|
152 |
+
dtype = next(self.image_encoder_1.parameters()).dtype
|
153 |
+
|
154 |
+
if not isinstance(image, torch.Tensor):
|
155 |
+
image = self.feature_extractor_1(image, return_tensors="pt").pixel_values
|
156 |
+
|
157 |
+
image = image.to(device=device, dtype=dtype)
|
158 |
+
image_embeds = self.image_encoder_1(image).image_embeds
|
159 |
+
image_embeds = image_embeds.repeat_interleave(
|
160 |
+
num_images_per_prompt, dim=0
|
161 |
+
).unsqueeze(1)
|
162 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
163 |
+
|
164 |
+
return image_embeds, uncond_image_embeds
|
165 |
+
|
166 |
+
def encode_image_2(
|
167 |
+
self,
|
168 |
+
image_one,
|
169 |
+
image_two,
|
170 |
+
mask,
|
171 |
+
device,
|
172 |
+
num_images_per_prompt,
|
173 |
+
):
|
174 |
+
dtype = next(self.image_encoder_2.parameters()).dtype
|
175 |
+
|
176 |
+
images = [image_one, image_two, mask]
|
177 |
+
images_new = []
|
178 |
+
for i, image in enumerate(images):
|
179 |
+
if not isinstance(image, torch.Tensor):
|
180 |
+
if i <= 1:
|
181 |
+
images_new.append(
|
182 |
+
self.feature_extractor_2(
|
183 |
+
image, return_tensors="pt"
|
184 |
+
).pixel_values
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
image = [
|
188 |
+
torch.from_numpy(
|
189 |
+
(np.array(im) / 255.0).astype(np.float32)
|
190 |
+
).unsqueeze(0)
|
191 |
+
for im in image
|
192 |
+
]
|
193 |
+
image = torch.stack(image, dim=0)
|
194 |
+
images_new.append(
|
195 |
+
F.interpolate(
|
196 |
+
image, size=images_new[0].shape[-2:], mode="nearest"
|
197 |
+
)
|
198 |
+
)
|
199 |
+
|
200 |
+
image = torch.cat(images_new, dim=1).to(device=device, dtype=dtype)
|
201 |
+
image_embeds = self.image_encoder_2(image).last_hidden_state
|
202 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
203 |
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
204 |
+
|
205 |
+
return image_embeds, uncond_image_embeds
|
206 |
+
|
207 |
+
def prepare_latents(
|
208 |
+
self,
|
209 |
+
batch_size,
|
210 |
+
num_tokens,
|
211 |
+
num_channels_latents,
|
212 |
+
dtype,
|
213 |
+
device,
|
214 |
+
generator,
|
215 |
+
latents: Optional[torch.Tensor] = None,
|
216 |
+
):
|
217 |
+
if latents is not None:
|
218 |
+
return latents.to(device=device, dtype=dtype)
|
219 |
+
|
220 |
+
shape = (batch_size, num_tokens, num_channels_latents)
|
221 |
+
|
222 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
223 |
+
raise ValueError(
|
224 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
225 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
226 |
+
)
|
227 |
+
|
228 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
229 |
+
|
230 |
+
return latents
|
231 |
+
|
232 |
+
@torch.no_grad()
|
233 |
+
def decode_latents(
|
234 |
+
self,
|
235 |
+
latents: torch.Tensor,
|
236 |
+
sampled_points: torch.Tensor,
|
237 |
+
decode_progressive: bool = False,
|
238 |
+
decode_to_cpu: bool = False,
|
239 |
+
# Params for sampling points
|
240 |
+
bbox_min: np.ndarray = np.array([-1.005, -1.005, -1.005]),
|
241 |
+
bbox_max: np.ndarray = np.array([1.005, 1.005, 1.005]),
|
242 |
+
octree_depth: int = 8,
|
243 |
+
indexing: str = "ij",
|
244 |
+
padding: float = 0.05,
|
245 |
+
):
|
246 |
+
device, dtype = latents.device, latents.dtype
|
247 |
+
batch_size = latents.shape[0]
|
248 |
+
|
249 |
+
grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], []
|
250 |
+
|
251 |
+
if sampled_points is None:
|
252 |
+
sampled_points, grid_size, bbox_size = generate_dense_grid_points(
|
253 |
+
bbox_min, bbox_max, octree_depth, indexing
|
254 |
+
)
|
255 |
+
sampled_points = torch.FloatTensor(sampled_points).to(
|
256 |
+
device=device, dtype=dtype
|
257 |
+
)
|
258 |
+
sampled_points = sampled_points.unsqueeze(0).expand(batch_size, -1, -1)
|
259 |
+
|
260 |
+
grid_sizes.append(grid_size)
|
261 |
+
bbox_sizes.append(bbox_size)
|
262 |
+
bbox_mins.append(bbox_min)
|
263 |
+
bbox_maxs.append(bbox_max)
|
264 |
+
|
265 |
+
self.vae: TripoSGVAEModel
|
266 |
+
output = self.vae.decode(
|
267 |
+
latents, sampled_points=sampled_points, to_cpu=decode_to_cpu
|
268 |
+
).sample
|
269 |
+
|
270 |
+
if not decode_progressive:
|
271 |
+
return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
|
272 |
+
|
273 |
+
grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = [], [], [], []
|
274 |
+
sampled_points_list = []
|
275 |
+
|
276 |
+
for i in range(batch_size):
|
277 |
+
sdf_ = output[i].squeeze(-1) # [num_points]
|
278 |
+
sampled_points_ = sampled_points[i]
|
279 |
+
occupied_points = sampled_points_[sdf_ <= 0] # [num_occupied_points, 3]
|
280 |
+
|
281 |
+
if occupied_points.shape[0] == 0:
|
282 |
+
logger.warning(
|
283 |
+
f"No occupied points found in batch {i}. Using original bounding box."
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
bbox_min = occupied_points.min(dim=0).values
|
287 |
+
bbox_max = occupied_points.max(dim=0).values
|
288 |
+
bbox_min = (bbox_min - padding).float().cpu().numpy()
|
289 |
+
bbox_max = (bbox_max + padding).float().cpu().numpy()
|
290 |
+
|
291 |
+
sampled_points_, grid_size, bbox_size = generate_dense_grid_points(
|
292 |
+
bbox_min, bbox_max, octree_depth, indexing
|
293 |
+
)
|
294 |
+
sampled_points_ = torch.FloatTensor(sampled_points_).to(
|
295 |
+
device=device, dtype=dtype
|
296 |
+
)
|
297 |
+
sampled_points_list.append(sampled_points_)
|
298 |
+
|
299 |
+
grid_sizes.append(grid_size)
|
300 |
+
bbox_sizes.append(bbox_size)
|
301 |
+
bbox_mins.append(bbox_min)
|
302 |
+
bbox_maxs.append(bbox_max)
|
303 |
+
|
304 |
+
sampled_points = torch.stack(sampled_points_list, dim=0)
|
305 |
+
|
306 |
+
# Re-decode the new sampled points
|
307 |
+
output = self.vae.decode(
|
308 |
+
latents, sampled_points=sampled_points, to_cpu=decode_to_cpu
|
309 |
+
).sample
|
310 |
+
|
311 |
+
return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
|
312 |
+
|
313 |
+
@torch.no_grad()
|
314 |
+
def __call__(
|
315 |
+
self,
|
316 |
+
image: PipelineImageInput,
|
317 |
+
mask: PipelineImageInput,
|
318 |
+
image_scene: PipelineImageInput,
|
319 |
+
num_inference_steps: int = 50,
|
320 |
+
timesteps: List[int] = None,
|
321 |
+
guidance_scale: float = 7.0,
|
322 |
+
num_images_per_prompt: int = 1,
|
323 |
+
sampled_points: Optional[torch.Tensor] = None,
|
324 |
+
decode_progressive: bool = False,
|
325 |
+
decode_to_cpu: bool = False,
|
326 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
327 |
+
latents: Optional[torch.FloatTensor] = None,
|
328 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
329 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
330 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
331 |
+
output_type: Optional[str] = "mesh_vf",
|
332 |
+
return_dict: bool = True,
|
333 |
+
):
|
334 |
+
# 1. Check inputs. Raise error if not correct
|
335 |
+
# TODO
|
336 |
+
|
337 |
+
self._decode_progressive = decode_progressive
|
338 |
+
self._guidance_scale = guidance_scale
|
339 |
+
self._attention_kwargs = attention_kwargs
|
340 |
+
self._interrupt = False
|
341 |
+
|
342 |
+
# 2. Define call parameters
|
343 |
+
if isinstance(image, PIL.Image.Image):
|
344 |
+
batch_size = 1
|
345 |
+
elif isinstance(image, list):
|
346 |
+
batch_size = len(image)
|
347 |
+
elif isinstance(image, torch.Tensor):
|
348 |
+
batch_size = image.shape[0]
|
349 |
+
else:
|
350 |
+
raise ValueError("Invalid input type for image")
|
351 |
+
|
352 |
+
device = self._execution_device
|
353 |
+
|
354 |
+
# 3. Encode condition
|
355 |
+
image_embeds_1, negative_image_embeds_1 = self.encode_image_1(
|
356 |
+
image, device, num_images_per_prompt
|
357 |
+
)
|
358 |
+
image_embeds_2, negative_image_embeds_2 = self.encode_image_2(
|
359 |
+
image, image_scene, mask, device, num_images_per_prompt
|
360 |
+
)
|
361 |
+
|
362 |
+
if self.do_classifier_free_guidance:
|
363 |
+
image_embeds_1 = torch.cat([negative_image_embeds_1, image_embeds_1], dim=0)
|
364 |
+
image_embeds_2 = torch.cat([negative_image_embeds_2, image_embeds_2], dim=0)
|
365 |
+
|
366 |
+
# 4. Prepare timesteps
|
367 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
368 |
+
self.scheduler, num_inference_steps, device, timesteps
|
369 |
+
)
|
370 |
+
num_warmup_steps = max(
|
371 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
372 |
+
)
|
373 |
+
self._num_timesteps = len(timesteps)
|
374 |
+
|
375 |
+
# 5. Prepare latent variables
|
376 |
+
num_tokens = self.transformer.config.width
|
377 |
+
num_channels_latents = self.transformer.config.in_channels
|
378 |
+
latents = self.prepare_latents(
|
379 |
+
batch_size * num_images_per_prompt,
|
380 |
+
num_tokens,
|
381 |
+
num_channels_latents,
|
382 |
+
image_embeds_1.dtype,
|
383 |
+
device,
|
384 |
+
generator,
|
385 |
+
latents,
|
386 |
+
)
|
387 |
+
|
388 |
+
# 6. Denoising loop
|
389 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
390 |
+
for i, t in enumerate(timesteps):
|
391 |
+
if self.interrupt:
|
392 |
+
continue
|
393 |
+
|
394 |
+
# expand the latents if we are doing classifier free guidance
|
395 |
+
latent_model_input = (
|
396 |
+
torch.cat([latents] * 2)
|
397 |
+
if self.do_classifier_free_guidance
|
398 |
+
else latents
|
399 |
+
)
|
400 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
401 |
+
timestep = t.expand(latent_model_input.shape[0])
|
402 |
+
|
403 |
+
noise_pred = self.transformer(
|
404 |
+
latent_model_input,
|
405 |
+
timestep,
|
406 |
+
encoder_hidden_states=image_embeds_1,
|
407 |
+
encoder_hidden_states_2=image_embeds_2,
|
408 |
+
attention_kwargs=attention_kwargs,
|
409 |
+
return_dict=False,
|
410 |
+
)[0]
|
411 |
+
|
412 |
+
# perform guidance
|
413 |
+
if self.do_classifier_free_guidance:
|
414 |
+
noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
|
415 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
416 |
+
noise_pred_image - noise_pred_uncond
|
417 |
+
)
|
418 |
+
|
419 |
+
# compute the previous noisy sample x_t -> x_t-1
|
420 |
+
latents_dtype = latents.dtype
|
421 |
+
latents = self.scheduler.step(
|
422 |
+
noise_pred, t, latents, return_dict=False
|
423 |
+
)[0]
|
424 |
+
|
425 |
+
if latents.dtype != latents_dtype:
|
426 |
+
if torch.backends.mps.is_available():
|
427 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
428 |
+
latents = latents.to(latents_dtype)
|
429 |
+
|
430 |
+
if callback_on_step_end is not None:
|
431 |
+
callback_kwargs = {}
|
432 |
+
for k in callback_on_step_end_tensor_inputs:
|
433 |
+
callback_kwargs[k] = locals()[k]
|
434 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
435 |
+
|
436 |
+
latents = callback_outputs.pop("latents", latents)
|
437 |
+
image_embeds_1 = callback_outputs.pop(
|
438 |
+
"image_embeds_1", image_embeds_1
|
439 |
+
)
|
440 |
+
negative_image_embeds_1 = callback_outputs.pop(
|
441 |
+
"negative_image_embeds_1", negative_image_embeds_1
|
442 |
+
)
|
443 |
+
image_embeds_2 = callback_outputs.pop(
|
444 |
+
"image_embeds_2", image_embeds_2
|
445 |
+
)
|
446 |
+
negative_image_embeds_2 = callback_outputs.pop(
|
447 |
+
"negative_image_embeds_2", negative_image_embeds_2
|
448 |
+
)
|
449 |
+
|
450 |
+
# call the callback, if provided
|
451 |
+
if i == len(timesteps) - 1 or (
|
452 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
453 |
+
):
|
454 |
+
progress_bar.update()
|
455 |
+
|
456 |
+
grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = None, None, None, None
|
457 |
+
|
458 |
+
if output_type == "latent":
|
459 |
+
output = latents
|
460 |
+
else:
|
461 |
+
output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs = self.decode_latents(
|
462 |
+
latents,
|
463 |
+
sampled_points=sampled_points,
|
464 |
+
decode_progressive=decode_progressive,
|
465 |
+
decode_to_cpu=decode_to_cpu,
|
466 |
+
)
|
467 |
+
|
468 |
+
# Offload all models
|
469 |
+
self.maybe_free_model_hooks()
|
470 |
+
|
471 |
+
if not return_dict:
|
472 |
+
return (output, grid_sizes, bbox_sizes, bbox_mins, bbox_maxs)
|
473 |
+
|
474 |
+
return TripoSGPipelineOutput(
|
475 |
+
samples=output,
|
476 |
+
grid_sizes=grid_sizes,
|
477 |
+
bbox_sizes=bbox_sizes,
|
478 |
+
bbox_mins=bbox_mins,
|
479 |
+
bbox_maxs=bbox_maxs,
|
480 |
+
)
|
481 |
+
|
482 |
+
def _init_custom_adapter(
|
483 |
+
self, set_self_attn_module_names: Optional[List[str]] = None
|
484 |
+
):
|
485 |
+
# Set attention processor
|
486 |
+
func_default = lambda name, hs, cad, ap: MIAttnProcessor2_0(use_mi=False)
|
487 |
+
set_transformer_attn_processor( # avoid warning
|
488 |
+
self.transformer,
|
489 |
+
set_self_attn_proc_func=func_default,
|
490 |
+
set_cross_attn_1_proc_func=func_default,
|
491 |
+
set_cross_attn_2_proc_func=func_default,
|
492 |
+
)
|
493 |
+
set_transformer_attn_processor(
|
494 |
+
self.transformer,
|
495 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: MIAttnProcessor2_0(),
|
496 |
+
set_self_attn_module_names=set_self_attn_module_names,
|
497 |
+
)
|
midi/pipelines/pipeline_triposg_output.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from diffusers.utils import BaseOutput
|
7 |
+
|
8 |
+
PipelineBoxOutput = Union[
|
9 |
+
List[List[int]], # [[257, 257, 257], ...]
|
10 |
+
List[List[float]], # [[-1.05, -1.05, -1.05], ...]
|
11 |
+
List[np.ndarray],
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class TripoSGPipelineOutput(BaseOutput):
|
17 |
+
r"""
|
18 |
+
Output class for TripoSG pipelines.
|
19 |
+
"""
|
20 |
+
|
21 |
+
samples: torch.Tensor
|
22 |
+
grid_sizes: Optional[PipelineBoxOutput] = None
|
23 |
+
bbox_sizes: Optional[PipelineBoxOutput] = None
|
24 |
+
bbox_mins: Optional[PipelineBoxOutput] = None
|
25 |
+
bbox_maxs: Optional[PipelineBoxOutput] = None
|
midi/pipelines/pipeline_utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.utils import logging
|
2 |
+
|
3 |
+
logger = logging.get_logger(__name__)
|
4 |
+
|
5 |
+
|
6 |
+
class TransformerDiffusionMixin:
|
7 |
+
r"""
|
8 |
+
Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
|
9 |
+
"""
|
10 |
+
|
11 |
+
def enable_vae_slicing(self):
|
12 |
+
r"""
|
13 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
14 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
15 |
+
"""
|
16 |
+
self.vae.enable_slicing()
|
17 |
+
|
18 |
+
def disable_vae_slicing(self):
|
19 |
+
r"""
|
20 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
21 |
+
computing decoding in one step.
|
22 |
+
"""
|
23 |
+
self.vae.disable_slicing()
|
24 |
+
|
25 |
+
def enable_vae_tiling(self):
|
26 |
+
r"""
|
27 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
28 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
29 |
+
processing larger images.
|
30 |
+
"""
|
31 |
+
self.vae.enable_tiling()
|
32 |
+
|
33 |
+
def disable_vae_tiling(self):
|
34 |
+
r"""
|
35 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
36 |
+
computing decoding in one step.
|
37 |
+
"""
|
38 |
+
self.vae.disable_tiling()
|
39 |
+
|
40 |
+
def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
|
41 |
+
"""
|
42 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
43 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
44 |
+
|
45 |
+
<Tip warning={true}>
|
46 |
+
|
47 |
+
This API is 🧪 experimental.
|
48 |
+
|
49 |
+
</Tip>
|
50 |
+
|
51 |
+
Args:
|
52 |
+
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
|
53 |
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
54 |
+
"""
|
55 |
+
self.fusing_transformer = False
|
56 |
+
self.fusing_vae = False
|
57 |
+
|
58 |
+
if transformer:
|
59 |
+
self.fusing_transformer = True
|
60 |
+
self.transformer.fuse_qkv_projections()
|
61 |
+
|
62 |
+
if vae:
|
63 |
+
self.fusing_vae = True
|
64 |
+
self.vae.fuse_qkv_projections()
|
65 |
+
|
66 |
+
def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
|
67 |
+
"""Disable QKV projection fusion if enabled.
|
68 |
+
|
69 |
+
<Tip warning={true}>
|
70 |
+
|
71 |
+
This API is 🧪 experimental.
|
72 |
+
|
73 |
+
</Tip>
|
74 |
+
|
75 |
+
Args:
|
76 |
+
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
|
77 |
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
78 |
+
|
79 |
+
"""
|
80 |
+
if transformer:
|
81 |
+
if not self.fusing_transformer:
|
82 |
+
logger.warning(
|
83 |
+
"The UNet was not initially fused for QKV projections. Doing nothing."
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
self.transformer.unfuse_qkv_projections()
|
87 |
+
self.fusing_transformer = False
|
88 |
+
|
89 |
+
if vae:
|
90 |
+
if not self.fusing_vae:
|
91 |
+
logger.warning(
|
92 |
+
"The VAE was not initially fused for QKV projections. Doing nothing."
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
self.vae.unfuse_qkv_projections()
|
96 |
+
self.fusing_vae = False
|
midi/schedulers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .scheduling_rectified_flow import (
|
2 |
+
RectifiedFlowScheduler,
|
3 |
+
compute_density_for_timestep_sampling,
|
4 |
+
compute_loss_weighting,
|
5 |
+
)
|
midi/schedulers/scheduling_rectified_flow.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
13 |
+
from diffusers.utils import BaseOutput, logging
|
14 |
+
from torch.distributions import LogisticNormal
|
15 |
+
|
16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
17 |
+
|
18 |
+
|
19 |
+
# TODO: may move to training_utils.py
|
20 |
+
def compute_density_for_timestep_sampling(
|
21 |
+
weighting_scheme: str,
|
22 |
+
batch_size: int,
|
23 |
+
logit_mean: float = 0.0,
|
24 |
+
logit_std: float = 1.0,
|
25 |
+
mode_scale: float = None,
|
26 |
+
):
|
27 |
+
if weighting_scheme == "logit_normal":
|
28 |
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
29 |
+
u = torch.normal(
|
30 |
+
mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
|
31 |
+
)
|
32 |
+
u = torch.nn.functional.sigmoid(u)
|
33 |
+
elif weighting_scheme == "logit_normal_dist":
|
34 |
+
u = (
|
35 |
+
LogisticNormal(loc=logit_mean, scale=logit_std)
|
36 |
+
.sample((batch_size,))[:, 0]
|
37 |
+
.to("cpu")
|
38 |
+
)
|
39 |
+
elif weighting_scheme == "mode":
|
40 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
41 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
42 |
+
else:
|
43 |
+
u = torch.rand(size=(batch_size,), device="cpu")
|
44 |
+
return u
|
45 |
+
|
46 |
+
|
47 |
+
def compute_loss_weighting(weighting_scheme: str, sigmas=None):
|
48 |
+
"""
|
49 |
+
Computes loss weighting scheme for SD3 training.
|
50 |
+
|
51 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
52 |
+
|
53 |
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
54 |
+
"""
|
55 |
+
if weighting_scheme == "sigma_sqrt":
|
56 |
+
weighting = (sigmas**-2.0).float()
|
57 |
+
elif weighting_scheme == "cosmap":
|
58 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
59 |
+
weighting = 2 / (math.pi * bot)
|
60 |
+
else:
|
61 |
+
weighting = torch.ones_like(sigmas)
|
62 |
+
return weighting
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class RectifiedFlowSchedulerOutput(BaseOutput):
|
67 |
+
"""
|
68 |
+
Output class for the scheduler's `step` function output.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
72 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
73 |
+
denoising loop.
|
74 |
+
"""
|
75 |
+
|
76 |
+
prev_sample: torch.FloatTensor
|
77 |
+
|
78 |
+
|
79 |
+
class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin):
|
80 |
+
"""
|
81 |
+
The rectified flow scheduler is a scheduler that is used to propagate the diffusion process in the rectified flow.
|
82 |
+
|
83 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
84 |
+
methods the library implements for all schedulers such as loading and saving.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
num_train_timesteps (`int`, defaults to 1000):
|
88 |
+
The number of diffusion steps to train the model.
|
89 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
90 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
91 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
92 |
+
shift (`float`, defaults to 1.0):
|
93 |
+
The shift value for the timestep schedule.
|
94 |
+
"""
|
95 |
+
|
96 |
+
_compatibles = []
|
97 |
+
order = 1
|
98 |
+
|
99 |
+
@register_to_config
|
100 |
+
def __init__(
|
101 |
+
self,
|
102 |
+
num_train_timesteps: int = 1000,
|
103 |
+
shift: float = 1.0,
|
104 |
+
use_dynamic_shifting: bool = False,
|
105 |
+
):
|
106 |
+
# pre-compute timesteps and sigmas; no use in fact
|
107 |
+
# NOTE that shape diffusion sample timesteps randomly or in a distribution,
|
108 |
+
# instead of sampling from the pre-defined linspace
|
109 |
+
timesteps = np.array(
|
110 |
+
[
|
111 |
+
(1.0 - i / num_train_timesteps) * num_train_timesteps
|
112 |
+
for i in range(num_train_timesteps)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
116 |
+
|
117 |
+
sigmas = timesteps / num_train_timesteps
|
118 |
+
if not use_dynamic_shifting:
|
119 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
120 |
+
sigmas = self.time_shift(sigmas)
|
121 |
+
|
122 |
+
self.timesteps = sigmas * num_train_timesteps
|
123 |
+
|
124 |
+
self._step_index = None
|
125 |
+
self._begin_index = None
|
126 |
+
|
127 |
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
128 |
+
|
129 |
+
@property
|
130 |
+
def step_index(self):
|
131 |
+
"""
|
132 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
133 |
+
"""
|
134 |
+
return self._step_index
|
135 |
+
|
136 |
+
@property
|
137 |
+
def begin_index(self):
|
138 |
+
"""
|
139 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
140 |
+
"""
|
141 |
+
return self._begin_index
|
142 |
+
|
143 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
144 |
+
def set_begin_index(self, begin_index: int = 0):
|
145 |
+
"""
|
146 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
begin_index (`int`):
|
150 |
+
The begin index for the scheduler.
|
151 |
+
"""
|
152 |
+
self._begin_index = begin_index
|
153 |
+
|
154 |
+
def _sigma_to_t(self, sigma):
|
155 |
+
return sigma * self.config.num_train_timesteps
|
156 |
+
|
157 |
+
def _t_to_sigma(self, timestep):
|
158 |
+
return timestep / self.config.num_train_timesteps
|
159 |
+
|
160 |
+
def time_shift_dynamic(self, mu: float, sigma: float, t: torch.Tensor):
|
161 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
162 |
+
|
163 |
+
def time_shift(self, t: torch.Tensor):
|
164 |
+
return self.config.shift * t / (1 + (self.config.shift - 1) * t)
|
165 |
+
|
166 |
+
def set_timesteps(
|
167 |
+
self,
|
168 |
+
num_inference_steps: int = None,
|
169 |
+
device: Union[str, torch.device] = None,
|
170 |
+
sigmas: Optional[List[float]] = None,
|
171 |
+
mu: Optional[float] = None,
|
172 |
+
):
|
173 |
+
"""
|
174 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
175 |
+
|
176 |
+
Args:
|
177 |
+
num_inference_steps (`int`):
|
178 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
179 |
+
device (`str` or `torch.device`, *optional*):
|
180 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
181 |
+
"""
|
182 |
+
|
183 |
+
if self.config.use_dynamic_shifting and mu is None:
|
184 |
+
raise ValueError(
|
185 |
+
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
186 |
+
)
|
187 |
+
|
188 |
+
if sigmas is None:
|
189 |
+
self.num_inference_steps = num_inference_steps
|
190 |
+
timesteps = np.array(
|
191 |
+
[
|
192 |
+
(1.0 - i / num_inference_steps) * self.config.num_train_timesteps
|
193 |
+
for i in range(num_inference_steps)
|
194 |
+
]
|
195 |
+
) # different from the original code in SD3
|
196 |
+
sigmas = timesteps / self.config.num_train_timesteps
|
197 |
+
|
198 |
+
if self.config.use_dynamic_shifting:
|
199 |
+
sigmas = self.time_shift_dynamic(mu, 1.0, sigmas)
|
200 |
+
else:
|
201 |
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
202 |
+
|
203 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
204 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
205 |
+
|
206 |
+
self.timesteps = timesteps.to(device=device)
|
207 |
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
208 |
+
|
209 |
+
self._step_index = None
|
210 |
+
self._begin_index = None
|
211 |
+
|
212 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
213 |
+
if schedule_timesteps is None:
|
214 |
+
schedule_timesteps = self.timesteps
|
215 |
+
|
216 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
217 |
+
|
218 |
+
# The sigma index that is taken for the **very** first `step`
|
219 |
+
# is always the second index (or the last index if there is only 1)
|
220 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
221 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
222 |
+
pos = 1 if len(indices) > 1 else 0
|
223 |
+
|
224 |
+
return indices[pos].item()
|
225 |
+
|
226 |
+
def _init_step_index(self, timestep):
|
227 |
+
if self.begin_index is None:
|
228 |
+
if isinstance(timestep, torch.Tensor):
|
229 |
+
timestep = timestep.to(self.timesteps.device)
|
230 |
+
self._step_index = self.index_for_timestep(timestep)
|
231 |
+
else:
|
232 |
+
self._step_index = self._begin_index
|
233 |
+
|
234 |
+
def step(
|
235 |
+
self,
|
236 |
+
model_output: torch.FloatTensor,
|
237 |
+
timestep: Union[float, torch.FloatTensor],
|
238 |
+
sample: torch.FloatTensor,
|
239 |
+
s_churn: float = 0.0,
|
240 |
+
s_tmin: float = 0.0,
|
241 |
+
s_tmax: float = float("inf"),
|
242 |
+
s_noise: float = 1.0,
|
243 |
+
generator: Optional[torch.Generator] = None,
|
244 |
+
return_dict: bool = True,
|
245 |
+
) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
|
246 |
+
"""
|
247 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
248 |
+
process from the learned model outputs (most often the predicted noise).
|
249 |
+
|
250 |
+
Args:
|
251 |
+
model_output (`torch.FloatTensor`):
|
252 |
+
The direct output from learned diffusion model.
|
253 |
+
timestep (`float`):
|
254 |
+
The current discrete timestep in the diffusion chain.
|
255 |
+
sample (`torch.FloatTensor`):
|
256 |
+
A current instance of a sample created by the diffusion process.
|
257 |
+
s_churn (`float`):
|
258 |
+
s_tmin (`float`):
|
259 |
+
s_tmax (`float`):
|
260 |
+
s_noise (`float`, defaults to 1.0):
|
261 |
+
Scaling factor for noise added to the sample.
|
262 |
+
generator (`torch.Generator`, *optional*):
|
263 |
+
A random number generator.
|
264 |
+
return_dict (`bool`):
|
265 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
266 |
+
tuple.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
270 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
271 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
272 |
+
"""
|
273 |
+
|
274 |
+
if (
|
275 |
+
isinstance(timestep, int)
|
276 |
+
or isinstance(timestep, torch.IntTensor)
|
277 |
+
or isinstance(timestep, torch.LongTensor)
|
278 |
+
):
|
279 |
+
raise ValueError(
|
280 |
+
(
|
281 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
282 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
283 |
+
" one of the `scheduler.timesteps` as a timestep."
|
284 |
+
),
|
285 |
+
)
|
286 |
+
|
287 |
+
if self.step_index is None:
|
288 |
+
self._init_step_index(timestep)
|
289 |
+
|
290 |
+
# Upcast to avoid precision issues when computing prev_sample
|
291 |
+
sample = sample.to(torch.float32)
|
292 |
+
|
293 |
+
sigma = self.sigmas[self.step_index]
|
294 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
295 |
+
|
296 |
+
# Here different directions are used for the flow matching
|
297 |
+
prev_sample = sample + (sigma - sigma_next) * model_output
|
298 |
+
|
299 |
+
# Cast sample back to model compatible dtype
|
300 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
301 |
+
|
302 |
+
# upon completion increase step index by one
|
303 |
+
self._step_index += 1
|
304 |
+
|
305 |
+
if not return_dict:
|
306 |
+
return (prev_sample,)
|
307 |
+
|
308 |
+
return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
|
309 |
+
|
310 |
+
def scale_noise(
|
311 |
+
self,
|
312 |
+
original_samples: torch.Tensor,
|
313 |
+
noise: torch.Tensor,
|
314 |
+
timesteps: torch.IntTensor,
|
315 |
+
) -> torch.Tensor:
|
316 |
+
"""
|
317 |
+
Forward function for the noise scaling in the flow matching.
|
318 |
+
"""
|
319 |
+
sigmas = self._t_to_sigma(timesteps.to(dtype=torch.float32))
|
320 |
+
|
321 |
+
while len(sigmas.shape) < len(original_samples.shape):
|
322 |
+
sigmas = sigmas.unsqueeze(-1)
|
323 |
+
|
324 |
+
return (1.0 - sigmas) * original_samples + sigmas * noise
|
325 |
+
|
326 |
+
def __len__(self):
|
327 |
+
return self.config.num_train_timesteps
|
midi/utils/smoothing.py
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
"""
|
4 |
+
Utilities for smoothing the occ/sdf grids.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from scipy import ndimage as ndi
|
14 |
+
from scipy import sparse
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"smooth",
|
18 |
+
"smooth_constrained",
|
19 |
+
"smooth_gaussian",
|
20 |
+
"signed_distance_function",
|
21 |
+
"smooth_gpu",
|
22 |
+
"smooth_constrained_gpu",
|
23 |
+
"smooth_gaussian_gpu",
|
24 |
+
"signed_distance_function_gpu",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def _build_variable_indices(band: np.ndarray) -> np.ndarray:
|
29 |
+
num_variables = np.count_nonzero(band)
|
30 |
+
variable_indices = np.full(band.shape, -1, dtype=np.int_)
|
31 |
+
variable_indices[band] = np.arange(num_variables)
|
32 |
+
return variable_indices
|
33 |
+
|
34 |
+
|
35 |
+
def _buildq3d(variable_indices: np.ndarray):
|
36 |
+
"""
|
37 |
+
Builds the filterq matrix for the given variables.
|
38 |
+
"""
|
39 |
+
|
40 |
+
num_variables = variable_indices.max() + 1
|
41 |
+
filterq = sparse.lil_matrix((3 * num_variables, num_variables))
|
42 |
+
|
43 |
+
# Pad variable_indices to simplify out-of-bounds accesses
|
44 |
+
variable_indices = np.pad(
|
45 |
+
variable_indices, [(0, 1), (0, 1), (0, 1)], mode="constant", constant_values=-1
|
46 |
+
)
|
47 |
+
|
48 |
+
coords = np.nonzero(variable_indices >= 0)
|
49 |
+
for count, (i, j, k) in enumerate(zip(*coords)):
|
50 |
+
|
51 |
+
assert variable_indices[i, j, k] == count
|
52 |
+
|
53 |
+
filterq[3 * count, count] = -2
|
54 |
+
neighbor = variable_indices[i - 1, j, k]
|
55 |
+
if neighbor >= 0:
|
56 |
+
filterq[3 * count, neighbor] = 1
|
57 |
+
else:
|
58 |
+
filterq[3 * count, count] += 1
|
59 |
+
|
60 |
+
neighbor = variable_indices[i + 1, j, k]
|
61 |
+
if neighbor >= 0:
|
62 |
+
filterq[3 * count, neighbor] = 1
|
63 |
+
else:
|
64 |
+
filterq[3 * count, count] += 1
|
65 |
+
|
66 |
+
filterq[3 * count + 1, count] = -2
|
67 |
+
neighbor = variable_indices[i, j - 1, k]
|
68 |
+
if neighbor >= 0:
|
69 |
+
filterq[3 * count + 1, neighbor] = 1
|
70 |
+
else:
|
71 |
+
filterq[3 * count + 1, count] += 1
|
72 |
+
|
73 |
+
neighbor = variable_indices[i, j + 1, k]
|
74 |
+
if neighbor >= 0:
|
75 |
+
filterq[3 * count + 1, neighbor] = 1
|
76 |
+
else:
|
77 |
+
filterq[3 * count + 1, count] += 1
|
78 |
+
|
79 |
+
filterq[3 * count + 2, count] = -2
|
80 |
+
neighbor = variable_indices[i, j, k - 1]
|
81 |
+
if neighbor >= 0:
|
82 |
+
filterq[3 * count + 2, neighbor] = 1
|
83 |
+
else:
|
84 |
+
filterq[3 * count + 2, count] += 1
|
85 |
+
|
86 |
+
neighbor = variable_indices[i, j, k + 1]
|
87 |
+
if neighbor >= 0:
|
88 |
+
filterq[3 * count + 2, neighbor] = 1
|
89 |
+
else:
|
90 |
+
filterq[3 * count + 2, count] += 1
|
91 |
+
|
92 |
+
filterq = filterq.tocsr()
|
93 |
+
return filterq.T.dot(filterq)
|
94 |
+
|
95 |
+
|
96 |
+
def _buildq3d_gpu(variable_indices: torch.Tensor, chunk_size=10000):
|
97 |
+
"""
|
98 |
+
Builds the filterq matrix for the given variables on GPU, using chunking to reduce memory usage.
|
99 |
+
"""
|
100 |
+
device = variable_indices.device
|
101 |
+
num_variables = variable_indices.max().item() + 1
|
102 |
+
|
103 |
+
# Pad variable_indices to simplify out-of-bounds accesses
|
104 |
+
variable_indices = torch.nn.functional.pad(
|
105 |
+
variable_indices, (0, 1, 0, 1, 0, 1), mode="constant", value=-1
|
106 |
+
)
|
107 |
+
|
108 |
+
coords = torch.nonzero(variable_indices >= 0)
|
109 |
+
i, j, k = coords[:, 0], coords[:, 1], coords[:, 2]
|
110 |
+
|
111 |
+
# Function to process a chunk of data
|
112 |
+
def process_chunk(start, end):
|
113 |
+
row_indices = []
|
114 |
+
col_indices = []
|
115 |
+
values = []
|
116 |
+
|
117 |
+
for axis in range(3):
|
118 |
+
row_indices.append(3 * torch.arange(start, end, device=device) + axis)
|
119 |
+
col_indices.append(
|
120 |
+
variable_indices[i[start:end], j[start:end], k[start:end]]
|
121 |
+
)
|
122 |
+
values.append(torch.full((end - start,), -2, device=device))
|
123 |
+
|
124 |
+
for offset in [-1, 1]:
|
125 |
+
if axis == 0:
|
126 |
+
neighbor = variable_indices[
|
127 |
+
i[start:end] + offset, j[start:end], k[start:end]
|
128 |
+
]
|
129 |
+
elif axis == 1:
|
130 |
+
neighbor = variable_indices[
|
131 |
+
i[start:end], j[start:end] + offset, k[start:end]
|
132 |
+
]
|
133 |
+
else:
|
134 |
+
neighbor = variable_indices[
|
135 |
+
i[start:end], j[start:end], k[start:end] + offset
|
136 |
+
]
|
137 |
+
|
138 |
+
mask = neighbor >= 0
|
139 |
+
row_indices.append(
|
140 |
+
3 * torch.arange(start, end, device=device)[mask] + axis
|
141 |
+
)
|
142 |
+
col_indices.append(neighbor[mask])
|
143 |
+
values.append(torch.ones(mask.sum(), device=device))
|
144 |
+
|
145 |
+
# Add 1 to the diagonal for out-of-bounds neighbors
|
146 |
+
row_indices.append(
|
147 |
+
3 * torch.arange(start, end, device=device)[~mask] + axis
|
148 |
+
)
|
149 |
+
col_indices.append(
|
150 |
+
variable_indices[i[start:end], j[start:end], k[start:end]][~mask]
|
151 |
+
)
|
152 |
+
values.append(torch.ones((~mask).sum(), device=device))
|
153 |
+
|
154 |
+
return torch.cat(row_indices), torch.cat(col_indices), torch.cat(values)
|
155 |
+
|
156 |
+
# Process data in chunks
|
157 |
+
all_row_indices = []
|
158 |
+
all_col_indices = []
|
159 |
+
all_values = []
|
160 |
+
|
161 |
+
for start in range(0, coords.shape[0], chunk_size):
|
162 |
+
end = min(start + chunk_size, coords.shape[0])
|
163 |
+
row_indices, col_indices, values = process_chunk(start, end)
|
164 |
+
all_row_indices.append(row_indices)
|
165 |
+
all_col_indices.append(col_indices)
|
166 |
+
all_values.append(values)
|
167 |
+
|
168 |
+
# Concatenate all chunks
|
169 |
+
row_indices = torch.cat(all_row_indices)
|
170 |
+
col_indices = torch.cat(all_col_indices)
|
171 |
+
values = torch.cat(all_values)
|
172 |
+
|
173 |
+
# Create sparse tensor
|
174 |
+
indices = torch.stack([row_indices, col_indices])
|
175 |
+
filterq = torch.sparse_coo_tensor(
|
176 |
+
indices, values, (3 * num_variables, num_variables)
|
177 |
+
)
|
178 |
+
|
179 |
+
# Compute filterq.T @ filterq
|
180 |
+
return torch.sparse.mm(filterq.t(), filterq)
|
181 |
+
|
182 |
+
|
183 |
+
# Usage example:
|
184 |
+
# variable_indices = torch.tensor(...).cuda() # Your input tensor on GPU
|
185 |
+
# result = _buildq3d_gpu(variable_indices)
|
186 |
+
|
187 |
+
|
188 |
+
def _buildq2d(variable_indices: np.ndarray):
|
189 |
+
"""
|
190 |
+
Builds the filterq matrix for the given variables.
|
191 |
+
|
192 |
+
Version for 2 dimensions.
|
193 |
+
"""
|
194 |
+
|
195 |
+
num_variables = variable_indices.max() + 1
|
196 |
+
filterq = sparse.lil_matrix((3 * num_variables, num_variables))
|
197 |
+
|
198 |
+
# Pad variable_indices to simplify out-of-bounds accesses
|
199 |
+
variable_indices = np.pad(
|
200 |
+
variable_indices, [(0, 1), (0, 1)], mode="constant", constant_values=-1
|
201 |
+
)
|
202 |
+
|
203 |
+
coords = np.nonzero(variable_indices >= 0)
|
204 |
+
for count, (i, j) in enumerate(zip(*coords)):
|
205 |
+
assert variable_indices[i, j] == count
|
206 |
+
|
207 |
+
filterq[2 * count, count] = -2
|
208 |
+
neighbor = variable_indices[i - 1, j]
|
209 |
+
if neighbor >= 0:
|
210 |
+
filterq[2 * count, neighbor] = 1
|
211 |
+
else:
|
212 |
+
filterq[2 * count, count] += 1
|
213 |
+
|
214 |
+
neighbor = variable_indices[i + 1, j]
|
215 |
+
if neighbor >= 0:
|
216 |
+
filterq[2 * count, neighbor] = 1
|
217 |
+
else:
|
218 |
+
filterq[2 * count, count] += 1
|
219 |
+
|
220 |
+
filterq[2 * count + 1, count] = -2
|
221 |
+
neighbor = variable_indices[i, j - 1]
|
222 |
+
if neighbor >= 0:
|
223 |
+
filterq[2 * count + 1, neighbor] = 1
|
224 |
+
else:
|
225 |
+
filterq[2 * count + 1, count] += 1
|
226 |
+
|
227 |
+
neighbor = variable_indices[i, j + 1]
|
228 |
+
if neighbor >= 0:
|
229 |
+
filterq[2 * count + 1, neighbor] = 1
|
230 |
+
else:
|
231 |
+
filterq[2 * count + 1, count] += 1
|
232 |
+
|
233 |
+
filterq = filterq.tocsr()
|
234 |
+
return filterq.T.dot(filterq)
|
235 |
+
|
236 |
+
|
237 |
+
def _jacobi(
|
238 |
+
filterq,
|
239 |
+
x0: np.ndarray,
|
240 |
+
lower_bound: np.ndarray,
|
241 |
+
upper_bound: np.ndarray,
|
242 |
+
max_iters: int = 10,
|
243 |
+
rel_tol: float = 1e-6,
|
244 |
+
weight: float = 0.5,
|
245 |
+
):
|
246 |
+
"""Jacobi method with constraints."""
|
247 |
+
|
248 |
+
jacobi_r = sparse.lil_matrix(filterq)
|
249 |
+
shp = jacobi_r.shape
|
250 |
+
jacobi_d = 1.0 / filterq.diagonal()
|
251 |
+
jacobi_r.setdiag((0,) * shp[0])
|
252 |
+
jacobi_r = jacobi_r.tocsr()
|
253 |
+
|
254 |
+
x = x0
|
255 |
+
|
256 |
+
# We check the stopping criterion each 10 iterations
|
257 |
+
check_each = 10
|
258 |
+
cum_rel_tol = 1 - (1 - rel_tol) ** check_each
|
259 |
+
|
260 |
+
energy_now = np.dot(x, filterq.dot(x)) / 2
|
261 |
+
logging.info("Energy at iter %d: %.6g", 0, energy_now)
|
262 |
+
for i in range(max_iters):
|
263 |
+
|
264 |
+
x_1 = -jacobi_d * jacobi_r.dot(x)
|
265 |
+
x = weight * x_1 + (1 - weight) * x
|
266 |
+
|
267 |
+
# Constraints.
|
268 |
+
x = np.maximum(x, lower_bound)
|
269 |
+
x = np.minimum(x, upper_bound)
|
270 |
+
|
271 |
+
# Stopping criterion
|
272 |
+
if (i + 1) % check_each == 0:
|
273 |
+
# Update energy
|
274 |
+
energy_before = energy_now
|
275 |
+
energy_now = np.dot(x, filterq.dot(x)) / 2
|
276 |
+
|
277 |
+
logging.info("Energy at iter %d: %.6g", i + 1, energy_now)
|
278 |
+
|
279 |
+
# Check stopping criterion
|
280 |
+
cum_rel_improvement = (energy_before - energy_now) / energy_before
|
281 |
+
if cum_rel_improvement < cum_rel_tol:
|
282 |
+
break
|
283 |
+
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
def signed_distance_function(
|
288 |
+
levelset: np.ndarray, band_radius: int
|
289 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
290 |
+
"""
|
291 |
+
Return the distance to the 0.5 levelset of a function, the mask of the
|
292 |
+
border (i.e., the nearest cells to the 0.5 level-set) and the mask of the
|
293 |
+
band (i.e., the cells of the function whose distance to the 0.5 level-set
|
294 |
+
is less of equal to `band_radius`).
|
295 |
+
"""
|
296 |
+
|
297 |
+
binary_array = np.where(levelset > 0, True, False)
|
298 |
+
|
299 |
+
# Compute the band and the border.
|
300 |
+
dist_func = ndi.distance_transform_edt
|
301 |
+
distance = np.where(
|
302 |
+
binary_array, dist_func(binary_array) - 0.5, -dist_func(~binary_array) + 0.5
|
303 |
+
)
|
304 |
+
border = np.abs(distance) < 1
|
305 |
+
band = np.abs(distance) <= band_radius
|
306 |
+
|
307 |
+
return distance, border, band
|
308 |
+
|
309 |
+
|
310 |
+
def signed_distance_function_iso0(
|
311 |
+
levelset: np.ndarray, band_radius: int
|
312 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
313 |
+
"""
|
314 |
+
Return the distance to the 0 levelset of a function, the mask of the
|
315 |
+
border (i.e., the nearest cells to the 0 level-set) and the mask of the
|
316 |
+
band (i.e., the cells of the function whose distance to the 0 level-set
|
317 |
+
is less of equal to `band_radius`).
|
318 |
+
"""
|
319 |
+
|
320 |
+
binary_array = levelset > 0
|
321 |
+
|
322 |
+
# Compute the band and the border.
|
323 |
+
dist_func = ndi.distance_transform_edt
|
324 |
+
distance = np.where(
|
325 |
+
binary_array, dist_func(binary_array), -dist_func(~binary_array)
|
326 |
+
)
|
327 |
+
border = np.zeros_like(levelset, dtype=bool)
|
328 |
+
border[:-1, :, :] |= levelset[:-1, :, :] * levelset[1:, :, :] <= 0
|
329 |
+
border[:, :-1, :] |= levelset[:, :-1, :] * levelset[:, 1:, :] <= 0
|
330 |
+
border[:, :, :-1] |= levelset[:, :, :-1] * levelset[:, :, 1:] <= 0
|
331 |
+
band = np.abs(distance) <= band_radius
|
332 |
+
|
333 |
+
return distance, border, band
|
334 |
+
|
335 |
+
|
336 |
+
def signed_distance_function_gpu(levelset: torch.Tensor, band_radius: int):
|
337 |
+
binary_array = (levelset > 0).float()
|
338 |
+
|
339 |
+
# Compute distance transform
|
340 |
+
dist_pos = (
|
341 |
+
F.max_pool3d(
|
342 |
+
-binary_array.unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
|
343 |
+
)
|
344 |
+
.squeeze(0)
|
345 |
+
.squeeze(0)
|
346 |
+
+ binary_array
|
347 |
+
)
|
348 |
+
dist_neg = F.max_pool3d(
|
349 |
+
(binary_array - 1).unsqueeze(0).unsqueeze(0), kernel_size=3, stride=1, padding=1
|
350 |
+
).squeeze(0).squeeze(0) + (1 - binary_array)
|
351 |
+
|
352 |
+
distance = torch.where(binary_array > 0, dist_pos - 0.5, -dist_neg + 0.5)
|
353 |
+
|
354 |
+
# breakpoint()
|
355 |
+
|
356 |
+
# Use levelset as distance directly
|
357 |
+
# distance = levelset
|
358 |
+
# print(distance.shape)
|
359 |
+
# Compute border and band
|
360 |
+
border = torch.abs(distance) < 1
|
361 |
+
band = torch.abs(distance) <= band_radius
|
362 |
+
|
363 |
+
return distance, border, band
|
364 |
+
|
365 |
+
|
366 |
+
def smooth_constrained(
|
367 |
+
binary_array: np.ndarray,
|
368 |
+
band_radius: int = 4,
|
369 |
+
max_iters: int = 250,
|
370 |
+
rel_tol: float = 1e-6,
|
371 |
+
) -> np.ndarray:
|
372 |
+
"""
|
373 |
+
Implementation of the smoothing method from
|
374 |
+
|
375 |
+
"Surface Extraction from Binary Volumes with Higher-Order Smoothness"
|
376 |
+
Victor Lempitsky, CVPR10
|
377 |
+
"""
|
378 |
+
|
379 |
+
# # Compute the distance map, the border and the band.
|
380 |
+
logging.info("Computing distance transform...")
|
381 |
+
# distance, _, band = signed_distance_function(binary_array, band_radius)
|
382 |
+
binary_array_gpu = torch.from_numpy(binary_array).cuda()
|
383 |
+
distance, _, band = signed_distance_function_gpu(binary_array_gpu, band_radius)
|
384 |
+
distance = distance.cpu().numpy()
|
385 |
+
band = band.cpu().numpy()
|
386 |
+
|
387 |
+
variable_indices = _build_variable_indices(band)
|
388 |
+
|
389 |
+
# Compute filterq.
|
390 |
+
logging.info("Building matrix filterq...")
|
391 |
+
if binary_array.ndim == 3:
|
392 |
+
filterq = _buildq3d(variable_indices)
|
393 |
+
# variable_indices_gpu = torch.from_numpy(variable_indices).cuda()
|
394 |
+
# filterq_gpu = _buildq3d_gpu(variable_indices_gpu)
|
395 |
+
# filterq = filterq_gpu.cpu().numpy()
|
396 |
+
elif binary_array.ndim == 2:
|
397 |
+
filterq = _buildq2d(variable_indices)
|
398 |
+
else:
|
399 |
+
raise ValueError("binary_array.ndim not in [2, 3]")
|
400 |
+
|
401 |
+
# Initialize the variables.
|
402 |
+
res = np.asarray(distance, dtype=np.double)
|
403 |
+
x = res[band]
|
404 |
+
upper_bound = np.where(x < 0, x, np.inf)
|
405 |
+
lower_bound = np.where(x > 0, x, -np.inf)
|
406 |
+
|
407 |
+
upper_bound[np.abs(upper_bound) < 1] = 0
|
408 |
+
lower_bound[np.abs(lower_bound) < 1] = 0
|
409 |
+
|
410 |
+
# Solve.
|
411 |
+
logging.info("Minimizing energy...")
|
412 |
+
x = _jacobi(
|
413 |
+
filterq=filterq,
|
414 |
+
x0=x,
|
415 |
+
lower_bound=lower_bound,
|
416 |
+
upper_bound=upper_bound,
|
417 |
+
max_iters=max_iters,
|
418 |
+
rel_tol=rel_tol,
|
419 |
+
)
|
420 |
+
|
421 |
+
res[band] = x
|
422 |
+
return res
|
423 |
+
|
424 |
+
|
425 |
+
def total_variation_denoising(x, weight=0.1, num_iterations=5, eps=1e-8):
|
426 |
+
diff_x = torch.diff(x, dim=0, prepend=x[:1])
|
427 |
+
diff_y = torch.diff(x, dim=1, prepend=x[:, :1])
|
428 |
+
diff_z = torch.diff(x, dim=2, prepend=x[:, :, :1])
|
429 |
+
|
430 |
+
norm = torch.sqrt(diff_x**2 + diff_y**2 + diff_z**2 + eps)
|
431 |
+
|
432 |
+
div_x = torch.diff(diff_x / norm, dim=0, append=diff_x[-1:] / norm[-1:])
|
433 |
+
div_y = torch.diff(diff_y / norm, dim=1, append=diff_y[:, -1:] / norm[:, -1:])
|
434 |
+
div_z = torch.diff(diff_z / norm, dim=2, append=diff_z[:, :, -1:] / norm[:, :, -1:])
|
435 |
+
|
436 |
+
return x - weight * (div_x + div_y + div_z)
|
437 |
+
|
438 |
+
|
439 |
+
def smooth_constrained_gpu(
|
440 |
+
binary_array: torch.Tensor,
|
441 |
+
band_radius: int = 4,
|
442 |
+
max_iters: int = 250,
|
443 |
+
rel_tol: float = 1e-4,
|
444 |
+
):
|
445 |
+
distance, _, band = signed_distance_function_gpu(binary_array, band_radius)
|
446 |
+
|
447 |
+
# Initialize variables
|
448 |
+
x = distance[band]
|
449 |
+
upper_bound = torch.where(x < 0, x, torch.tensor(float("inf"), device=x.device))
|
450 |
+
lower_bound = torch.where(x > 0, x, torch.tensor(float("-inf"), device=x.device))
|
451 |
+
|
452 |
+
upper_bound[torch.abs(upper_bound) < 1] = 0
|
453 |
+
lower_bound[torch.abs(lower_bound) < 1] = 0
|
454 |
+
|
455 |
+
# Define the 3D Laplacian kernel
|
456 |
+
laplacian_kernel = torch.tensor(
|
457 |
+
[
|
458 |
+
[
|
459 |
+
[
|
460 |
+
[[0, 1, 0], [1, -6, 1], [0, 1, 0]],
|
461 |
+
[[1, 0, 1], [0, 0, 0], [1, 0, 1]],
|
462 |
+
[[0, 1, 0], [1, 0, 1], [0, 1, 0]],
|
463 |
+
]
|
464 |
+
]
|
465 |
+
],
|
466 |
+
device=x.device,
|
467 |
+
).float()
|
468 |
+
|
469 |
+
laplacian_kernel = laplacian_kernel / laplacian_kernel.abs().sum()
|
470 |
+
|
471 |
+
breakpoint()
|
472 |
+
|
473 |
+
# Simplified Jacobi iteration
|
474 |
+
for i in range(max_iters):
|
475 |
+
# Reshape x to 5D tensor (batch, channel, depth, height, width)
|
476 |
+
x_5d = x.view(1, 1, *band.shape)
|
477 |
+
x_3d = x.view(*band.shape)
|
478 |
+
|
479 |
+
# Apply 3D convolution
|
480 |
+
laplacian = F.conv3d(x_5d, laplacian_kernel, padding=1)
|
481 |
+
|
482 |
+
# Reshape back to original dimensions
|
483 |
+
laplacian = laplacian.view(x.shape)
|
484 |
+
|
485 |
+
# Use a small relaxation factor to improve stability
|
486 |
+
relaxation_factor = 0.1
|
487 |
+
tv_weight = 0.1
|
488 |
+
# x_new = x + relaxation_factor * laplacian
|
489 |
+
x_new = total_variation_denoising(x_3d, weight=tv_weight)
|
490 |
+
# Print laplacian min and max
|
491 |
+
# print(f"Laplacian min: {laplacian.min().item():.4f}, max: {laplacian.max().item():.4f}")
|
492 |
+
|
493 |
+
# Apply constraints
|
494 |
+
# Reshape x_new to match the dimensions of lower_bound and upper_bound
|
495 |
+
x_new = x_new.view(x.shape)
|
496 |
+
x_new = torch.clamp(x_new, min=lower_bound, max=upper_bound)
|
497 |
+
|
498 |
+
# Check for convergence
|
499 |
+
diff_norm = torch.norm(x_new - x)
|
500 |
+
print(diff_norm)
|
501 |
+
x_norm = torch.norm(x)
|
502 |
+
|
503 |
+
if x_norm > 1e-8: # Avoid division by very small numbers
|
504 |
+
relative_change = diff_norm / x_norm
|
505 |
+
if relative_change < rel_tol:
|
506 |
+
break
|
507 |
+
elif diff_norm < rel_tol: # If x_norm is very small, check absolute change
|
508 |
+
break
|
509 |
+
|
510 |
+
x = x_new
|
511 |
+
|
512 |
+
# Check for NaN and break if found, also check for inf
|
513 |
+
if torch.isnan(x).any() or torch.isinf(x).any():
|
514 |
+
print(f"NaN or Inf detected at iteration {i}")
|
515 |
+
breakpoint()
|
516 |
+
break
|
517 |
+
|
518 |
+
result = distance.clone()
|
519 |
+
result[band] = x
|
520 |
+
return result
|
521 |
+
|
522 |
+
|
523 |
+
def smooth_gaussian(binary_array: np.ndarray, sigma: float = 3) -> np.ndarray:
|
524 |
+
vol = np.float_(binary_array) - 0.5
|
525 |
+
return ndi.gaussian_filter(vol, sigma=sigma)
|
526 |
+
|
527 |
+
|
528 |
+
def smooth_gaussian_gpu(binary_array: torch.Tensor, sigma: float = 3):
|
529 |
+
# vol = binary_array.float()
|
530 |
+
vol = binary_array
|
531 |
+
kernel_size = int(2 * sigma + 1)
|
532 |
+
kernel = torch.ones(
|
533 |
+
1,
|
534 |
+
1,
|
535 |
+
kernel_size,
|
536 |
+
kernel_size,
|
537 |
+
kernel_size,
|
538 |
+
device=binary_array.device,
|
539 |
+
dtype=vol.dtype,
|
540 |
+
) / (kernel_size**3)
|
541 |
+
return F.conv3d(
|
542 |
+
vol.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2
|
543 |
+
).squeeze()
|
544 |
+
|
545 |
+
|
546 |
+
def smooth(binary_array: np.ndarray, method: str = "auto", **kwargs) -> np.ndarray:
|
547 |
+
"""
|
548 |
+
Smooths the 0.5 level-set of a binary array. Returns a floating-point
|
549 |
+
array with a smoothed version of the original level-set in the 0 isovalue.
|
550 |
+
|
551 |
+
This function can apply two different methods:
|
552 |
+
|
553 |
+
- A constrained smoothing method which preserves details and fine
|
554 |
+
structures, but it is slow and requires a large amount of memory. This
|
555 |
+
method is recommended when the input array is small (smaller than
|
556 |
+
(500, 500, 500)).
|
557 |
+
- A Gaussian filter applied over the binary array. This method is fast, but
|
558 |
+
not very precise, as it can destroy fine details. It is only recommended
|
559 |
+
when the input array is large and the 0.5 level-set does not contain
|
560 |
+
thin structures.
|
561 |
+
|
562 |
+
Parameters
|
563 |
+
----------
|
564 |
+
binary_array : ndarray
|
565 |
+
Input binary array with the 0.5 level-set to smooth.
|
566 |
+
method : str, one of ['auto', 'gaussian', 'constrained']
|
567 |
+
Smoothing method. If 'auto' is given, the method will be automatically
|
568 |
+
chosen based on the size of `binary_array`.
|
569 |
+
|
570 |
+
Parameters for 'gaussian'
|
571 |
+
-------------------------
|
572 |
+
sigma : float
|
573 |
+
Size of the Gaussian filter (default 3).
|
574 |
+
|
575 |
+
Parameters for 'constrained'
|
576 |
+
----------------------------
|
577 |
+
max_iters : positive integer
|
578 |
+
Number of iterations of the constrained optimization method
|
579 |
+
(default 250).
|
580 |
+
rel_tol: float
|
581 |
+
Relative tolerance as a stopping criterion (default 1e-6).
|
582 |
+
|
583 |
+
Output
|
584 |
+
------
|
585 |
+
res : ndarray
|
586 |
+
Floating-point array with a smoothed 0 level-set.
|
587 |
+
"""
|
588 |
+
|
589 |
+
binary_array = np.asarray(binary_array)
|
590 |
+
|
591 |
+
if method == "auto":
|
592 |
+
if binary_array.size > 512**3:
|
593 |
+
method = "gaussian"
|
594 |
+
else:
|
595 |
+
method = "constrained"
|
596 |
+
|
597 |
+
if method == "gaussian":
|
598 |
+
return smooth_gaussian(binary_array, **kwargs)
|
599 |
+
|
600 |
+
if method == "constrained":
|
601 |
+
return smooth_constrained(binary_array, **kwargs)
|
602 |
+
|
603 |
+
raise ValueError("Unknown method '{}'".format(method))
|
604 |
+
|
605 |
+
|
606 |
+
def smooth_gpu(binary_array: torch.Tensor, method: str = "auto", **kwargs):
|
607 |
+
if method == "auto":
|
608 |
+
method = "gaussian" if binary_array.numel() > 512**3 else "constrained"
|
609 |
+
|
610 |
+
if method == "gaussian":
|
611 |
+
return smooth_gaussian_gpu(binary_array, **kwargs)
|
612 |
+
elif method == "constrained":
|
613 |
+
return smooth_constrained_gpu(binary_array, **kwargs)
|
614 |
+
else:
|
615 |
+
raise ValueError(f"Unknown method '{method}'")
|