Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
a207590
0
Parent(s):
init
Browse files- .gitattributes +37 -0
- README.md +14 -0
- app.py +293 -0
- inference_ig2mv_sdxl.py +286 -0
- mvadapter/__init__.py +0 -0
- mvadapter/loaders/__init__.py +1 -0
- mvadapter/loaders/custom_adapter.py +98 -0
- mvadapter/models/__init__.py +0 -0
- mvadapter/models/attention_processor.py +743 -0
- mvadapter/pipelines/pipeline_mvadapter_i2mv_sd.py +777 -0
- mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py +962 -0
- mvadapter/pipelines/pipeline_mvadapter_t2mv_sd.py +634 -0
- mvadapter/pipelines/pipeline_mvadapter_t2mv_sdxl.py +801 -0
- mvadapter/schedulers/scheduler_utils.py +70 -0
- mvadapter/schedulers/scheduling_shift_snr.py +138 -0
- mvadapter/utils/__init__.py +3 -0
- mvadapter/utils/camera.py +211 -0
- mvadapter/utils/geometry.py +253 -0
- mvadapter/utils/logging.py +340 -0
- mvadapter/utils/render.py +499 -0
- mvadapter/utils/saving.py +88 -0
- requirements.txt +23 -0
.gitattributes
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.glb filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MV Adapter Img2Texture
|
3 |
+
emoji: 🔮
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.23.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: Generate 3D texture from image
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
import subprocess
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import spaces
|
10 |
+
import torch
|
11 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
12 |
+
from PIL import Image
|
13 |
+
from torchvision import transforms
|
14 |
+
from transformers import AutoModelForImageSegmentation
|
15 |
+
|
16 |
+
from inference_ig2mv_sdxl import (
|
17 |
+
prepare_pipeline,
|
18 |
+
preprocess_image,
|
19 |
+
remove_bg,
|
20 |
+
run_pipeline,
|
21 |
+
)
|
22 |
+
from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
|
23 |
+
|
24 |
+
# install others
|
25 |
+
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
|
26 |
+
|
27 |
+
|
28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
DTYPE = torch.float16
|
30 |
+
MAX_SEED = np.iinfo(np.int32).max
|
31 |
+
NUM_VIEWS = 6
|
32 |
+
HEIGHT = 768
|
33 |
+
WIDTH = 768
|
34 |
+
|
35 |
+
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
36 |
+
os.makedirs(TMP_DIR, exist_ok=True)
|
37 |
+
|
38 |
+
|
39 |
+
HEADER = """
|
40 |
+
# 🔮 Image to Texture with [MV-Adapter](https://github.com/huanngzh/MV-Adapter)
|
41 |
+
## State-of-the-art Open Source Texture Generation Using Multi-View Diffusion Model
|
42 |
+
"""
|
43 |
+
|
44 |
+
EXAMPLES = [
|
45 |
+
["examples/001.jpeg", "examples/001.glb"],
|
46 |
+
["examples/002.jpeg", "examples/002.glb"],
|
47 |
+
]
|
48 |
+
|
49 |
+
# MV-Adapter
|
50 |
+
pipe = prepare_pipeline(
|
51 |
+
base_model="stabilityai/stable-diffusion-xl-base-1.0",
|
52 |
+
vae_model="madebyollin/sdxl-vae-fp16-fix",
|
53 |
+
unet_model=None,
|
54 |
+
lora_model=None,
|
55 |
+
adapter_path="huanngzh/mv-adapter",
|
56 |
+
scheduler=None,
|
57 |
+
num_views=NUM_VIEWS,
|
58 |
+
device=DEVICE,
|
59 |
+
dtype=DTYPE,
|
60 |
+
)
|
61 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
62 |
+
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
63 |
+
)
|
64 |
+
birefnet.to(DEVICE)
|
65 |
+
transform_image = transforms.Compose(
|
66 |
+
[
|
67 |
+
transforms.Resize((1024, 1024)),
|
68 |
+
transforms.ToTensor(),
|
69 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
|
73 |
+
|
74 |
+
if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
|
75 |
+
hf_hub_download(
|
76 |
+
"dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints"
|
77 |
+
)
|
78 |
+
if not os.path.exists("checkpoints/big-lama.pt"):
|
79 |
+
subprocess.run(
|
80 |
+
"wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
81 |
+
shell=True,
|
82 |
+
check=True,
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
87 |
+
|
88 |
+
|
89 |
+
def start_session(req: gr.Request):
|
90 |
+
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
91 |
+
os.makedirs(save_dir, exist_ok=True)
|
92 |
+
print("start session, mkdir", save_dir)
|
93 |
+
|
94 |
+
|
95 |
+
def end_session(req: gr.Request):
|
96 |
+
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
97 |
+
shutil.rmtree(save_dir)
|
98 |
+
|
99 |
+
|
100 |
+
def get_random_hex():
|
101 |
+
random_bytes = os.urandom(8)
|
102 |
+
random_hex = random_bytes.hex()
|
103 |
+
return random_hex
|
104 |
+
|
105 |
+
|
106 |
+
def get_random_seed(randomize_seed, seed):
|
107 |
+
if randomize_seed:
|
108 |
+
seed = random.randint(0, MAX_SEED)
|
109 |
+
return seed
|
110 |
+
|
111 |
+
|
112 |
+
@spaces.GPU(duration=90)
|
113 |
+
@torch.no_grad()
|
114 |
+
def run_mvadapter(
|
115 |
+
mesh_path,
|
116 |
+
prompt,
|
117 |
+
image,
|
118 |
+
seed=42,
|
119 |
+
guidance_scale=3.0,
|
120 |
+
num_inference_steps=30,
|
121 |
+
reference_conditioning_scale=1.0,
|
122 |
+
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
|
123 |
+
progress=gr.Progress(track_tqdm=True),
|
124 |
+
):
|
125 |
+
# pre-process the reference image
|
126 |
+
image = Image.open(image).convert("RGB") if isinstance(image, str) else image
|
127 |
+
image = remove_bg_fn(image)
|
128 |
+
image = preprocess_image(image, HEIGHT, WIDTH)
|
129 |
+
|
130 |
+
if isinstance(seed, str):
|
131 |
+
try:
|
132 |
+
seed = int(seed.strip())
|
133 |
+
except ValueError:
|
134 |
+
seed = 42
|
135 |
+
|
136 |
+
images, _, _, _ = run_pipeline(
|
137 |
+
pipe,
|
138 |
+
mesh_path=mesh_path,
|
139 |
+
num_views=NUM_VIEWS,
|
140 |
+
text=prompt,
|
141 |
+
image=image,
|
142 |
+
height=HEIGHT,
|
143 |
+
width=WIDTH,
|
144 |
+
num_inference_steps=num_inference_steps,
|
145 |
+
guidance_scale=guidance_scale,
|
146 |
+
seed=seed,
|
147 |
+
remove_bg_fn=None,
|
148 |
+
reference_conditioning_scale=reference_conditioning_scale,
|
149 |
+
negative_prompt=negative_prompt,
|
150 |
+
device=DEVICE,
|
151 |
+
)
|
152 |
+
|
153 |
+
torch.cuda.empty_cache()
|
154 |
+
|
155 |
+
return images, image
|
156 |
+
|
157 |
+
|
158 |
+
@spaces.GPU(duration=90)
|
159 |
+
@torch.no_grad()
|
160 |
+
def run_texturing(
|
161 |
+
mesh_path: str,
|
162 |
+
mv_images: List[Image.Image],
|
163 |
+
uv_unwarp: bool,
|
164 |
+
preprocess_mesh: bool,
|
165 |
+
uv_size: int,
|
166 |
+
req: gr.Request,
|
167 |
+
):
|
168 |
+
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
169 |
+
mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
|
170 |
+
mv_images = [item[0] for item in mv_images]
|
171 |
+
make_image_grid(mv_images, rows=1).save(mv_image_path)
|
172 |
+
|
173 |
+
from texture import ModProcessConfig, TexturePipeline
|
174 |
+
|
175 |
+
texture_pipe = TexturePipeline(
|
176 |
+
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
|
177 |
+
inpaint_ckpt_path="checkpoints/big-lama.pt",
|
178 |
+
device=DEVICE,
|
179 |
+
)
|
180 |
+
|
181 |
+
textured_glb_path = texture_pipe(
|
182 |
+
mesh_path=mesh_path,
|
183 |
+
save_dir=save_dir,
|
184 |
+
save_name=f"texture_mesh_{get_random_hex()}",
|
185 |
+
uv_unwarp=uv_unwarp,
|
186 |
+
preprocess_mesh=preprocess_mesh,
|
187 |
+
uv_size=uv_size,
|
188 |
+
rgb_path=mv_image_path,
|
189 |
+
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
|
190 |
+
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
|
191 |
+
).shaded_model_save_path
|
192 |
+
|
193 |
+
torch.cuda.empty_cache()
|
194 |
+
|
195 |
+
return textured_glb_path, textured_glb_path
|
196 |
+
|
197 |
+
|
198 |
+
with gr.Blocks(title="MVAdapter") as demo:
|
199 |
+
gr.Markdown(HEADER)
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column():
|
203 |
+
with gr.Row():
|
204 |
+
input_mesh = gr.Model3D(label="Input 3D mesh")
|
205 |
+
image_prompt = gr.Image(label="Input Image", type="pil")
|
206 |
+
|
207 |
+
with gr.Accordion("Generation Settings", open=False):
|
208 |
+
prompt = gr.Textbox(
|
209 |
+
label="Prompt (Optional)",
|
210 |
+
placeholder="Enter your prompt",
|
211 |
+
value="high quality",
|
212 |
+
)
|
213 |
+
seed = gr.Slider(
|
214 |
+
label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0
|
215 |
+
)
|
216 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
217 |
+
num_inference_steps = gr.Slider(
|
218 |
+
label="Number of inference steps",
|
219 |
+
minimum=8,
|
220 |
+
maximum=50,
|
221 |
+
step=1,
|
222 |
+
value=25,
|
223 |
+
)
|
224 |
+
guidance_scale = gr.Slider(
|
225 |
+
label="CFG scale",
|
226 |
+
minimum=0.0,
|
227 |
+
maximum=20.0,
|
228 |
+
step=0.1,
|
229 |
+
value=3.0,
|
230 |
+
)
|
231 |
+
reference_conditioning_scale = gr.Slider(
|
232 |
+
label="Image conditioning scale",
|
233 |
+
minimum=0.0,
|
234 |
+
maximum=2.0,
|
235 |
+
step=0.1,
|
236 |
+
value=1.0,
|
237 |
+
)
|
238 |
+
|
239 |
+
with gr.Accordion("Texture Settings", open=False):
|
240 |
+
with gr.Row():
|
241 |
+
uv_unwarp = gr.Checkbox(label="Unwarp UV", value=True)
|
242 |
+
preprocess_mesh = gr.Checkbox(label="Preprocess Mesh", value=False)
|
243 |
+
uv_size = gr.Slider(
|
244 |
+
label="UV Size", minimum=1024, maximum=8192, step=512, value=4096
|
245 |
+
)
|
246 |
+
|
247 |
+
gen_button = gr.Button("Generate Texture", variant="primary")
|
248 |
+
|
249 |
+
examples = gr.Examples(
|
250 |
+
examples=EXAMPLES,
|
251 |
+
inputs=[image_prompt, input_mesh],
|
252 |
+
outputs=[image_prompt],
|
253 |
+
)
|
254 |
+
|
255 |
+
with gr.Column():
|
256 |
+
mv_result = gr.Gallery(
|
257 |
+
label="Multi-View Results",
|
258 |
+
show_label=False,
|
259 |
+
columns=[3],
|
260 |
+
rows=[2],
|
261 |
+
object_fit="contain",
|
262 |
+
height="auto",
|
263 |
+
type="pil",
|
264 |
+
)
|
265 |
+
textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
|
266 |
+
download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
|
267 |
+
|
268 |
+
gen_button.click(
|
269 |
+
get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
|
270 |
+
).then(
|
271 |
+
run_mvadapter,
|
272 |
+
inputs=[
|
273 |
+
input_mesh,
|
274 |
+
prompt,
|
275 |
+
image_prompt,
|
276 |
+
seed,
|
277 |
+
guidance_scale,
|
278 |
+
num_inference_steps,
|
279 |
+
reference_conditioning_scale,
|
280 |
+
],
|
281 |
+
outputs=[mv_result, image_prompt],
|
282 |
+
).then(
|
283 |
+
run_texturing,
|
284 |
+
inputs=[input_mesh, mv_result, uv_unwarp, preprocess_mesh, uv_size],
|
285 |
+
outputs=[textured_model_output, download_glb],
|
286 |
+
).then(
|
287 |
+
lambda: gr.Button(interactive=True), outputs=[download_glb]
|
288 |
+
)
|
289 |
+
|
290 |
+
demo.load(start_session)
|
291 |
+
demo.unload(end_session)
|
292 |
+
|
293 |
+
demo.launch()
|
inference_ig2mv_sdxl.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
from tqdm import tqdm
|
9 |
+
from transformers import AutoModelForImageSegmentation
|
10 |
+
|
11 |
+
from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
|
12 |
+
from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
|
13 |
+
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
|
14 |
+
from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
|
15 |
+
from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
|
16 |
+
|
17 |
+
|
18 |
+
def prepare_pipeline(
|
19 |
+
base_model,
|
20 |
+
vae_model,
|
21 |
+
unet_model,
|
22 |
+
lora_model,
|
23 |
+
adapter_path,
|
24 |
+
scheduler,
|
25 |
+
num_views,
|
26 |
+
device,
|
27 |
+
dtype,
|
28 |
+
):
|
29 |
+
# Load vae and unet if provided
|
30 |
+
pipe_kwargs = {}
|
31 |
+
if vae_model is not None:
|
32 |
+
pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
|
33 |
+
if unet_model is not None:
|
34 |
+
pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
|
35 |
+
|
36 |
+
# Prepare pipeline
|
37 |
+
pipe: MVAdapterI2MVSDXLPipeline
|
38 |
+
pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
|
39 |
+
|
40 |
+
# Load scheduler if provided
|
41 |
+
scheduler_class = None
|
42 |
+
if scheduler == "ddpm":
|
43 |
+
scheduler_class = DDPMScheduler
|
44 |
+
elif scheduler == "lcm":
|
45 |
+
scheduler_class = LCMScheduler
|
46 |
+
|
47 |
+
pipe.scheduler = ShiftSNRScheduler.from_scheduler(
|
48 |
+
pipe.scheduler,
|
49 |
+
shift_mode="interpolated",
|
50 |
+
shift_scale=8.0,
|
51 |
+
scheduler_class=scheduler_class,
|
52 |
+
)
|
53 |
+
pipe.init_custom_adapter(
|
54 |
+
num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
|
55 |
+
)
|
56 |
+
pipe.load_custom_adapter(
|
57 |
+
adapter_path, weight_name="mvadapter_ig2mv_sdxl.safetensors"
|
58 |
+
)
|
59 |
+
|
60 |
+
pipe.to(device=device, dtype=dtype)
|
61 |
+
pipe.cond_encoder.to(device=device, dtype=dtype)
|
62 |
+
|
63 |
+
# load lora if provided
|
64 |
+
if lora_model is not None:
|
65 |
+
model_, name_ = lora_model.rsplit("/", 1)
|
66 |
+
pipe.load_lora_weights(model_, weight_name=name_)
|
67 |
+
|
68 |
+
return pipe
|
69 |
+
|
70 |
+
|
71 |
+
def remove_bg(image, net, transform, device):
|
72 |
+
image_size = image.size
|
73 |
+
input_images = transform(image).unsqueeze(0).to(device)
|
74 |
+
with torch.no_grad():
|
75 |
+
preds = net(input_images)[-1].sigmoid().cpu()
|
76 |
+
pred = preds[0].squeeze()
|
77 |
+
pred_pil = transforms.ToPILImage()(pred)
|
78 |
+
mask = pred_pil.resize(image_size)
|
79 |
+
image.putalpha(mask)
|
80 |
+
return image
|
81 |
+
|
82 |
+
|
83 |
+
def preprocess_image(image: Image.Image, height, width):
|
84 |
+
image = np.array(image)
|
85 |
+
alpha = image[..., 3] > 0
|
86 |
+
H, W = alpha.shape
|
87 |
+
# get the bounding box of alpha
|
88 |
+
y, x = np.where(alpha)
|
89 |
+
y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
|
90 |
+
x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
|
91 |
+
image_center = image[y0:y1, x0:x1]
|
92 |
+
# resize the longer side to H * 0.9
|
93 |
+
H, W, _ = image_center.shape
|
94 |
+
if H > W:
|
95 |
+
W = int(W * (height * 0.9) / H)
|
96 |
+
H = int(height * 0.9)
|
97 |
+
else:
|
98 |
+
H = int(H * (width * 0.9) / W)
|
99 |
+
W = int(width * 0.9)
|
100 |
+
image_center = np.array(Image.fromarray(image_center).resize((W, H)))
|
101 |
+
# pad to H, W
|
102 |
+
start_h = (height - H) // 2
|
103 |
+
start_w = (width - W) // 2
|
104 |
+
image = np.zeros((height, width, 4), dtype=np.uint8)
|
105 |
+
image[start_h : start_h + H, start_w : start_w + W] = image_center
|
106 |
+
image = image.astype(np.float32) / 255.0
|
107 |
+
image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
|
108 |
+
image = (image * 255).clip(0, 255).astype(np.uint8)
|
109 |
+
image = Image.fromarray(image)
|
110 |
+
|
111 |
+
return image
|
112 |
+
|
113 |
+
|
114 |
+
def run_pipeline(
|
115 |
+
pipe,
|
116 |
+
mesh_path,
|
117 |
+
num_views,
|
118 |
+
text,
|
119 |
+
image,
|
120 |
+
height,
|
121 |
+
width,
|
122 |
+
num_inference_steps,
|
123 |
+
guidance_scale,
|
124 |
+
seed,
|
125 |
+
remove_bg_fn=None,
|
126 |
+
reference_conditioning_scale=1.0,
|
127 |
+
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
|
128 |
+
lora_scale=1.0,
|
129 |
+
device="cuda",
|
130 |
+
):
|
131 |
+
# Prepare cameras
|
132 |
+
cameras = get_orthogonal_camera(
|
133 |
+
elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
|
134 |
+
distance=[1.8] * num_views,
|
135 |
+
left=-0.55,
|
136 |
+
right=0.55,
|
137 |
+
bottom=-0.55,
|
138 |
+
top=0.55,
|
139 |
+
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
|
140 |
+
device=device,
|
141 |
+
)
|
142 |
+
ctx = NVDiffRastContextWrapper(device=device, context_type="cuda")
|
143 |
+
|
144 |
+
mesh = load_mesh(mesh_path, rescale=True, device=device)
|
145 |
+
render_out = render(
|
146 |
+
ctx,
|
147 |
+
mesh,
|
148 |
+
cameras,
|
149 |
+
height=height,
|
150 |
+
width=width,
|
151 |
+
render_attr=False,
|
152 |
+
normal_background=0.0,
|
153 |
+
)
|
154 |
+
pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
|
155 |
+
normal_images = tensor_to_image(
|
156 |
+
(render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
|
157 |
+
)
|
158 |
+
control_images = (
|
159 |
+
torch.cat(
|
160 |
+
[
|
161 |
+
(render_out.pos + 0.5).clamp(0, 1),
|
162 |
+
(render_out.normal / 2 + 0.5).clamp(0, 1),
|
163 |
+
],
|
164 |
+
dim=-1,
|
165 |
+
)
|
166 |
+
.permute(0, 3, 1, 2)
|
167 |
+
.to(device)
|
168 |
+
)
|
169 |
+
|
170 |
+
# Prepare image
|
171 |
+
reference_image = Image.open(image) if isinstance(image, str) else image
|
172 |
+
if remove_bg_fn is not None:
|
173 |
+
reference_image = remove_bg_fn(reference_image)
|
174 |
+
reference_image = preprocess_image(reference_image, height, width)
|
175 |
+
elif reference_image.mode == "RGBA":
|
176 |
+
reference_image = preprocess_image(reference_image, height, width)
|
177 |
+
|
178 |
+
pipe_kwargs = {}
|
179 |
+
if seed != -1 and isinstance(seed, int):
|
180 |
+
pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
|
181 |
+
|
182 |
+
images = pipe(
|
183 |
+
text,
|
184 |
+
height=height,
|
185 |
+
width=width,
|
186 |
+
num_inference_steps=num_inference_steps,
|
187 |
+
guidance_scale=guidance_scale,
|
188 |
+
num_images_per_prompt=num_views,
|
189 |
+
control_image=control_images,
|
190 |
+
control_conditioning_scale=1.0,
|
191 |
+
reference_image=reference_image,
|
192 |
+
reference_conditioning_scale=reference_conditioning_scale,
|
193 |
+
negative_prompt=negative_prompt,
|
194 |
+
cross_attention_kwargs={"scale": lora_scale},
|
195 |
+
**pipe_kwargs,
|
196 |
+
).images
|
197 |
+
|
198 |
+
return images, pos_images, normal_images, reference_image
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
parser = argparse.ArgumentParser()
|
203 |
+
# Models
|
204 |
+
parser.add_argument(
|
205 |
+
"--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
|
209 |
+
)
|
210 |
+
parser.add_argument("--unet_model", type=str, default=None)
|
211 |
+
parser.add_argument("--scheduler", type=str, default=None)
|
212 |
+
parser.add_argument("--lora_model", type=str, default=None)
|
213 |
+
parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
|
214 |
+
parser.add_argument("--num_views", type=int, default=6)
|
215 |
+
# Device
|
216 |
+
parser.add_argument("--device", type=str, default="cuda")
|
217 |
+
# Inference
|
218 |
+
parser.add_argument("--mesh", type=str, required=True)
|
219 |
+
parser.add_argument("--image", type=str, required=True)
|
220 |
+
parser.add_argument("--text", type=str, required=False, default="high quality")
|
221 |
+
parser.add_argument("--num_inference_steps", type=int, default=50)
|
222 |
+
parser.add_argument("--guidance_scale", type=float, default=3.0)
|
223 |
+
parser.add_argument("--seed", type=int, default=-1)
|
224 |
+
parser.add_argument("--lora_scale", type=float, default=1.0)
|
225 |
+
parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
|
226 |
+
parser.add_argument(
|
227 |
+
"--negative_prompt",
|
228 |
+
type=str,
|
229 |
+
default="watermark, ugly, deformed, noisy, blurry, low contrast",
|
230 |
+
)
|
231 |
+
parser.add_argument("--output", type=str, default="output.png")
|
232 |
+
# Extra
|
233 |
+
parser.add_argument("--remove_bg", action="store_true", help="Remove background")
|
234 |
+
args = parser.parse_args()
|
235 |
+
|
236 |
+
pipe = prepare_pipeline(
|
237 |
+
base_model=args.base_model,
|
238 |
+
vae_model=args.vae_model,
|
239 |
+
unet_model=args.unet_model,
|
240 |
+
lora_model=args.lora_model,
|
241 |
+
adapter_path=args.adapter_path,
|
242 |
+
scheduler=args.scheduler,
|
243 |
+
num_views=args.num_views,
|
244 |
+
device=args.device,
|
245 |
+
dtype=torch.float16,
|
246 |
+
)
|
247 |
+
|
248 |
+
if args.remove_bg:
|
249 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
250 |
+
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
251 |
+
)
|
252 |
+
birefnet.to(args.device)
|
253 |
+
transform_image = transforms.Compose(
|
254 |
+
[
|
255 |
+
transforms.Resize((1024, 1024)),
|
256 |
+
transforms.ToTensor(),
|
257 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
258 |
+
]
|
259 |
+
)
|
260 |
+
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
|
261 |
+
else:
|
262 |
+
remove_bg_fn = None
|
263 |
+
|
264 |
+
images, pos_images, normal_images, reference_image = run_pipeline(
|
265 |
+
pipe,
|
266 |
+
mesh_path=args.mesh,
|
267 |
+
num_views=args.num_views,
|
268 |
+
text=args.text,
|
269 |
+
image=args.image,
|
270 |
+
height=768,
|
271 |
+
width=768,
|
272 |
+
num_inference_steps=args.num_inference_steps,
|
273 |
+
guidance_scale=args.guidance_scale,
|
274 |
+
seed=args.seed,
|
275 |
+
lora_scale=args.lora_scale,
|
276 |
+
reference_conditioning_scale=args.reference_conditioning_scale,
|
277 |
+
negative_prompt=args.negative_prompt,
|
278 |
+
device=args.device,
|
279 |
+
remove_bg_fn=remove_bg_fn,
|
280 |
+
)
|
281 |
+
make_image_grid(images, rows=1).save(args.output)
|
282 |
+
make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
|
283 |
+
make_image_grid(normal_images, rows=1).save(
|
284 |
+
args.output.rsplit(".", 1)[0] + "_nor.png"
|
285 |
+
)
|
286 |
+
reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")
|
mvadapter/__init__.py
ADDED
File without changes
|
mvadapter/loaders/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .custom_adapter import CustomAdapterMixin
|
mvadapter/loaders/custom_adapter.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Optional, Union
|
3 |
+
|
4 |
+
import safetensors
|
5 |
+
import torch
|
6 |
+
from diffusers.utils import _get_model_file, logging
|
7 |
+
from safetensors import safe_open
|
8 |
+
|
9 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
10 |
+
|
11 |
+
|
12 |
+
class CustomAdapterMixin:
|
13 |
+
def init_custom_adapter(self, *args, **kwargs):
|
14 |
+
self._init_custom_adapter(*args, **kwargs)
|
15 |
+
|
16 |
+
def _init_custom_adapter(self, *args, **kwargs):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def load_custom_adapter(
|
20 |
+
self,
|
21 |
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
22 |
+
weight_name: str,
|
23 |
+
subfolder: Optional[str] = None,
|
24 |
+
**kwargs,
|
25 |
+
):
|
26 |
+
# Load the main state dict first.
|
27 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
28 |
+
force_download = kwargs.pop("force_download", False)
|
29 |
+
proxies = kwargs.pop("proxies", None)
|
30 |
+
local_files_only = kwargs.pop("local_files_only", None)
|
31 |
+
token = kwargs.pop("token", None)
|
32 |
+
revision = kwargs.pop("revision", None)
|
33 |
+
|
34 |
+
user_agent = {
|
35 |
+
"file_type": "attn_procs_weights",
|
36 |
+
"framework": "pytorch",
|
37 |
+
}
|
38 |
+
|
39 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
40 |
+
model_file = _get_model_file(
|
41 |
+
pretrained_model_name_or_path_or_dict,
|
42 |
+
weights_name=weight_name,
|
43 |
+
subfolder=subfolder,
|
44 |
+
cache_dir=cache_dir,
|
45 |
+
force_download=force_download,
|
46 |
+
proxies=proxies,
|
47 |
+
local_files_only=local_files_only,
|
48 |
+
token=token,
|
49 |
+
revision=revision,
|
50 |
+
user_agent=user_agent,
|
51 |
+
)
|
52 |
+
if weight_name.endswith(".safetensors"):
|
53 |
+
state_dict = {}
|
54 |
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
55 |
+
for key in f.keys():
|
56 |
+
state_dict[key] = f.get_tensor(key)
|
57 |
+
else:
|
58 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
59 |
+
else:
|
60 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
61 |
+
|
62 |
+
self._load_custom_adapter(state_dict)
|
63 |
+
|
64 |
+
def _load_custom_adapter(self, state_dict):
|
65 |
+
raise NotImplementedError
|
66 |
+
|
67 |
+
def save_custom_adapter(
|
68 |
+
self,
|
69 |
+
save_directory: Union[str, os.PathLike],
|
70 |
+
weight_name: str,
|
71 |
+
safe_serialization: bool = False,
|
72 |
+
**kwargs,
|
73 |
+
):
|
74 |
+
if os.path.isfile(save_directory):
|
75 |
+
logger.error(
|
76 |
+
f"Provided path ({save_directory}) should be a directory, not a file"
|
77 |
+
)
|
78 |
+
return
|
79 |
+
|
80 |
+
if safe_serialization:
|
81 |
+
|
82 |
+
def save_function(weights, filename):
|
83 |
+
return safetensors.torch.save_file(
|
84 |
+
weights, filename, metadata={"format": "pt"}
|
85 |
+
)
|
86 |
+
|
87 |
+
else:
|
88 |
+
save_function = torch.save
|
89 |
+
|
90 |
+
# Save the model
|
91 |
+
state_dict = self._save_custom_adapter(**kwargs)
|
92 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
93 |
+
logger.info(
|
94 |
+
f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
|
95 |
+
)
|
96 |
+
|
97 |
+
def _save_custom_adapter(self):
|
98 |
+
raise NotImplementedError
|
mvadapter/models/__init__.py
ADDED
File without changes
|
mvadapter/models/attention_processor.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable, List, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from diffusers.models.attention_processor import Attention
|
7 |
+
from diffusers.models.unets import UNet2DConditionModel
|
8 |
+
from diffusers.utils import deprecate, logging
|
9 |
+
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def default_set_attn_proc_func(
|
15 |
+
name: str,
|
16 |
+
hidden_size: int,
|
17 |
+
cross_attention_dim: Optional[int],
|
18 |
+
ori_attn_proc: object,
|
19 |
+
) -> object:
|
20 |
+
return ori_attn_proc
|
21 |
+
|
22 |
+
|
23 |
+
def set_unet_2d_condition_attn_processor(
|
24 |
+
unet: UNet2DConditionModel,
|
25 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
26 |
+
set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
|
27 |
+
set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
|
28 |
+
set_self_attn_module_names: Optional[List[str]] = None,
|
29 |
+
set_cross_attn_module_names: Optional[List[str]] = None,
|
30 |
+
set_custom_attn_module_names: Optional[List[str]] = None,
|
31 |
+
) -> None:
|
32 |
+
do_set_processor = lambda name, module_names: (
|
33 |
+
any([name.startswith(module_name) for module_name in module_names])
|
34 |
+
if module_names is not None
|
35 |
+
else True
|
36 |
+
) # prefix match
|
37 |
+
|
38 |
+
attn_procs = {}
|
39 |
+
for name, attn_processor in unet.attn_processors.items():
|
40 |
+
# set attn_processor by default, if module_names is None
|
41 |
+
set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
|
42 |
+
set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
|
43 |
+
set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
|
44 |
+
|
45 |
+
if name.startswith("mid_block"):
|
46 |
+
hidden_size = unet.config.block_out_channels[-1]
|
47 |
+
elif name.startswith("up_blocks"):
|
48 |
+
block_id = int(name[len("up_blocks.")])
|
49 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
50 |
+
elif name.startswith("down_blocks"):
|
51 |
+
block_id = int(name[len("down_blocks.")])
|
52 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
53 |
+
|
54 |
+
is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
|
55 |
+
if is_custom:
|
56 |
+
attn_procs[name] = (
|
57 |
+
set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
|
58 |
+
if set_custom_attn_processor
|
59 |
+
else attn_processor
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
cross_attention_dim = (
|
63 |
+
None
|
64 |
+
if name.endswith("attn1.processor")
|
65 |
+
else unet.config.cross_attention_dim
|
66 |
+
)
|
67 |
+
if cross_attention_dim is None or "motion_modules" in name:
|
68 |
+
# self attention
|
69 |
+
attn_procs[name] = (
|
70 |
+
set_self_attn_proc_func(
|
71 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
72 |
+
)
|
73 |
+
if set_self_attn_processor
|
74 |
+
else attn_processor
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
# cross attention
|
78 |
+
attn_procs[name] = (
|
79 |
+
set_cross_attn_proc_func(
|
80 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
81 |
+
)
|
82 |
+
if set_cross_attn_processor
|
83 |
+
else attn_processor
|
84 |
+
)
|
85 |
+
|
86 |
+
unet.set_attn_processor(attn_procs)
|
87 |
+
|
88 |
+
|
89 |
+
class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
|
90 |
+
r"""
|
91 |
+
Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self,
|
96 |
+
query_dim: int,
|
97 |
+
inner_dim: int,
|
98 |
+
num_views: int = 1,
|
99 |
+
name: Optional[str] = None,
|
100 |
+
use_mv: bool = True,
|
101 |
+
use_ref: bool = False,
|
102 |
+
):
|
103 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
104 |
+
raise ImportError(
|
105 |
+
"DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
106 |
+
)
|
107 |
+
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.num_views = num_views
|
111 |
+
self.name = name # NOTE: need for image cross-attention
|
112 |
+
self.use_mv = use_mv
|
113 |
+
self.use_ref = use_ref
|
114 |
+
|
115 |
+
if self.use_mv:
|
116 |
+
self.to_q_mv = nn.Linear(
|
117 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
118 |
+
)
|
119 |
+
self.to_k_mv = nn.Linear(
|
120 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
121 |
+
)
|
122 |
+
self.to_v_mv = nn.Linear(
|
123 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
124 |
+
)
|
125 |
+
self.to_out_mv = nn.ModuleList(
|
126 |
+
[
|
127 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
128 |
+
nn.Dropout(0.0),
|
129 |
+
]
|
130 |
+
)
|
131 |
+
|
132 |
+
if self.use_ref:
|
133 |
+
self.to_q_ref = nn.Linear(
|
134 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
135 |
+
)
|
136 |
+
self.to_k_ref = nn.Linear(
|
137 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
138 |
+
)
|
139 |
+
self.to_v_ref = nn.Linear(
|
140 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
141 |
+
)
|
142 |
+
self.to_out_ref = nn.ModuleList(
|
143 |
+
[
|
144 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
145 |
+
nn.Dropout(0.0),
|
146 |
+
]
|
147 |
+
)
|
148 |
+
|
149 |
+
def __call__(
|
150 |
+
self,
|
151 |
+
attn: Attention,
|
152 |
+
hidden_states: torch.FloatTensor,
|
153 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
154 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
155 |
+
temb: Optional[torch.FloatTensor] = None,
|
156 |
+
mv_scale: float = 1.0,
|
157 |
+
ref_hidden_states: Optional[torch.FloatTensor] = None,
|
158 |
+
ref_scale: float = 1.0,
|
159 |
+
cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
|
160 |
+
use_mv: bool = True,
|
161 |
+
use_ref: bool = True,
|
162 |
+
num_views: Optional[int] = None,
|
163 |
+
*args,
|
164 |
+
**kwargs,
|
165 |
+
) -> torch.FloatTensor:
|
166 |
+
"""
|
167 |
+
New args:
|
168 |
+
mv_scale (float): scale for multi-view self-attention.
|
169 |
+
ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
|
170 |
+
ref_scale (float): scale for image cross-attention.
|
171 |
+
cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
|
172 |
+
|
173 |
+
"""
|
174 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
175 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
176 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
177 |
+
|
178 |
+
if num_views is not None:
|
179 |
+
self.num_views = num_views
|
180 |
+
|
181 |
+
# NEW: cache hidden states for reference unet
|
182 |
+
if cache_hidden_states is not None:
|
183 |
+
cache_hidden_states[self.name] = hidden_states.clone()
|
184 |
+
|
185 |
+
# NEW: whether to use multi-view attention and image cross-attention
|
186 |
+
use_mv = self.use_mv and use_mv
|
187 |
+
use_ref = self.use_ref and use_ref
|
188 |
+
|
189 |
+
residual = hidden_states
|
190 |
+
if attn.spatial_norm is not None:
|
191 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
192 |
+
|
193 |
+
input_ndim = hidden_states.ndim
|
194 |
+
|
195 |
+
if input_ndim == 4:
|
196 |
+
batch_size, channel, height, width = hidden_states.shape
|
197 |
+
hidden_states = hidden_states.view(
|
198 |
+
batch_size, channel, height * width
|
199 |
+
).transpose(1, 2)
|
200 |
+
|
201 |
+
batch_size, sequence_length, _ = (
|
202 |
+
hidden_states.shape
|
203 |
+
if encoder_hidden_states is None
|
204 |
+
else encoder_hidden_states.shape
|
205 |
+
)
|
206 |
+
|
207 |
+
if attention_mask is not None:
|
208 |
+
attention_mask = attn.prepare_attention_mask(
|
209 |
+
attention_mask, sequence_length, batch_size
|
210 |
+
)
|
211 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
212 |
+
# (batch, heads, source_length, target_length)
|
213 |
+
attention_mask = attention_mask.view(
|
214 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
215 |
+
)
|
216 |
+
|
217 |
+
if attn.group_norm is not None:
|
218 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
219 |
+
1, 2
|
220 |
+
)
|
221 |
+
|
222 |
+
query = attn.to_q(hidden_states)
|
223 |
+
|
224 |
+
# NEW: for decoupled multi-view attention
|
225 |
+
if use_mv:
|
226 |
+
query_mv = self.to_q_mv(hidden_states)
|
227 |
+
|
228 |
+
# NEW: for decoupled reference cross attention
|
229 |
+
if use_ref:
|
230 |
+
query_ref = self.to_q_ref(hidden_states)
|
231 |
+
|
232 |
+
if encoder_hidden_states is None:
|
233 |
+
encoder_hidden_states = hidden_states
|
234 |
+
elif attn.norm_cross:
|
235 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
236 |
+
encoder_hidden_states
|
237 |
+
)
|
238 |
+
|
239 |
+
key = attn.to_k(encoder_hidden_states)
|
240 |
+
value = attn.to_v(encoder_hidden_states)
|
241 |
+
|
242 |
+
inner_dim = key.shape[-1]
|
243 |
+
head_dim = inner_dim // attn.heads
|
244 |
+
|
245 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
246 |
+
|
247 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
|
250 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
251 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
252 |
+
hidden_states = F.scaled_dot_product_attention(
|
253 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
254 |
+
)
|
255 |
+
|
256 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
257 |
+
batch_size, -1, attn.heads * head_dim
|
258 |
+
)
|
259 |
+
hidden_states = hidden_states.to(query.dtype)
|
260 |
+
|
261 |
+
####### Decoupled multi-view self-attention ########
|
262 |
+
if use_mv:
|
263 |
+
key_mv = self.to_k_mv(encoder_hidden_states)
|
264 |
+
value_mv = self.to_v_mv(encoder_hidden_states)
|
265 |
+
|
266 |
+
query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
|
267 |
+
key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
|
268 |
+
value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
|
269 |
+
|
270 |
+
height = width = math.isqrt(sequence_length)
|
271 |
+
|
272 |
+
# row self-attention
|
273 |
+
query_mv = rearrange(
|
274 |
+
query_mv,
|
275 |
+
"(b nv) (ih iw) h c -> (b nv ih) iw h c",
|
276 |
+
nv=self.num_views,
|
277 |
+
ih=height,
|
278 |
+
iw=width,
|
279 |
+
).transpose(1, 2)
|
280 |
+
key_mv = rearrange(
|
281 |
+
key_mv,
|
282 |
+
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
|
283 |
+
nv=self.num_views,
|
284 |
+
ih=height,
|
285 |
+
iw=width,
|
286 |
+
)
|
287 |
+
key_mv = (
|
288 |
+
key_mv.repeat_interleave(self.num_views, dim=0)
|
289 |
+
.view(batch_size * height, -1, attn.heads, head_dim)
|
290 |
+
.transpose(1, 2)
|
291 |
+
)
|
292 |
+
value_mv = rearrange(
|
293 |
+
value_mv,
|
294 |
+
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
|
295 |
+
nv=self.num_views,
|
296 |
+
ih=height,
|
297 |
+
iw=width,
|
298 |
+
)
|
299 |
+
value_mv = (
|
300 |
+
value_mv.repeat_interleave(self.num_views, dim=0)
|
301 |
+
.view(batch_size * height, -1, attn.heads, head_dim)
|
302 |
+
.transpose(1, 2)
|
303 |
+
)
|
304 |
+
|
305 |
+
hidden_states_mv = F.scaled_dot_product_attention(
|
306 |
+
query_mv,
|
307 |
+
key_mv,
|
308 |
+
value_mv,
|
309 |
+
dropout_p=0.0,
|
310 |
+
is_causal=False,
|
311 |
+
)
|
312 |
+
hidden_states_mv = rearrange(
|
313 |
+
hidden_states_mv,
|
314 |
+
"(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
|
315 |
+
nv=self.num_views,
|
316 |
+
ih=height,
|
317 |
+
)
|
318 |
+
hidden_states_mv = hidden_states_mv.to(query.dtype)
|
319 |
+
|
320 |
+
# linear proj
|
321 |
+
hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
|
322 |
+
# dropout
|
323 |
+
hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
|
324 |
+
|
325 |
+
if use_ref:
|
326 |
+
reference_hidden_states = ref_hidden_states[self.name]
|
327 |
+
|
328 |
+
key_ref = self.to_k_ref(reference_hidden_states)
|
329 |
+
value_ref = self.to_v_ref(reference_hidden_states)
|
330 |
+
|
331 |
+
query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
332 |
+
1, 2
|
333 |
+
)
|
334 |
+
key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
335 |
+
value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
336 |
+
1, 2
|
337 |
+
)
|
338 |
+
|
339 |
+
hidden_states_ref = F.scaled_dot_product_attention(
|
340 |
+
query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
|
341 |
+
)
|
342 |
+
|
343 |
+
hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
|
344 |
+
batch_size, -1, attn.heads * head_dim
|
345 |
+
)
|
346 |
+
hidden_states_ref = hidden_states_ref.to(query.dtype)
|
347 |
+
|
348 |
+
# linear proj
|
349 |
+
hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
|
350 |
+
# dropout
|
351 |
+
hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
|
352 |
+
|
353 |
+
# linear proj
|
354 |
+
hidden_states = attn.to_out[0](hidden_states)
|
355 |
+
# dropout
|
356 |
+
hidden_states = attn.to_out[1](hidden_states)
|
357 |
+
|
358 |
+
if use_mv:
|
359 |
+
hidden_states = hidden_states + hidden_states_mv * mv_scale
|
360 |
+
|
361 |
+
if use_ref:
|
362 |
+
hidden_states = hidden_states + hidden_states_ref * ref_scale
|
363 |
+
|
364 |
+
if input_ndim == 4:
|
365 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
366 |
+
batch_size, channel, height, width
|
367 |
+
)
|
368 |
+
|
369 |
+
if attn.residual_connection:
|
370 |
+
hidden_states = hidden_states + residual
|
371 |
+
|
372 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
373 |
+
|
374 |
+
return hidden_states
|
375 |
+
|
376 |
+
def set_num_views(self, num_views: int) -> None:
|
377 |
+
self.num_views = num_views
|
378 |
+
|
379 |
+
|
380 |
+
class DecoupledMVRowColSelfAttnProcessor2_0(torch.nn.Module):
|
381 |
+
r"""
|
382 |
+
Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
query_dim: int,
|
388 |
+
inner_dim: int,
|
389 |
+
num_views: int = 1,
|
390 |
+
name: Optional[str] = None,
|
391 |
+
use_mv: bool = True,
|
392 |
+
use_ref: bool = False,
|
393 |
+
):
|
394 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
395 |
+
raise ImportError(
|
396 |
+
"DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
397 |
+
)
|
398 |
+
|
399 |
+
super().__init__()
|
400 |
+
|
401 |
+
self.num_views = num_views
|
402 |
+
self.name = name # NOTE: need for image cross-attention
|
403 |
+
self.use_mv = use_mv
|
404 |
+
self.use_ref = use_ref
|
405 |
+
|
406 |
+
if self.use_mv:
|
407 |
+
self.to_q_mv = nn.Linear(
|
408 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
409 |
+
)
|
410 |
+
self.to_k_mv = nn.Linear(
|
411 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
412 |
+
)
|
413 |
+
self.to_v_mv = nn.Linear(
|
414 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
415 |
+
)
|
416 |
+
self.to_out_mv = nn.ModuleList(
|
417 |
+
[
|
418 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
419 |
+
nn.Dropout(0.0),
|
420 |
+
]
|
421 |
+
)
|
422 |
+
|
423 |
+
if self.use_ref:
|
424 |
+
self.to_q_ref = nn.Linear(
|
425 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
426 |
+
)
|
427 |
+
self.to_k_ref = nn.Linear(
|
428 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
429 |
+
)
|
430 |
+
self.to_v_ref = nn.Linear(
|
431 |
+
in_features=query_dim, out_features=inner_dim, bias=False
|
432 |
+
)
|
433 |
+
self.to_out_ref = nn.ModuleList(
|
434 |
+
[
|
435 |
+
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
|
436 |
+
nn.Dropout(0.0),
|
437 |
+
]
|
438 |
+
)
|
439 |
+
|
440 |
+
def __call__(
|
441 |
+
self,
|
442 |
+
attn: Attention,
|
443 |
+
hidden_states: torch.FloatTensor,
|
444 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
445 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
446 |
+
temb: Optional[torch.FloatTensor] = None,
|
447 |
+
mv_scale: float = 1.0,
|
448 |
+
ref_hidden_states: Optional[torch.FloatTensor] = None,
|
449 |
+
ref_scale: float = 1.0,
|
450 |
+
cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
|
451 |
+
use_mv: bool = True,
|
452 |
+
use_ref: bool = True,
|
453 |
+
num_views: Optional[int] = None,
|
454 |
+
*args,
|
455 |
+
**kwargs,
|
456 |
+
) -> torch.FloatTensor:
|
457 |
+
"""
|
458 |
+
New args:
|
459 |
+
mv_scale (float): scale for multi-view self-attention.
|
460 |
+
ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
|
461 |
+
ref_scale (float): scale for image cross-attention.
|
462 |
+
cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
|
463 |
+
|
464 |
+
"""
|
465 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
466 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
467 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
468 |
+
|
469 |
+
if num_views is not None:
|
470 |
+
self.num_views = num_views
|
471 |
+
|
472 |
+
# NEW: cache hidden states for reference unet
|
473 |
+
if cache_hidden_states is not None:
|
474 |
+
cache_hidden_states[self.name] = hidden_states.clone()
|
475 |
+
|
476 |
+
# NEW: whether to use multi-view attention and image cross-attention
|
477 |
+
use_mv = self.use_mv and use_mv
|
478 |
+
use_ref = self.use_ref and use_ref
|
479 |
+
|
480 |
+
residual = hidden_states
|
481 |
+
if attn.spatial_norm is not None:
|
482 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
483 |
+
|
484 |
+
input_ndim = hidden_states.ndim
|
485 |
+
|
486 |
+
if input_ndim == 4:
|
487 |
+
batch_size, channel, height, width = hidden_states.shape
|
488 |
+
hidden_states = hidden_states.view(
|
489 |
+
batch_size, channel, height * width
|
490 |
+
).transpose(1, 2)
|
491 |
+
|
492 |
+
batch_size, sequence_length, _ = (
|
493 |
+
hidden_states.shape
|
494 |
+
if encoder_hidden_states is None
|
495 |
+
else encoder_hidden_states.shape
|
496 |
+
)
|
497 |
+
|
498 |
+
if attention_mask is not None:
|
499 |
+
attention_mask = attn.prepare_attention_mask(
|
500 |
+
attention_mask, sequence_length, batch_size
|
501 |
+
)
|
502 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
503 |
+
# (batch, heads, source_length, target_length)
|
504 |
+
attention_mask = attention_mask.view(
|
505 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
506 |
+
)
|
507 |
+
|
508 |
+
if attn.group_norm is not None:
|
509 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
510 |
+
1, 2
|
511 |
+
)
|
512 |
+
|
513 |
+
query = attn.to_q(hidden_states)
|
514 |
+
|
515 |
+
# NEW: for decoupled multi-view attention
|
516 |
+
if use_mv:
|
517 |
+
query_mv = self.to_q_mv(hidden_states)
|
518 |
+
|
519 |
+
# NEW: for decoupled reference cross attention
|
520 |
+
if use_ref:
|
521 |
+
query_ref = self.to_q_ref(hidden_states)
|
522 |
+
|
523 |
+
if encoder_hidden_states is None:
|
524 |
+
encoder_hidden_states = hidden_states
|
525 |
+
elif attn.norm_cross:
|
526 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
527 |
+
encoder_hidden_states
|
528 |
+
)
|
529 |
+
|
530 |
+
key = attn.to_k(encoder_hidden_states)
|
531 |
+
value = attn.to_v(encoder_hidden_states)
|
532 |
+
|
533 |
+
inner_dim = key.shape[-1]
|
534 |
+
head_dim = inner_dim // attn.heads
|
535 |
+
|
536 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
537 |
+
|
538 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
539 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
540 |
+
|
541 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
542 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
543 |
+
hidden_states = F.scaled_dot_product_attention(
|
544 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
545 |
+
)
|
546 |
+
|
547 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
548 |
+
batch_size, -1, attn.heads * head_dim
|
549 |
+
)
|
550 |
+
hidden_states = hidden_states.to(query.dtype)
|
551 |
+
|
552 |
+
####### Decoupled multi-view self-attention ########
|
553 |
+
if use_mv:
|
554 |
+
key_mv = self.to_k_mv(encoder_hidden_states)
|
555 |
+
value_mv = self.to_v_mv(encoder_hidden_states)
|
556 |
+
|
557 |
+
query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
|
558 |
+
key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
|
559 |
+
value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
|
560 |
+
|
561 |
+
height = width = math.isqrt(sequence_length)
|
562 |
+
|
563 |
+
query_mv = rearrange(
|
564 |
+
query_mv,
|
565 |
+
"(b nv) (ih iw) h c -> b nv ih iw h c",
|
566 |
+
nv=self.num_views,
|
567 |
+
ih=height,
|
568 |
+
iw=width,
|
569 |
+
)
|
570 |
+
key_mv = rearrange(
|
571 |
+
key_mv,
|
572 |
+
"(b nv) (ih iw) h c -> b nv ih iw h c",
|
573 |
+
nv=self.num_views,
|
574 |
+
ih=height,
|
575 |
+
iw=width,
|
576 |
+
)
|
577 |
+
value_mv = rearrange(
|
578 |
+
value_mv,
|
579 |
+
"(b nv) (ih iw) h c -> b nv ih iw h c",
|
580 |
+
nv=self.num_views,
|
581 |
+
ih=height,
|
582 |
+
iw=width,
|
583 |
+
)
|
584 |
+
|
585 |
+
# row-wise attention for view 0123 (front, right, back, left)
|
586 |
+
query_mv_0123 = rearrange(
|
587 |
+
query_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
|
588 |
+
)
|
589 |
+
key_mv_0123 = rearrange(
|
590 |
+
key_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
|
591 |
+
)
|
592 |
+
value_mv_0123 = rearrange(
|
593 |
+
value_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
|
594 |
+
)
|
595 |
+
hidden_states_mv_0123 = F.scaled_dot_product_attention(
|
596 |
+
query_mv_0123,
|
597 |
+
key_mv_0123,
|
598 |
+
value_mv_0123,
|
599 |
+
dropout_p=0.0,
|
600 |
+
is_causal=False,
|
601 |
+
)
|
602 |
+
hidden_states_mv_0123 = rearrange(
|
603 |
+
hidden_states_mv_0123,
|
604 |
+
"(b ih) h (nv iw) c -> b nv (ih iw) (h c)",
|
605 |
+
ih=height,
|
606 |
+
iw=height,
|
607 |
+
)
|
608 |
+
|
609 |
+
# col-wise attention for view 0245 (front, back, top, bottom)
|
610 |
+
# flip first
|
611 |
+
query_mv_0245 = torch.cat(
|
612 |
+
[
|
613 |
+
torch.flip(query_mv[:, [0]], [3]), # horizontal flip
|
614 |
+
query_mv[:, [2, 4, 5]],
|
615 |
+
],
|
616 |
+
dim=1,
|
617 |
+
)
|
618 |
+
key_mv_0245 = torch.cat(
|
619 |
+
[
|
620 |
+
torch.flip(key_mv[:, [0]], [3]), # horizontal flip
|
621 |
+
key_mv[:, [2, 4, 5]],
|
622 |
+
],
|
623 |
+
dim=1,
|
624 |
+
)
|
625 |
+
value_mv_0245 = torch.cat(
|
626 |
+
[
|
627 |
+
torch.flip(value_mv[:, [0]], [3]), # horizontal flip
|
628 |
+
value_mv[:, [2, 4, 5]],
|
629 |
+
],
|
630 |
+
dim=1,
|
631 |
+
)
|
632 |
+
# attention
|
633 |
+
query_mv_0245 = rearrange(
|
634 |
+
query_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c"
|
635 |
+
)
|
636 |
+
key_mv_0245 = rearrange(key_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c")
|
637 |
+
value_mv_0245 = rearrange(
|
638 |
+
value_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c"
|
639 |
+
)
|
640 |
+
hidden_states_mv_0245 = F.scaled_dot_product_attention(
|
641 |
+
query_mv_0245,
|
642 |
+
key_mv_0245,
|
643 |
+
value_mv_0245,
|
644 |
+
dropout_p=0.0,
|
645 |
+
is_causal=False,
|
646 |
+
)
|
647 |
+
# flip back
|
648 |
+
hidden_states_mv_0245 = rearrange(
|
649 |
+
hidden_states_mv_0245,
|
650 |
+
"(b iw) h (nv ih) c -> b nv ih iw (h c)",
|
651 |
+
ih=height,
|
652 |
+
iw=height,
|
653 |
+
)
|
654 |
+
hidden_states_mv_0245 = torch.cat(
|
655 |
+
[
|
656 |
+
torch.flip(hidden_states_mv_0245[:, [0]], [3]), # horizontal flip
|
657 |
+
hidden_states_mv_0245[:, [1, 2, 3]],
|
658 |
+
],
|
659 |
+
dim=1,
|
660 |
+
)
|
661 |
+
hidden_states_mv_0245 = hidden_states_mv_0245.view(
|
662 |
+
hidden_states_mv_0245.shape[0],
|
663 |
+
hidden_states_mv_0245.shape[1],
|
664 |
+
-1,
|
665 |
+
hidden_states_mv_0245.shape[-1],
|
666 |
+
)
|
667 |
+
|
668 |
+
# combine row and col
|
669 |
+
hidden_states_mv = torch.stack(
|
670 |
+
[
|
671 |
+
(hidden_states_mv_0123[:, 0] + hidden_states_mv_0245[:, 0]) / 2,
|
672 |
+
hidden_states_mv_0123[:, 1],
|
673 |
+
(hidden_states_mv_0123[:, 2] + hidden_states_mv_0245[:, 1]) / 2,
|
674 |
+
hidden_states_mv_0123[:, 3],
|
675 |
+
hidden_states_mv_0245[:, 2],
|
676 |
+
hidden_states_mv_0245[:, 3],
|
677 |
+
],
|
678 |
+
dim=1,
|
679 |
+
)
|
680 |
+
|
681 |
+
hidden_states_mv = hidden_states_mv.view(
|
682 |
+
-1, hidden_states_mv.shape[-2], hidden_states_mv.shape[-1]
|
683 |
+
)
|
684 |
+
hidden_states_mv = hidden_states_mv.to(query.dtype)
|
685 |
+
|
686 |
+
# linear proj
|
687 |
+
hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
|
688 |
+
# dropout
|
689 |
+
hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
|
690 |
+
|
691 |
+
if use_ref:
|
692 |
+
reference_hidden_states = ref_hidden_states[self.name]
|
693 |
+
|
694 |
+
key_ref = self.to_k_ref(reference_hidden_states)
|
695 |
+
value_ref = self.to_v_ref(reference_hidden_states)
|
696 |
+
|
697 |
+
query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
698 |
+
1, 2
|
699 |
+
)
|
700 |
+
key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
701 |
+
value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
|
702 |
+
1, 2
|
703 |
+
)
|
704 |
+
|
705 |
+
hidden_states_ref = F.scaled_dot_product_attention(
|
706 |
+
query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
|
707 |
+
)
|
708 |
+
|
709 |
+
hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
|
710 |
+
batch_size, -1, attn.heads * head_dim
|
711 |
+
)
|
712 |
+
hidden_states_ref = hidden_states_ref.to(query.dtype)
|
713 |
+
|
714 |
+
# linear proj
|
715 |
+
hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
|
716 |
+
# dropout
|
717 |
+
hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
|
718 |
+
|
719 |
+
# linear proj
|
720 |
+
hidden_states = attn.to_out[0](hidden_states)
|
721 |
+
# dropout
|
722 |
+
hidden_states = attn.to_out[1](hidden_states)
|
723 |
+
|
724 |
+
if use_mv:
|
725 |
+
hidden_states = hidden_states + hidden_states_mv * mv_scale
|
726 |
+
|
727 |
+
if use_ref:
|
728 |
+
hidden_states = hidden_states + hidden_states_ref * ref_scale
|
729 |
+
|
730 |
+
if input_ndim == 4:
|
731 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
732 |
+
batch_size, channel, height, width
|
733 |
+
)
|
734 |
+
|
735 |
+
if attn.residual_connection:
|
736 |
+
hidden_states = hidden_states + residual
|
737 |
+
|
738 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
739 |
+
|
740 |
+
return hidden_states
|
741 |
+
|
742 |
+
def set_num_views(self, num_views: int) -> None:
|
743 |
+
self.num_views = num_views
|
mvadapter/pipelines/pipeline_mvadapter_i2mv_sd.py
ADDED
@@ -0,0 +1,777 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
16 |
+
|
17 |
+
import PIL
|
18 |
+
import torch
|
19 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
20 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
21 |
+
from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
|
22 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
23 |
+
StableDiffusionPipelineOutput,
|
24 |
+
)
|
25 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
26 |
+
StableDiffusionPipeline,
|
27 |
+
rescale_noise_cfg,
|
28 |
+
retrieve_timesteps,
|
29 |
+
)
|
30 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
31 |
+
StableDiffusionSafetyChecker,
|
32 |
+
)
|
33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
34 |
+
from diffusers.utils import deprecate, is_torch_xla_available, logging
|
35 |
+
from diffusers.utils.torch_utils import randn_tensor
|
36 |
+
from transformers import (
|
37 |
+
CLIPImageProcessor,
|
38 |
+
CLIPTextModel,
|
39 |
+
CLIPTokenizer,
|
40 |
+
CLIPVisionModelWithProjection,
|
41 |
+
)
|
42 |
+
|
43 |
+
from ..loaders import CustomAdapterMixin
|
44 |
+
from ..models.attention_processor import (
|
45 |
+
DecoupledMVRowSelfAttnProcessor2_0,
|
46 |
+
set_unet_2d_condition_attn_processor,
|
47 |
+
)
|
48 |
+
|
49 |
+
if is_torch_xla_available():
|
50 |
+
import torch_xla.core.xla_model as xm
|
51 |
+
|
52 |
+
XLA_AVAILABLE = True
|
53 |
+
else:
|
54 |
+
XLA_AVAILABLE = False
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
57 |
+
|
58 |
+
|
59 |
+
def retrieve_latents(
|
60 |
+
encoder_output: torch.Tensor,
|
61 |
+
generator: Optional[torch.Generator] = None,
|
62 |
+
sample_mode: str = "sample",
|
63 |
+
):
|
64 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
65 |
+
return encoder_output.latent_dist.sample(generator)
|
66 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
67 |
+
return encoder_output.latent_dist.mode()
|
68 |
+
elif hasattr(encoder_output, "latents"):
|
69 |
+
return encoder_output.latents
|
70 |
+
else:
|
71 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
72 |
+
|
73 |
+
|
74 |
+
class MVAdapterI2MVSDPipeline(StableDiffusionPipeline, CustomAdapterMixin):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
vae: AutoencoderKL,
|
78 |
+
text_encoder: CLIPTextModel,
|
79 |
+
tokenizer: CLIPTokenizer,
|
80 |
+
unet: UNet2DConditionModel,
|
81 |
+
scheduler: KarrasDiffusionSchedulers,
|
82 |
+
safety_checker: StableDiffusionSafetyChecker,
|
83 |
+
feature_extractor: CLIPImageProcessor,
|
84 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
85 |
+
requires_safety_checker: bool = False,
|
86 |
+
):
|
87 |
+
super().__init__(
|
88 |
+
vae=vae,
|
89 |
+
text_encoder=text_encoder,
|
90 |
+
tokenizer=tokenizer,
|
91 |
+
unet=unet,
|
92 |
+
scheduler=scheduler,
|
93 |
+
safety_checker=safety_checker,
|
94 |
+
feature_extractor=feature_extractor,
|
95 |
+
image_encoder=image_encoder,
|
96 |
+
requires_safety_checker=requires_safety_checker,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.control_image_processor = VaeImageProcessor(
|
100 |
+
vae_scale_factor=self.vae_scale_factor,
|
101 |
+
do_convert_rgb=True,
|
102 |
+
do_normalize=False,
|
103 |
+
)
|
104 |
+
|
105 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
|
106 |
+
def prepare_image_latents(
|
107 |
+
self,
|
108 |
+
image,
|
109 |
+
timestep,
|
110 |
+
batch_size,
|
111 |
+
num_images_per_prompt,
|
112 |
+
dtype,
|
113 |
+
device,
|
114 |
+
generator=None,
|
115 |
+
add_noise=True,
|
116 |
+
):
|
117 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
118 |
+
raise ValueError(
|
119 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
120 |
+
)
|
121 |
+
|
122 |
+
image = image.to(device=device, dtype=dtype)
|
123 |
+
|
124 |
+
batch_size = batch_size * num_images_per_prompt
|
125 |
+
|
126 |
+
if image.shape[1] == 4:
|
127 |
+
init_latents = image
|
128 |
+
|
129 |
+
else:
|
130 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
131 |
+
raise ValueError(
|
132 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
133 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
134 |
+
)
|
135 |
+
|
136 |
+
elif isinstance(generator, list):
|
137 |
+
init_latents = [
|
138 |
+
retrieve_latents(
|
139 |
+
self.vae.encode(image[i : i + 1]), generator=generator[i]
|
140 |
+
)
|
141 |
+
for i in range(batch_size)
|
142 |
+
]
|
143 |
+
init_latents = torch.cat(init_latents, dim=0)
|
144 |
+
else:
|
145 |
+
init_latents = retrieve_latents(
|
146 |
+
self.vae.encode(image), generator=generator
|
147 |
+
)
|
148 |
+
|
149 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
150 |
+
|
151 |
+
if (
|
152 |
+
batch_size > init_latents.shape[0]
|
153 |
+
and batch_size % init_latents.shape[0] == 0
|
154 |
+
):
|
155 |
+
# expand init_latents for batch_size
|
156 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
157 |
+
init_latents = torch.cat(
|
158 |
+
[init_latents] * additional_image_per_prompt, dim=0
|
159 |
+
)
|
160 |
+
elif (
|
161 |
+
batch_size > init_latents.shape[0]
|
162 |
+
and batch_size % init_latents.shape[0] != 0
|
163 |
+
):
|
164 |
+
raise ValueError(
|
165 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
init_latents = torch.cat([init_latents], dim=0)
|
169 |
+
|
170 |
+
if add_noise:
|
171 |
+
shape = init_latents.shape
|
172 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
173 |
+
# get latents
|
174 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
175 |
+
|
176 |
+
latents = init_latents
|
177 |
+
|
178 |
+
return latents
|
179 |
+
|
180 |
+
def prepare_control_image(
|
181 |
+
self,
|
182 |
+
image,
|
183 |
+
width,
|
184 |
+
height,
|
185 |
+
batch_size,
|
186 |
+
num_images_per_prompt,
|
187 |
+
device,
|
188 |
+
dtype,
|
189 |
+
do_classifier_free_guidance=False,
|
190 |
+
num_empty_images=0, # for concat in batch like ImageDream
|
191 |
+
):
|
192 |
+
assert hasattr(
|
193 |
+
self, "control_image_processor"
|
194 |
+
), "control_image_processor is not initialized"
|
195 |
+
|
196 |
+
image = self.control_image_processor.preprocess(
|
197 |
+
image, height=height, width=width
|
198 |
+
).to(dtype=torch.float32)
|
199 |
+
|
200 |
+
if num_empty_images > 0:
|
201 |
+
image = torch.cat(
|
202 |
+
[image, torch.zeros_like(image[:num_empty_images])], dim=0
|
203 |
+
)
|
204 |
+
|
205 |
+
image_batch_size = image.shape[0]
|
206 |
+
|
207 |
+
if image_batch_size == 1:
|
208 |
+
repeat_by = batch_size
|
209 |
+
else:
|
210 |
+
# image batch size is the same as prompt batch size
|
211 |
+
repeat_by = num_images_per_prompt # always 1 for control image
|
212 |
+
|
213 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
214 |
+
|
215 |
+
image = image.to(device=device, dtype=dtype)
|
216 |
+
|
217 |
+
if do_classifier_free_guidance:
|
218 |
+
image = torch.cat([image] * 2)
|
219 |
+
|
220 |
+
return image
|
221 |
+
|
222 |
+
@torch.no_grad()
|
223 |
+
def __call__(
|
224 |
+
self,
|
225 |
+
prompt: Union[str, List[str]] = None,
|
226 |
+
height: Optional[int] = None,
|
227 |
+
width: Optional[int] = None,
|
228 |
+
num_inference_steps: int = 50,
|
229 |
+
timesteps: List[int] = None,
|
230 |
+
sigmas: List[float] = None,
|
231 |
+
guidance_scale: float = 7.5,
|
232 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
233 |
+
num_images_per_prompt: Optional[int] = 1,
|
234 |
+
eta: float = 0.0,
|
235 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
236 |
+
latents: Optional[torch.Tensor] = None,
|
237 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
238 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
239 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
240 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
241 |
+
output_type: Optional[str] = "pil",
|
242 |
+
return_dict: bool = True,
|
243 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
244 |
+
guidance_rescale: float = 0.0,
|
245 |
+
clip_skip: Optional[int] = None,
|
246 |
+
callback_on_step_end: Optional[
|
247 |
+
Union[
|
248 |
+
Callable[[int, int, Dict], None],
|
249 |
+
PipelineCallback,
|
250 |
+
MultiPipelineCallbacks,
|
251 |
+
]
|
252 |
+
] = None,
|
253 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
254 |
+
# NEW
|
255 |
+
mv_scale: float = 1.0,
|
256 |
+
# Camera or geometry condition
|
257 |
+
control_image: Optional[PipelineImageInput] = None,
|
258 |
+
control_conditioning_scale: Optional[float] = 1.0,
|
259 |
+
control_conditioning_factor: float = 1.0,
|
260 |
+
# Image condition
|
261 |
+
reference_image: Optional[PipelineImageInput] = None,
|
262 |
+
reference_conditioning_scale: Optional[float] = 1.0,
|
263 |
+
# Optional. controlnet
|
264 |
+
controlnet_image: Optional[PipelineImageInput] = None,
|
265 |
+
controlnet_conditioning_scale: Optional[float] = 1.0,
|
266 |
+
**kwargs,
|
267 |
+
):
|
268 |
+
r"""
|
269 |
+
The call function to the pipeline for generation.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
prompt (`str` or `List[str]`, *optional*):
|
273 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
274 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
275 |
+
The height in pixels of the generated image.
|
276 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
277 |
+
The width in pixels of the generated image.
|
278 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
279 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
280 |
+
expense of slower inference.
|
281 |
+
timesteps (`List[int]`, *optional*):
|
282 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
283 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
284 |
+
passed will be used. Must be in descending order.
|
285 |
+
sigmas (`List[float]`, *optional*):
|
286 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
287 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
288 |
+
will be used.
|
289 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
290 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
291 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
292 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
293 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
294 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
295 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
296 |
+
The number of images to generate per prompt.
|
297 |
+
eta (`float`, *optional*, defaults to 0.0):
|
298 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
299 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
300 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
301 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
302 |
+
generation deterministic.
|
303 |
+
latents (`torch.Tensor`, *optional*):
|
304 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
305 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
306 |
+
tensor is generated by sampling using the supplied random `generator`.
|
307 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
308 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
309 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
310 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
311 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
312 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
313 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
314 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
315 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
316 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
317 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
318 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
319 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
320 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
321 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
322 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
323 |
+
plain tuple.
|
324 |
+
cross_attention_kwargs (`dict`, *optional*):
|
325 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
326 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
327 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
328 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
329 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
330 |
+
using zero terminal SNR.
|
331 |
+
clip_skip (`int`, *optional*):
|
332 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
333 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
334 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
335 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
336 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
337 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
338 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
339 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
340 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
341 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
342 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
343 |
+
|
344 |
+
Examples:
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
348 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
349 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
350 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
351 |
+
"not-safe-for-work" (nsfw) content.
|
352 |
+
"""
|
353 |
+
|
354 |
+
callback = kwargs.pop("callback", None)
|
355 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
356 |
+
|
357 |
+
if callback is not None:
|
358 |
+
deprecate(
|
359 |
+
"callback",
|
360 |
+
"1.0.0",
|
361 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
362 |
+
)
|
363 |
+
if callback_steps is not None:
|
364 |
+
deprecate(
|
365 |
+
"callback_steps",
|
366 |
+
"1.0.0",
|
367 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
368 |
+
)
|
369 |
+
|
370 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
371 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
372 |
+
|
373 |
+
# 0. Default height and width to unet
|
374 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
375 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
376 |
+
# to deal with lora scaling and other possible forward hooks
|
377 |
+
|
378 |
+
# 1. Check inputs. Raise error if not correct
|
379 |
+
self.check_inputs(
|
380 |
+
prompt,
|
381 |
+
height,
|
382 |
+
width,
|
383 |
+
callback_steps,
|
384 |
+
negative_prompt,
|
385 |
+
prompt_embeds,
|
386 |
+
negative_prompt_embeds,
|
387 |
+
ip_adapter_image,
|
388 |
+
ip_adapter_image_embeds,
|
389 |
+
callback_on_step_end_tensor_inputs,
|
390 |
+
)
|
391 |
+
|
392 |
+
self._guidance_scale = guidance_scale
|
393 |
+
self._guidance_rescale = guidance_rescale
|
394 |
+
self._clip_skip = clip_skip
|
395 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
396 |
+
self._interrupt = False
|
397 |
+
|
398 |
+
# 2. Define call parameters
|
399 |
+
if prompt is not None and isinstance(prompt, str):
|
400 |
+
batch_size = 1
|
401 |
+
elif prompt is not None and isinstance(prompt, list):
|
402 |
+
batch_size = len(prompt)
|
403 |
+
else:
|
404 |
+
batch_size = prompt_embeds.shape[0]
|
405 |
+
|
406 |
+
device = self._execution_device
|
407 |
+
|
408 |
+
# 3. Encode input prompt
|
409 |
+
lora_scale = (
|
410 |
+
self.cross_attention_kwargs.get("scale", None)
|
411 |
+
if self.cross_attention_kwargs is not None
|
412 |
+
else None
|
413 |
+
)
|
414 |
+
|
415 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
416 |
+
prompt,
|
417 |
+
device,
|
418 |
+
num_images_per_prompt,
|
419 |
+
self.do_classifier_free_guidance,
|
420 |
+
negative_prompt,
|
421 |
+
prompt_embeds=prompt_embeds,
|
422 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
423 |
+
lora_scale=lora_scale,
|
424 |
+
clip_skip=self.clip_skip,
|
425 |
+
)
|
426 |
+
|
427 |
+
# For classifier free guidance, we need to do two forward passes.
|
428 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
429 |
+
# to avoid doing two forward passes
|
430 |
+
if self.do_classifier_free_guidance:
|
431 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
432 |
+
|
433 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
434 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
435 |
+
ip_adapter_image,
|
436 |
+
ip_adapter_image_embeds,
|
437 |
+
device,
|
438 |
+
batch_size * num_images_per_prompt,
|
439 |
+
self.do_classifier_free_guidance,
|
440 |
+
)
|
441 |
+
|
442 |
+
# 4. Prepare timesteps
|
443 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
444 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
445 |
+
)
|
446 |
+
|
447 |
+
# 5. Prepare latent variables
|
448 |
+
num_channels_latents = self.unet.config.in_channels
|
449 |
+
latents = self.prepare_latents(
|
450 |
+
batch_size * num_images_per_prompt,
|
451 |
+
num_channels_latents,
|
452 |
+
height,
|
453 |
+
width,
|
454 |
+
prompt_embeds.dtype,
|
455 |
+
device,
|
456 |
+
generator,
|
457 |
+
latents,
|
458 |
+
)
|
459 |
+
|
460 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
461 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
462 |
+
|
463 |
+
# 6.1 Add image embeds for IP-Adapter
|
464 |
+
added_cond_kwargs = (
|
465 |
+
{"image_embeds": image_embeds}
|
466 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
467 |
+
else None
|
468 |
+
)
|
469 |
+
|
470 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
471 |
+
timestep_cond = None
|
472 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
473 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
474 |
+
batch_size * num_images_per_prompt
|
475 |
+
)
|
476 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
477 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
478 |
+
).to(device=device, dtype=latents.dtype)
|
479 |
+
|
480 |
+
# Preprocess reference image
|
481 |
+
reference_image = self.image_processor.preprocess(reference_image)
|
482 |
+
reference_latents = self.prepare_image_latents(
|
483 |
+
reference_image,
|
484 |
+
timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
|
485 |
+
batch_size,
|
486 |
+
1,
|
487 |
+
prompt_embeds.dtype,
|
488 |
+
device,
|
489 |
+
generator,
|
490 |
+
add_noise=False,
|
491 |
+
)
|
492 |
+
|
493 |
+
ref_timesteps = torch.zeros_like(timesteps[0])
|
494 |
+
ref_hidden_states = {}
|
495 |
+
with torch.no_grad():
|
496 |
+
self.unet(
|
497 |
+
reference_latents,
|
498 |
+
ref_timesteps,
|
499 |
+
encoder_hidden_states=prompt_embeds[-1:],
|
500 |
+
cross_attention_kwargs={
|
501 |
+
"cache_hidden_states": ref_hidden_states,
|
502 |
+
"use_mv": False,
|
503 |
+
"use_ref": False,
|
504 |
+
},
|
505 |
+
return_dict=False,
|
506 |
+
)
|
507 |
+
ref_hidden_states = {
|
508 |
+
k: v.repeat_interleave(num_images_per_prompt, dim=0)
|
509 |
+
for k, v in ref_hidden_states.items()
|
510 |
+
}
|
511 |
+
if self.do_classifier_free_guidance:
|
512 |
+
ref_hidden_states = {
|
513 |
+
k: torch.cat([torch.zeros_like(v), v], dim=0)
|
514 |
+
for k, v in ref_hidden_states.items()
|
515 |
+
}
|
516 |
+
|
517 |
+
cross_attention_kwargs = {
|
518 |
+
"num_views": num_images_per_prompt,
|
519 |
+
"mv_scale": mv_scale,
|
520 |
+
"ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()},
|
521 |
+
"ref_scale": reference_conditioning_scale,
|
522 |
+
**(self.cross_attention_kwargs or {}),
|
523 |
+
}
|
524 |
+
|
525 |
+
# Preprocess control image
|
526 |
+
control_image_feature = self.prepare_control_image(
|
527 |
+
image=control_image,
|
528 |
+
width=width,
|
529 |
+
height=height,
|
530 |
+
batch_size=batch_size * num_images_per_prompt,
|
531 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
532 |
+
device=device,
|
533 |
+
dtype=latents.dtype,
|
534 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
535 |
+
)
|
536 |
+
control_image_feature = control_image_feature.to(
|
537 |
+
device=device, dtype=latents.dtype
|
538 |
+
)
|
539 |
+
|
540 |
+
adapter_state = self.cond_encoder(control_image_feature)
|
541 |
+
for i, state in enumerate(adapter_state):
|
542 |
+
adapter_state[i] = state * control_conditioning_scale
|
543 |
+
|
544 |
+
# Preprocess controlnet image if provided
|
545 |
+
do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
|
546 |
+
if do_controlnet:
|
547 |
+
controlnet_image = self.prepare_control_image(
|
548 |
+
image=controlnet_image,
|
549 |
+
width=width,
|
550 |
+
height=height,
|
551 |
+
batch_size=batch_size * num_images_per_prompt,
|
552 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
553 |
+
device=device,
|
554 |
+
dtype=latents.dtype,
|
555 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
556 |
+
)
|
557 |
+
controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
|
558 |
+
|
559 |
+
# 7. Denoising loop
|
560 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
561 |
+
self._num_timesteps = len(timesteps)
|
562 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
563 |
+
for i, t in enumerate(timesteps):
|
564 |
+
if self.interrupt:
|
565 |
+
continue
|
566 |
+
|
567 |
+
# expand the latents if we are doing classifier free guidance
|
568 |
+
latent_model_input = (
|
569 |
+
torch.cat([latents] * 2)
|
570 |
+
if self.do_classifier_free_guidance
|
571 |
+
else latents
|
572 |
+
)
|
573 |
+
latent_model_input = self.scheduler.scale_model_input(
|
574 |
+
latent_model_input, t
|
575 |
+
)
|
576 |
+
|
577 |
+
if i < int(num_inference_steps * control_conditioning_factor):
|
578 |
+
down_intrablock_additional_residuals = [
|
579 |
+
state.clone() for state in adapter_state
|
580 |
+
]
|
581 |
+
else:
|
582 |
+
down_intrablock_additional_residuals = None
|
583 |
+
|
584 |
+
unet_add_kwargs = {}
|
585 |
+
|
586 |
+
# Do controlnet if provided
|
587 |
+
if do_controlnet:
|
588 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
589 |
+
latent_model_input,
|
590 |
+
t,
|
591 |
+
encoder_hidden_states=prompt_embeds,
|
592 |
+
controlnet_cond=controlnet_image,
|
593 |
+
conditioning_scale=controlnet_conditioning_scale,
|
594 |
+
guess_mode=False,
|
595 |
+
added_cond_kwargs=added_cond_kwargs,
|
596 |
+
return_dict=False,
|
597 |
+
)
|
598 |
+
unet_add_kwargs.update(
|
599 |
+
{
|
600 |
+
"down_block_additional_residuals": down_block_res_samples,
|
601 |
+
"mid_block_additional_residual": mid_block_res_sample,
|
602 |
+
}
|
603 |
+
)
|
604 |
+
|
605 |
+
# predict the noise residual
|
606 |
+
noise_pred = self.unet(
|
607 |
+
latent_model_input,
|
608 |
+
t,
|
609 |
+
encoder_hidden_states=prompt_embeds,
|
610 |
+
timestep_cond=timestep_cond,
|
611 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
612 |
+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
613 |
+
added_cond_kwargs=added_cond_kwargs,
|
614 |
+
return_dict=False,
|
615 |
+
**unet_add_kwargs,
|
616 |
+
)[0]
|
617 |
+
|
618 |
+
# perform guidance
|
619 |
+
if self.do_classifier_free_guidance:
|
620 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
621 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
622 |
+
noise_pred_text - noise_pred_uncond
|
623 |
+
)
|
624 |
+
|
625 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
626 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
627 |
+
noise_pred = rescale_noise_cfg(
|
628 |
+
noise_pred,
|
629 |
+
noise_pred_text,
|
630 |
+
guidance_rescale=self.guidance_rescale,
|
631 |
+
)
|
632 |
+
|
633 |
+
# compute the previous noisy sample x_t -> x_t-1
|
634 |
+
latents = self.scheduler.step(
|
635 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
636 |
+
)[0]
|
637 |
+
|
638 |
+
if callback_on_step_end is not None:
|
639 |
+
callback_kwargs = {}
|
640 |
+
for k in callback_on_step_end_tensor_inputs:
|
641 |
+
callback_kwargs[k] = locals()[k]
|
642 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
643 |
+
|
644 |
+
latents = callback_outputs.pop("latents", latents)
|
645 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
646 |
+
negative_prompt_embeds = callback_outputs.pop(
|
647 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
648 |
+
)
|
649 |
+
|
650 |
+
# call the callback, if provided
|
651 |
+
if i == len(timesteps) - 1 or (
|
652 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
653 |
+
):
|
654 |
+
progress_bar.update()
|
655 |
+
if callback is not None and i % callback_steps == 0:
|
656 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
657 |
+
callback(step_idx, t, latents)
|
658 |
+
|
659 |
+
if XLA_AVAILABLE:
|
660 |
+
xm.mark_step()
|
661 |
+
|
662 |
+
if not output_type == "latent":
|
663 |
+
image = self.vae.decode(
|
664 |
+
latents / self.vae.config.scaling_factor,
|
665 |
+
return_dict=False,
|
666 |
+
generator=generator,
|
667 |
+
)[0]
|
668 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
669 |
+
image, device, prompt_embeds.dtype
|
670 |
+
)
|
671 |
+
else:
|
672 |
+
image = latents
|
673 |
+
has_nsfw_concept = None
|
674 |
+
|
675 |
+
if has_nsfw_concept is None:
|
676 |
+
do_denormalize = [True] * image.shape[0]
|
677 |
+
else:
|
678 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
679 |
+
image = self.image_processor.postprocess(
|
680 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
681 |
+
)
|
682 |
+
|
683 |
+
# Offload all models
|
684 |
+
self.maybe_free_model_hooks()
|
685 |
+
|
686 |
+
if not return_dict:
|
687 |
+
return (image, has_nsfw_concept)
|
688 |
+
|
689 |
+
return StableDiffusionPipelineOutput(
|
690 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
691 |
+
)
|
692 |
+
|
693 |
+
### NEW: adapters ###
|
694 |
+
def _init_custom_adapter(
|
695 |
+
self,
|
696 |
+
# Multi-view adapter
|
697 |
+
num_views: int = 1,
|
698 |
+
self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
|
699 |
+
# Condition encoder
|
700 |
+
cond_in_channels: int = 6,
|
701 |
+
# For training
|
702 |
+
copy_attn_weights: bool = True,
|
703 |
+
zero_init_module_keys: List[str] = [],
|
704 |
+
):
|
705 |
+
# Condition encoder
|
706 |
+
self.cond_encoder = T2IAdapter(
|
707 |
+
in_channels=cond_in_channels,
|
708 |
+
channels=self.unet.config.block_out_channels,
|
709 |
+
num_res_blocks=self.unet.config.layers_per_block,
|
710 |
+
downscale_factor=8,
|
711 |
+
)
|
712 |
+
|
713 |
+
# set custom attn processor for multi-view attention
|
714 |
+
self.unet: UNet2DConditionModel
|
715 |
+
set_unet_2d_condition_attn_processor(
|
716 |
+
self.unet,
|
717 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
718 |
+
query_dim=hs,
|
719 |
+
inner_dim=hs,
|
720 |
+
num_views=num_views,
|
721 |
+
name=name,
|
722 |
+
use_mv=True,
|
723 |
+
use_ref=True,
|
724 |
+
),
|
725 |
+
set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
726 |
+
query_dim=hs,
|
727 |
+
inner_dim=hs,
|
728 |
+
num_views=num_views,
|
729 |
+
name=name,
|
730 |
+
use_mv=False,
|
731 |
+
use_ref=False,
|
732 |
+
),
|
733 |
+
)
|
734 |
+
|
735 |
+
# copy decoupled attention weights from original unet
|
736 |
+
if copy_attn_weights:
|
737 |
+
state_dict = self.unet.state_dict()
|
738 |
+
for key in state_dict.keys():
|
739 |
+
if "_mv" in key:
|
740 |
+
compatible_key = key.replace("_mv", "").replace("processor.", "")
|
741 |
+
elif "_ref" in key:
|
742 |
+
compatible_key = key.replace("_ref", "").replace("processor.", "")
|
743 |
+
else:
|
744 |
+
compatible_key = key
|
745 |
+
|
746 |
+
is_zero_init_key = any([k in key for k in zero_init_module_keys])
|
747 |
+
if is_zero_init_key:
|
748 |
+
state_dict[key] = torch.zeros_like(state_dict[compatible_key])
|
749 |
+
else:
|
750 |
+
state_dict[key] = state_dict[compatible_key].clone()
|
751 |
+
self.unet.load_state_dict(state_dict)
|
752 |
+
|
753 |
+
def _load_custom_adapter(self, state_dict):
|
754 |
+
self.unet.load_state_dict(state_dict, strict=False)
|
755 |
+
self.cond_encoder.load_state_dict(state_dict, strict=False)
|
756 |
+
|
757 |
+
def _save_custom_adapter(
|
758 |
+
self,
|
759 |
+
include_keys: Optional[List[str]] = None,
|
760 |
+
exclude_keys: Optional[List[str]] = None,
|
761 |
+
):
|
762 |
+
def include_fn(k):
|
763 |
+
is_included = False
|
764 |
+
|
765 |
+
if include_keys is not None:
|
766 |
+
is_included = is_included or any([key in k for key in include_keys])
|
767 |
+
if exclude_keys is not None:
|
768 |
+
is_included = is_included and not any(
|
769 |
+
[key in k for key in exclude_keys]
|
770 |
+
)
|
771 |
+
|
772 |
+
return is_included
|
773 |
+
|
774 |
+
state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
|
775 |
+
state_dict.update(self.cond_encoder.state_dict())
|
776 |
+
|
777 |
+
return state_dict
|
mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py
ADDED
@@ -0,0 +1,962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import PIL
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
23 |
+
from diffusers.models import (
|
24 |
+
AutoencoderKL,
|
25 |
+
ImageProjection,
|
26 |
+
T2IAdapter,
|
27 |
+
UNet2DConditionModel,
|
28 |
+
)
|
29 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
|
30 |
+
StableDiffusionXLPipelineOutput,
|
31 |
+
)
|
32 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
33 |
+
StableDiffusionXLPipeline,
|
34 |
+
rescale_noise_cfg,
|
35 |
+
retrieve_timesteps,
|
36 |
+
)
|
37 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
38 |
+
from diffusers.utils import deprecate, logging
|
39 |
+
from diffusers.utils.torch_utils import randn_tensor
|
40 |
+
from einops import rearrange
|
41 |
+
from transformers import (
|
42 |
+
CLIPImageProcessor,
|
43 |
+
CLIPTextModel,
|
44 |
+
CLIPTextModelWithProjection,
|
45 |
+
CLIPTokenizer,
|
46 |
+
CLIPVisionModelWithProjection,
|
47 |
+
)
|
48 |
+
|
49 |
+
from ..loaders import CustomAdapterMixin
|
50 |
+
from ..models.attention_processor import (
|
51 |
+
DecoupledMVRowSelfAttnProcessor2_0,
|
52 |
+
set_unet_2d_condition_attn_processor,
|
53 |
+
)
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
+
|
57 |
+
|
58 |
+
def retrieve_latents(
|
59 |
+
encoder_output: torch.Tensor,
|
60 |
+
generator: Optional[torch.Generator] = None,
|
61 |
+
sample_mode: str = "sample",
|
62 |
+
):
|
63 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
64 |
+
return encoder_output.latent_dist.sample(generator)
|
65 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
66 |
+
return encoder_output.latent_dist.mode()
|
67 |
+
elif hasattr(encoder_output, "latents"):
|
68 |
+
return encoder_output.latents
|
69 |
+
else:
|
70 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
71 |
+
|
72 |
+
|
73 |
+
class MVAdapterI2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
vae: AutoencoderKL,
|
77 |
+
text_encoder: CLIPTextModel,
|
78 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
79 |
+
tokenizer: CLIPTokenizer,
|
80 |
+
tokenizer_2: CLIPTokenizer,
|
81 |
+
unet: UNet2DConditionModel,
|
82 |
+
scheduler: KarrasDiffusionSchedulers,
|
83 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
84 |
+
feature_extractor: CLIPImageProcessor = None,
|
85 |
+
force_zeros_for_empty_prompt: bool = True,
|
86 |
+
add_watermarker: Optional[bool] = None,
|
87 |
+
):
|
88 |
+
super().__init__(
|
89 |
+
vae=vae,
|
90 |
+
text_encoder=text_encoder,
|
91 |
+
text_encoder_2=text_encoder_2,
|
92 |
+
tokenizer=tokenizer,
|
93 |
+
tokenizer_2=tokenizer_2,
|
94 |
+
unet=unet,
|
95 |
+
scheduler=scheduler,
|
96 |
+
image_encoder=image_encoder,
|
97 |
+
feature_extractor=feature_extractor,
|
98 |
+
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
|
99 |
+
add_watermarker=add_watermarker,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.control_image_processor = VaeImageProcessor(
|
103 |
+
vae_scale_factor=self.vae_scale_factor,
|
104 |
+
do_convert_rgb=True,
|
105 |
+
do_normalize=False,
|
106 |
+
)
|
107 |
+
|
108 |
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
|
109 |
+
def prepare_image_latents(
|
110 |
+
self,
|
111 |
+
image,
|
112 |
+
timestep,
|
113 |
+
batch_size,
|
114 |
+
num_images_per_prompt,
|
115 |
+
dtype,
|
116 |
+
device,
|
117 |
+
generator=None,
|
118 |
+
add_noise=True,
|
119 |
+
):
|
120 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
121 |
+
raise ValueError(
|
122 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
123 |
+
)
|
124 |
+
|
125 |
+
latents_mean = latents_std = None
|
126 |
+
if (
|
127 |
+
hasattr(self.vae.config, "latents_mean")
|
128 |
+
and self.vae.config.latents_mean is not None
|
129 |
+
):
|
130 |
+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
131 |
+
if (
|
132 |
+
hasattr(self.vae.config, "latents_std")
|
133 |
+
and self.vae.config.latents_std is not None
|
134 |
+
):
|
135 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
136 |
+
|
137 |
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
138 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
139 |
+
self.text_encoder_2.to("cpu")
|
140 |
+
torch.cuda.empty_cache()
|
141 |
+
|
142 |
+
image = image.to(device=device, dtype=dtype)
|
143 |
+
|
144 |
+
batch_size = batch_size * num_images_per_prompt
|
145 |
+
|
146 |
+
if image.shape[1] == 4:
|
147 |
+
init_latents = image
|
148 |
+
|
149 |
+
else:
|
150 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
151 |
+
if self.vae.config.force_upcast:
|
152 |
+
image = image.float()
|
153 |
+
self.vae.to(dtype=torch.float32)
|
154 |
+
|
155 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
156 |
+
raise ValueError(
|
157 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
158 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
159 |
+
)
|
160 |
+
|
161 |
+
elif isinstance(generator, list):
|
162 |
+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
|
163 |
+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
|
164 |
+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
|
165 |
+
raise ValueError(
|
166 |
+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
|
167 |
+
)
|
168 |
+
|
169 |
+
init_latents = [
|
170 |
+
retrieve_latents(
|
171 |
+
self.vae.encode(image[i : i + 1]), generator=generator[i]
|
172 |
+
)
|
173 |
+
for i in range(batch_size)
|
174 |
+
]
|
175 |
+
init_latents = torch.cat(init_latents, dim=0)
|
176 |
+
else:
|
177 |
+
init_latents = retrieve_latents(
|
178 |
+
self.vae.encode(image), generator=generator
|
179 |
+
)
|
180 |
+
|
181 |
+
if self.vae.config.force_upcast:
|
182 |
+
self.vae.to(dtype)
|
183 |
+
|
184 |
+
init_latents = init_latents.to(dtype)
|
185 |
+
if latents_mean is not None and latents_std is not None:
|
186 |
+
latents_mean = latents_mean.to(device=device, dtype=dtype)
|
187 |
+
latents_std = latents_std.to(device=device, dtype=dtype)
|
188 |
+
init_latents = (
|
189 |
+
(init_latents - latents_mean)
|
190 |
+
* self.vae.config.scaling_factor
|
191 |
+
/ latents_std
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
init_latents = self.vae.config.scaling_factor * init_latents
|
195 |
+
|
196 |
+
if (
|
197 |
+
batch_size > init_latents.shape[0]
|
198 |
+
and batch_size % init_latents.shape[0] == 0
|
199 |
+
):
|
200 |
+
# expand init_latents for batch_size
|
201 |
+
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
202 |
+
init_latents = torch.cat(
|
203 |
+
[init_latents] * additional_image_per_prompt, dim=0
|
204 |
+
)
|
205 |
+
elif (
|
206 |
+
batch_size > init_latents.shape[0]
|
207 |
+
and batch_size % init_latents.shape[0] != 0
|
208 |
+
):
|
209 |
+
raise ValueError(
|
210 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
init_latents = torch.cat([init_latents], dim=0)
|
214 |
+
|
215 |
+
if add_noise:
|
216 |
+
shape = init_latents.shape
|
217 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
218 |
+
# get latents
|
219 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
220 |
+
|
221 |
+
latents = init_latents
|
222 |
+
|
223 |
+
return latents
|
224 |
+
|
225 |
+
def prepare_control_image(
|
226 |
+
self,
|
227 |
+
image,
|
228 |
+
width,
|
229 |
+
height,
|
230 |
+
batch_size,
|
231 |
+
num_images_per_prompt,
|
232 |
+
device,
|
233 |
+
dtype,
|
234 |
+
do_classifier_free_guidance=False,
|
235 |
+
num_empty_images=0, # for concat in batch like ImageDream
|
236 |
+
):
|
237 |
+
assert hasattr(
|
238 |
+
self, "control_image_processor"
|
239 |
+
), "control_image_processor is not initialized"
|
240 |
+
|
241 |
+
image = self.control_image_processor.preprocess(
|
242 |
+
image, height=height, width=width
|
243 |
+
).to(dtype=torch.float32)
|
244 |
+
|
245 |
+
if num_empty_images > 0:
|
246 |
+
image = torch.cat(
|
247 |
+
[image, torch.zeros_like(image[:num_empty_images])], dim=0
|
248 |
+
)
|
249 |
+
|
250 |
+
image_batch_size = image.shape[0]
|
251 |
+
|
252 |
+
if image_batch_size == 1:
|
253 |
+
repeat_by = batch_size
|
254 |
+
else:
|
255 |
+
# image batch size is the same as prompt batch size
|
256 |
+
repeat_by = num_images_per_prompt # always 1 for control image
|
257 |
+
|
258 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
259 |
+
|
260 |
+
image = image.to(device=device, dtype=dtype)
|
261 |
+
|
262 |
+
if do_classifier_free_guidance:
|
263 |
+
image = torch.cat([image] * 2)
|
264 |
+
|
265 |
+
return image
|
266 |
+
|
267 |
+
@torch.no_grad()
|
268 |
+
def __call__(
|
269 |
+
self,
|
270 |
+
prompt: Union[str, List[str]] = None,
|
271 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
272 |
+
height: Optional[int] = None,
|
273 |
+
width: Optional[int] = None,
|
274 |
+
num_inference_steps: int = 50,
|
275 |
+
timesteps: List[int] = None,
|
276 |
+
denoising_end: Optional[float] = None,
|
277 |
+
guidance_scale: float = 5.0,
|
278 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
279 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
280 |
+
num_images_per_prompt: Optional[int] = 1,
|
281 |
+
eta: float = 0.0,
|
282 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
283 |
+
latents: Optional[torch.FloatTensor] = None,
|
284 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
285 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
286 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
287 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
288 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
289 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
290 |
+
output_type: Optional[str] = "pil",
|
291 |
+
return_dict: bool = True,
|
292 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
293 |
+
guidance_rescale: float = 0.0,
|
294 |
+
original_size: Optional[Tuple[int, int]] = None,
|
295 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
296 |
+
target_size: Optional[Tuple[int, int]] = None,
|
297 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
298 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
299 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
300 |
+
clip_skip: Optional[int] = None,
|
301 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
302 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
303 |
+
# NEW
|
304 |
+
mv_scale: float = 1.0,
|
305 |
+
# Camera or geometry condition
|
306 |
+
control_image: Optional[PipelineImageInput] = None,
|
307 |
+
control_conditioning_scale: Optional[float] = 1.0,
|
308 |
+
control_conditioning_factor: float = 1.0,
|
309 |
+
# Image condition
|
310 |
+
reference_image: Optional[PipelineImageInput] = None,
|
311 |
+
reference_conditioning_scale: Optional[float] = 1.0,
|
312 |
+
**kwargs,
|
313 |
+
):
|
314 |
+
r"""
|
315 |
+
Function invoked when calling the pipeline for generation.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
prompt (`str` or `List[str]`, *optional*):
|
319 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
320 |
+
instead.
|
321 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
322 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
323 |
+
used in both text-encoders
|
324 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
325 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
326 |
+
Anything below 512 pixels won't work well for
|
327 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
328 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
329 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
330 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
331 |
+
Anything below 512 pixels won't work well for
|
332 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
333 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
334 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
335 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
336 |
+
expense of slower inference.
|
337 |
+
timesteps (`List[int]`, *optional*):
|
338 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
339 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
340 |
+
passed will be used. Must be in descending order.
|
341 |
+
denoising_end (`float`, *optional*):
|
342 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
343 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
344 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
345 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
346 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
347 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
348 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
349 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
350 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
351 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
352 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
353 |
+
usually at the expense of lower image quality.
|
354 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
355 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
356 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
357 |
+
less than `1`).
|
358 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
359 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
360 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
361 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
362 |
+
The number of images to generate per prompt.
|
363 |
+
eta (`float`, *optional*, defaults to 0.0):
|
364 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
365 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
366 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
367 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
368 |
+
to make generation deterministic.
|
369 |
+
latents (`torch.FloatTensor`, *optional*):
|
370 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
371 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
372 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
373 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
374 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
375 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
376 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
377 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
378 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
379 |
+
argument.
|
380 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
381 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
382 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
383 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
384 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
385 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
386 |
+
input argument.
|
387 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
388 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
389 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
390 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
391 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
392 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
393 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
394 |
+
The output format of the generate image. Choose between
|
395 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
396 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
397 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
398 |
+
of a plain tuple.
|
399 |
+
cross_attention_kwargs (`dict`, *optional*):
|
400 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
401 |
+
`self.processor` in
|
402 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
403 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
404 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
405 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
406 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
407 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
408 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
409 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
410 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
411 |
+
explained in section 2.2 of
|
412 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
413 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
414 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
415 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
416 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
417 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
418 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
419 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
420 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
421 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
422 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
423 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
424 |
+
micro-conditioning as explained in section 2.2 of
|
425 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
426 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
427 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
428 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
429 |
+
micro-conditioning as explained in section 2.2 of
|
430 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
431 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
432 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
433 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
434 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
435 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
436 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
437 |
+
callback_on_step_end (`Callable`, *optional*):
|
438 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
439 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
440 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
441 |
+
`callback_on_step_end_tensor_inputs`.
|
442 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
443 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
444 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
445 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
446 |
+
|
447 |
+
Examples:
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
451 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
452 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
453 |
+
"""
|
454 |
+
|
455 |
+
callback = kwargs.pop("callback", None)
|
456 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
457 |
+
|
458 |
+
if callback is not None:
|
459 |
+
deprecate(
|
460 |
+
"callback",
|
461 |
+
"1.0.0",
|
462 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
463 |
+
)
|
464 |
+
if callback_steps is not None:
|
465 |
+
deprecate(
|
466 |
+
"callback_steps",
|
467 |
+
"1.0.0",
|
468 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
469 |
+
)
|
470 |
+
|
471 |
+
# 0. Default height and width to unet
|
472 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
473 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
474 |
+
|
475 |
+
original_size = original_size or (height, width)
|
476 |
+
target_size = target_size or (height, width)
|
477 |
+
|
478 |
+
# 1. Check inputs. Raise error if not correct
|
479 |
+
self.check_inputs(
|
480 |
+
prompt,
|
481 |
+
prompt_2,
|
482 |
+
height,
|
483 |
+
width,
|
484 |
+
callback_steps,
|
485 |
+
negative_prompt,
|
486 |
+
negative_prompt_2,
|
487 |
+
prompt_embeds,
|
488 |
+
negative_prompt_embeds,
|
489 |
+
pooled_prompt_embeds,
|
490 |
+
negative_pooled_prompt_embeds,
|
491 |
+
ip_adapter_image,
|
492 |
+
ip_adapter_image_embeds,
|
493 |
+
callback_on_step_end_tensor_inputs,
|
494 |
+
)
|
495 |
+
|
496 |
+
self._guidance_scale = guidance_scale
|
497 |
+
self._guidance_rescale = guidance_rescale
|
498 |
+
self._clip_skip = clip_skip
|
499 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
500 |
+
self._denoising_end = denoising_end
|
501 |
+
self._interrupt = False
|
502 |
+
|
503 |
+
# 2. Define call parameters
|
504 |
+
if prompt is not None and isinstance(prompt, str):
|
505 |
+
batch_size = 1
|
506 |
+
elif prompt is not None and isinstance(prompt, list):
|
507 |
+
batch_size = len(prompt)
|
508 |
+
else:
|
509 |
+
batch_size = prompt_embeds.shape[0]
|
510 |
+
|
511 |
+
device = self._execution_device
|
512 |
+
|
513 |
+
# 3. Encode input prompt
|
514 |
+
lora_scale = (
|
515 |
+
self.cross_attention_kwargs.get("scale", None)
|
516 |
+
if self.cross_attention_kwargs is not None
|
517 |
+
else None
|
518 |
+
)
|
519 |
+
|
520 |
+
(
|
521 |
+
prompt_embeds,
|
522 |
+
negative_prompt_embeds,
|
523 |
+
pooled_prompt_embeds,
|
524 |
+
negative_pooled_prompt_embeds,
|
525 |
+
) = self.encode_prompt(
|
526 |
+
prompt=prompt,
|
527 |
+
prompt_2=prompt_2,
|
528 |
+
device=device,
|
529 |
+
num_images_per_prompt=num_images_per_prompt,
|
530 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
531 |
+
negative_prompt=negative_prompt,
|
532 |
+
negative_prompt_2=negative_prompt_2,
|
533 |
+
prompt_embeds=prompt_embeds,
|
534 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
535 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
536 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
537 |
+
lora_scale=lora_scale,
|
538 |
+
clip_skip=self.clip_skip,
|
539 |
+
)
|
540 |
+
|
541 |
+
# 4. Prepare timesteps
|
542 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
543 |
+
self.scheduler, num_inference_steps, device, timesteps
|
544 |
+
)
|
545 |
+
|
546 |
+
# 5. Prepare latent variables
|
547 |
+
num_channels_latents = self.unet.config.in_channels
|
548 |
+
latents = self.prepare_latents(
|
549 |
+
batch_size * num_images_per_prompt,
|
550 |
+
num_channels_latents,
|
551 |
+
height,
|
552 |
+
width,
|
553 |
+
prompt_embeds.dtype,
|
554 |
+
device,
|
555 |
+
generator,
|
556 |
+
latents,
|
557 |
+
)
|
558 |
+
|
559 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
560 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
561 |
+
|
562 |
+
# 7. Prepare added time ids & embeddings
|
563 |
+
add_text_embeds = pooled_prompt_embeds
|
564 |
+
if self.text_encoder_2 is None:
|
565 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
566 |
+
else:
|
567 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
568 |
+
|
569 |
+
add_time_ids = self._get_add_time_ids(
|
570 |
+
original_size,
|
571 |
+
crops_coords_top_left,
|
572 |
+
target_size,
|
573 |
+
dtype=prompt_embeds.dtype,
|
574 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
575 |
+
)
|
576 |
+
if negative_original_size is not None and negative_target_size is not None:
|
577 |
+
negative_add_time_ids = self._get_add_time_ids(
|
578 |
+
negative_original_size,
|
579 |
+
negative_crops_coords_top_left,
|
580 |
+
negative_target_size,
|
581 |
+
dtype=prompt_embeds.dtype,
|
582 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
583 |
+
)
|
584 |
+
else:
|
585 |
+
negative_add_time_ids = add_time_ids
|
586 |
+
|
587 |
+
if self.do_classifier_free_guidance:
|
588 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
589 |
+
add_text_embeds = torch.cat(
|
590 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
591 |
+
)
|
592 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
593 |
+
|
594 |
+
prompt_embeds = prompt_embeds.to(device)
|
595 |
+
add_text_embeds = add_text_embeds.to(device)
|
596 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
597 |
+
batch_size * num_images_per_prompt, 1
|
598 |
+
)
|
599 |
+
|
600 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
601 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
602 |
+
ip_adapter_image,
|
603 |
+
ip_adapter_image_embeds,
|
604 |
+
device,
|
605 |
+
batch_size * num_images_per_prompt,
|
606 |
+
self.do_classifier_free_guidance,
|
607 |
+
)
|
608 |
+
|
609 |
+
# Preprocess reference image
|
610 |
+
reference_image = self.image_processor.preprocess(reference_image)
|
611 |
+
reference_latents = self.prepare_image_latents(
|
612 |
+
reference_image,
|
613 |
+
timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
|
614 |
+
batch_size,
|
615 |
+
1,
|
616 |
+
prompt_embeds.dtype,
|
617 |
+
device,
|
618 |
+
generator,
|
619 |
+
add_noise=False,
|
620 |
+
)
|
621 |
+
|
622 |
+
with torch.no_grad():
|
623 |
+
ref_timesteps = torch.zeros_like(timesteps[0])
|
624 |
+
ref_hidden_states = {}
|
625 |
+
|
626 |
+
self.unet(
|
627 |
+
reference_latents,
|
628 |
+
ref_timesteps,
|
629 |
+
encoder_hidden_states=prompt_embeds[-1:],
|
630 |
+
added_cond_kwargs={
|
631 |
+
"text_embeds": add_text_embeds[-1:],
|
632 |
+
"time_ids": add_time_ids[-1:],
|
633 |
+
},
|
634 |
+
cross_attention_kwargs={
|
635 |
+
"cache_hidden_states": ref_hidden_states,
|
636 |
+
"use_mv": False,
|
637 |
+
"use_ref": False,
|
638 |
+
},
|
639 |
+
return_dict=False,
|
640 |
+
)
|
641 |
+
ref_hidden_states = {
|
642 |
+
k: v.repeat_interleave(num_images_per_prompt, dim=0)
|
643 |
+
for k, v in ref_hidden_states.items()
|
644 |
+
}
|
645 |
+
if self.do_classifier_free_guidance:
|
646 |
+
ref_hidden_states = {
|
647 |
+
k: torch.cat([torch.zeros_like(v), v], dim=0)
|
648 |
+
for k, v in ref_hidden_states.items()
|
649 |
+
}
|
650 |
+
|
651 |
+
cross_attention_kwargs = {
|
652 |
+
"mv_scale": mv_scale,
|
653 |
+
"ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()},
|
654 |
+
"ref_scale": reference_conditioning_scale,
|
655 |
+
"num_views": num_images_per_prompt,
|
656 |
+
**(self.cross_attention_kwargs or {}),
|
657 |
+
}
|
658 |
+
|
659 |
+
# Preprocess control image
|
660 |
+
control_image_feature = self.prepare_control_image(
|
661 |
+
image=control_image,
|
662 |
+
width=width,
|
663 |
+
height=height,
|
664 |
+
batch_size=batch_size * num_images_per_prompt,
|
665 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
666 |
+
device=device,
|
667 |
+
dtype=latents.dtype,
|
668 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
669 |
+
)
|
670 |
+
control_image_feature = control_image_feature.to(
|
671 |
+
device=device, dtype=latents.dtype
|
672 |
+
)
|
673 |
+
|
674 |
+
adapter_state = self.cond_encoder(control_image_feature)
|
675 |
+
for i, state in enumerate(adapter_state):
|
676 |
+
adapter_state[i] = state * control_conditioning_scale
|
677 |
+
|
678 |
+
# 8. Denoising loop
|
679 |
+
num_warmup_steps = max(
|
680 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
681 |
+
)
|
682 |
+
|
683 |
+
# 8.1 Apply denoising_end
|
684 |
+
if (
|
685 |
+
self.denoising_end is not None
|
686 |
+
and isinstance(self.denoising_end, float)
|
687 |
+
and self.denoising_end > 0
|
688 |
+
and self.denoising_end < 1
|
689 |
+
):
|
690 |
+
discrete_timestep_cutoff = int(
|
691 |
+
round(
|
692 |
+
self.scheduler.config.num_train_timesteps
|
693 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
694 |
+
)
|
695 |
+
)
|
696 |
+
num_inference_steps = len(
|
697 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
|
698 |
+
)
|
699 |
+
timesteps = timesteps[:num_inference_steps]
|
700 |
+
|
701 |
+
# 9. Optionally get Guidance Scale Embedding
|
702 |
+
timestep_cond = None
|
703 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
704 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
705 |
+
batch_size * num_images_per_prompt
|
706 |
+
)
|
707 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
708 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
709 |
+
).to(device=device, dtype=latents.dtype)
|
710 |
+
|
711 |
+
self._num_timesteps = len(timesteps)
|
712 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
713 |
+
for i, t in enumerate(timesteps):
|
714 |
+
if self.interrupt:
|
715 |
+
continue
|
716 |
+
|
717 |
+
# expand the latents if we are doing classifier free guidance
|
718 |
+
latent_model_input = (
|
719 |
+
torch.cat([latents] * 2)
|
720 |
+
if self.do_classifier_free_guidance
|
721 |
+
else latents
|
722 |
+
)
|
723 |
+
|
724 |
+
latent_model_input = self.scheduler.scale_model_input(
|
725 |
+
latent_model_input, t
|
726 |
+
)
|
727 |
+
|
728 |
+
added_cond_kwargs = {
|
729 |
+
"text_embeds": add_text_embeds,
|
730 |
+
"time_ids": add_time_ids,
|
731 |
+
}
|
732 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
733 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
734 |
+
|
735 |
+
if i < int(num_inference_steps * control_conditioning_factor):
|
736 |
+
down_intrablock_additional_residuals = [
|
737 |
+
state.clone() for state in adapter_state
|
738 |
+
]
|
739 |
+
else:
|
740 |
+
down_intrablock_additional_residuals = None
|
741 |
+
|
742 |
+
# predict the noise residual
|
743 |
+
noise_pred = self.unet(
|
744 |
+
latent_model_input,
|
745 |
+
t,
|
746 |
+
encoder_hidden_states=prompt_embeds,
|
747 |
+
timestep_cond=timestep_cond,
|
748 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
749 |
+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
750 |
+
added_cond_kwargs=added_cond_kwargs,
|
751 |
+
return_dict=False,
|
752 |
+
)[0]
|
753 |
+
|
754 |
+
# perform guidance
|
755 |
+
if self.do_classifier_free_guidance:
|
756 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
757 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
758 |
+
noise_pred_text - noise_pred_uncond
|
759 |
+
)
|
760 |
+
|
761 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
762 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
763 |
+
noise_pred = rescale_noise_cfg(
|
764 |
+
noise_pred,
|
765 |
+
noise_pred_text,
|
766 |
+
guidance_rescale=self.guidance_rescale,
|
767 |
+
)
|
768 |
+
|
769 |
+
# compute the previous noisy sample x_t -> x_t-1
|
770 |
+
latents_dtype = latents.dtype
|
771 |
+
latents = self.scheduler.step(
|
772 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
773 |
+
)[0]
|
774 |
+
if latents.dtype != latents_dtype:
|
775 |
+
if torch.backends.mps.is_available():
|
776 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
777 |
+
latents = latents.to(latents_dtype)
|
778 |
+
|
779 |
+
if callback_on_step_end is not None:
|
780 |
+
callback_kwargs = {}
|
781 |
+
for k in callback_on_step_end_tensor_inputs:
|
782 |
+
callback_kwargs[k] = locals()[k]
|
783 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
784 |
+
|
785 |
+
latents = callback_outputs.pop("latents", latents)
|
786 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
787 |
+
negative_prompt_embeds = callback_outputs.pop(
|
788 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
789 |
+
)
|
790 |
+
add_text_embeds = callback_outputs.pop(
|
791 |
+
"add_text_embeds", add_text_embeds
|
792 |
+
)
|
793 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
794 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
795 |
+
)
|
796 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
797 |
+
negative_add_time_ids = callback_outputs.pop(
|
798 |
+
"negative_add_time_ids", negative_add_time_ids
|
799 |
+
)
|
800 |
+
|
801 |
+
# call the callback, if provided
|
802 |
+
if i == len(timesteps) - 1 or (
|
803 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
804 |
+
):
|
805 |
+
progress_bar.update()
|
806 |
+
if callback is not None and i % callback_steps == 0:
|
807 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
808 |
+
callback(step_idx, t, latents)
|
809 |
+
|
810 |
+
if not output_type == "latent":
|
811 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
812 |
+
needs_upcasting = (
|
813 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
814 |
+
)
|
815 |
+
|
816 |
+
if needs_upcasting:
|
817 |
+
self.upcast_vae()
|
818 |
+
latents = latents.to(
|
819 |
+
next(iter(self.vae.post_quant_conv.parameters())).dtype
|
820 |
+
)
|
821 |
+
elif latents.dtype != self.vae.dtype:
|
822 |
+
if torch.backends.mps.is_available():
|
823 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
824 |
+
self.vae = self.vae.to(latents.dtype)
|
825 |
+
|
826 |
+
# unscale/denormalize the latents
|
827 |
+
# denormalize with the mean and std if available and not None
|
828 |
+
has_latents_mean = (
|
829 |
+
hasattr(self.vae.config, "latents_mean")
|
830 |
+
and self.vae.config.latents_mean is not None
|
831 |
+
)
|
832 |
+
has_latents_std = (
|
833 |
+
hasattr(self.vae.config, "latents_std")
|
834 |
+
and self.vae.config.latents_std is not None
|
835 |
+
)
|
836 |
+
if has_latents_mean and has_latents_std:
|
837 |
+
latents_mean = (
|
838 |
+
torch.tensor(self.vae.config.latents_mean)
|
839 |
+
.view(1, 4, 1, 1)
|
840 |
+
.to(latents.device, latents.dtype)
|
841 |
+
)
|
842 |
+
latents_std = (
|
843 |
+
torch.tensor(self.vae.config.latents_std)
|
844 |
+
.view(1, 4, 1, 1)
|
845 |
+
.to(latents.device, latents.dtype)
|
846 |
+
)
|
847 |
+
latents = (
|
848 |
+
latents * latents_std / self.vae.config.scaling_factor
|
849 |
+
+ latents_mean
|
850 |
+
)
|
851 |
+
else:
|
852 |
+
latents = latents / self.vae.config.scaling_factor
|
853 |
+
|
854 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
855 |
+
|
856 |
+
# cast back to fp16 if needed
|
857 |
+
if needs_upcasting:
|
858 |
+
self.vae.to(dtype=torch.float16)
|
859 |
+
else:
|
860 |
+
image = latents
|
861 |
+
|
862 |
+
if not output_type == "latent":
|
863 |
+
# apply watermark if available
|
864 |
+
if self.watermark is not None:
|
865 |
+
image = self.watermark.apply_watermark(image)
|
866 |
+
|
867 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
868 |
+
|
869 |
+
# Offload all models
|
870 |
+
self.maybe_free_model_hooks()
|
871 |
+
|
872 |
+
if not return_dict:
|
873 |
+
return (image,)
|
874 |
+
|
875 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
876 |
+
|
877 |
+
### NEW: adapters ###
|
878 |
+
def _init_custom_adapter(
|
879 |
+
self,
|
880 |
+
# Multi-view adapter
|
881 |
+
num_views: int = 1,
|
882 |
+
self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
|
883 |
+
# Condition encoder
|
884 |
+
cond_in_channels: int = 6,
|
885 |
+
# For training
|
886 |
+
copy_attn_weights: bool = True,
|
887 |
+
zero_init_module_keys: List[str] = [],
|
888 |
+
):
|
889 |
+
# Condition encoder
|
890 |
+
self.cond_encoder = T2IAdapter(
|
891 |
+
in_channels=cond_in_channels,
|
892 |
+
channels=(320, 640, 1280, 1280),
|
893 |
+
num_res_blocks=2,
|
894 |
+
downscale_factor=16,
|
895 |
+
adapter_type="full_adapter_xl",
|
896 |
+
)
|
897 |
+
|
898 |
+
# set custom attn processor for multi-view attention and image cross-attention
|
899 |
+
self.unet: UNet2DConditionModel
|
900 |
+
set_unet_2d_condition_attn_processor(
|
901 |
+
self.unet,
|
902 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
903 |
+
query_dim=hs,
|
904 |
+
inner_dim=hs,
|
905 |
+
num_views=num_views,
|
906 |
+
name=name,
|
907 |
+
use_mv=True,
|
908 |
+
use_ref=True,
|
909 |
+
),
|
910 |
+
set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
911 |
+
query_dim=hs,
|
912 |
+
inner_dim=hs,
|
913 |
+
num_views=num_views,
|
914 |
+
name=name,
|
915 |
+
use_mv=False,
|
916 |
+
use_ref=False,
|
917 |
+
),
|
918 |
+
)
|
919 |
+
|
920 |
+
# copy decoupled attention weights from original unet
|
921 |
+
if copy_attn_weights:
|
922 |
+
state_dict = self.unet.state_dict()
|
923 |
+
for key in state_dict.keys():
|
924 |
+
if "_mv" in key:
|
925 |
+
compatible_key = key.replace("_mv", "").replace("processor.", "")
|
926 |
+
elif "_ref" in key:
|
927 |
+
compatible_key = key.replace("_ref", "").replace("processor.", "")
|
928 |
+
else:
|
929 |
+
compatible_key = key
|
930 |
+
|
931 |
+
is_zero_init_key = any([k in key for k in zero_init_module_keys])
|
932 |
+
if is_zero_init_key:
|
933 |
+
state_dict[key] = torch.zeros_like(state_dict[compatible_key])
|
934 |
+
else:
|
935 |
+
state_dict[key] = state_dict[compatible_key].clone()
|
936 |
+
self.unet.load_state_dict(state_dict)
|
937 |
+
|
938 |
+
def _load_custom_adapter(self, state_dict):
|
939 |
+
self.unet.load_state_dict(state_dict, strict=False)
|
940 |
+
self.cond_encoder.load_state_dict(state_dict, strict=False)
|
941 |
+
|
942 |
+
def _save_custom_adapter(
|
943 |
+
self,
|
944 |
+
include_keys: Optional[List[str]] = None,
|
945 |
+
exclude_keys: Optional[List[str]] = None,
|
946 |
+
):
|
947 |
+
def include_fn(k):
|
948 |
+
is_included = False
|
949 |
+
|
950 |
+
if include_keys is not None:
|
951 |
+
is_included = is_included or any([key in k for key in include_keys])
|
952 |
+
if exclude_keys is not None:
|
953 |
+
is_included = is_included and not any(
|
954 |
+
[key in k for key in exclude_keys]
|
955 |
+
)
|
956 |
+
|
957 |
+
return is_included
|
958 |
+
|
959 |
+
state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
|
960 |
+
state_dict.update(self.cond_encoder.state_dict())
|
961 |
+
|
962 |
+
return state_dict
|
mvadapter/pipelines/pipeline_mvadapter_t2mv_sd.py
ADDED
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
19 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
20 |
+
from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
|
21 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
22 |
+
StableDiffusionPipelineOutput,
|
23 |
+
)
|
24 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
25 |
+
StableDiffusionPipeline,
|
26 |
+
rescale_noise_cfg,
|
27 |
+
retrieve_timesteps,
|
28 |
+
)
|
29 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
30 |
+
StableDiffusionSafetyChecker,
|
31 |
+
)
|
32 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
33 |
+
from diffusers.utils import deprecate, is_torch_xla_available, logging
|
34 |
+
from diffusers.utils.torch_utils import randn_tensor
|
35 |
+
from packaging import version
|
36 |
+
from transformers import (
|
37 |
+
CLIPImageProcessor,
|
38 |
+
CLIPTextModel,
|
39 |
+
CLIPTokenizer,
|
40 |
+
CLIPVisionModelWithProjection,
|
41 |
+
)
|
42 |
+
|
43 |
+
from ..loaders import CustomAdapterMixin
|
44 |
+
from ..models.attention_processor import (
|
45 |
+
DecoupledMVRowSelfAttnProcessor2_0,
|
46 |
+
set_unet_2d_condition_attn_processor,
|
47 |
+
)
|
48 |
+
|
49 |
+
if is_torch_xla_available():
|
50 |
+
import torch_xla.core.xla_model as xm
|
51 |
+
|
52 |
+
XLA_AVAILABLE = True
|
53 |
+
else:
|
54 |
+
XLA_AVAILABLE = False
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
57 |
+
|
58 |
+
|
59 |
+
class MVAdapterT2MVSDPipeline(StableDiffusionPipeline, CustomAdapterMixin):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
vae: AutoencoderKL,
|
63 |
+
text_encoder: CLIPTextModel,
|
64 |
+
tokenizer: CLIPTokenizer,
|
65 |
+
unet: UNet2DConditionModel,
|
66 |
+
scheduler: KarrasDiffusionSchedulers,
|
67 |
+
safety_checker: StableDiffusionSafetyChecker,
|
68 |
+
feature_extractor: CLIPImageProcessor,
|
69 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
70 |
+
requires_safety_checker: bool = True,
|
71 |
+
):
|
72 |
+
super().__init__(
|
73 |
+
vae=vae,
|
74 |
+
text_encoder=text_encoder,
|
75 |
+
tokenizer=tokenizer,
|
76 |
+
unet=unet,
|
77 |
+
scheduler=scheduler,
|
78 |
+
safety_checker=safety_checker,
|
79 |
+
feature_extractor=feature_extractor,
|
80 |
+
image_encoder=image_encoder,
|
81 |
+
requires_safety_checker=requires_safety_checker,
|
82 |
+
)
|
83 |
+
|
84 |
+
self.control_image_processor = VaeImageProcessor(
|
85 |
+
vae_scale_factor=self.vae_scale_factor,
|
86 |
+
do_convert_rgb=True,
|
87 |
+
do_normalize=False,
|
88 |
+
)
|
89 |
+
|
90 |
+
def prepare_control_image(
|
91 |
+
self,
|
92 |
+
image,
|
93 |
+
width,
|
94 |
+
height,
|
95 |
+
batch_size,
|
96 |
+
num_images_per_prompt,
|
97 |
+
device,
|
98 |
+
dtype,
|
99 |
+
do_classifier_free_guidance=False,
|
100 |
+
):
|
101 |
+
assert hasattr(
|
102 |
+
self, "control_image_processor"
|
103 |
+
), "control_image_processor is not initialized"
|
104 |
+
|
105 |
+
image = self.control_image_processor.preprocess(
|
106 |
+
image, height=height, width=width
|
107 |
+
).to(dtype=torch.float32)
|
108 |
+
image_batch_size = image.shape[0]
|
109 |
+
|
110 |
+
if image_batch_size == 1:
|
111 |
+
repeat_by = batch_size
|
112 |
+
else:
|
113 |
+
# image batch size is the same as prompt batch size
|
114 |
+
repeat_by = num_images_per_prompt # always 1 for control image
|
115 |
+
|
116 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
117 |
+
|
118 |
+
image = image.to(device=device, dtype=dtype)
|
119 |
+
|
120 |
+
if do_classifier_free_guidance:
|
121 |
+
image = torch.cat([image] * 2)
|
122 |
+
|
123 |
+
return image
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def __call__(
|
127 |
+
self,
|
128 |
+
prompt: Union[str, List[str]] = None,
|
129 |
+
height: Optional[int] = None,
|
130 |
+
width: Optional[int] = None,
|
131 |
+
num_inference_steps: int = 50,
|
132 |
+
timesteps: List[int] = None,
|
133 |
+
sigmas: List[float] = None,
|
134 |
+
guidance_scale: float = 7.5,
|
135 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
136 |
+
num_images_per_prompt: Optional[int] = 1,
|
137 |
+
eta: float = 0.0,
|
138 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
139 |
+
latents: Optional[torch.Tensor] = None,
|
140 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
141 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
142 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
143 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
144 |
+
output_type: Optional[str] = "pil",
|
145 |
+
return_dict: bool = True,
|
146 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
147 |
+
guidance_rescale: float = 0.0,
|
148 |
+
clip_skip: Optional[int] = None,
|
149 |
+
callback_on_step_end: Optional[
|
150 |
+
Union[
|
151 |
+
Callable[[int, int, Dict], None],
|
152 |
+
PipelineCallback,
|
153 |
+
MultiPipelineCallbacks,
|
154 |
+
]
|
155 |
+
] = None,
|
156 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
157 |
+
# NEW
|
158 |
+
mv_scale: float = 1.0,
|
159 |
+
# Camera or geometry condition
|
160 |
+
control_image: Optional[PipelineImageInput] = None,
|
161 |
+
control_conditioning_scale: Optional[float] = 1.0,
|
162 |
+
control_conditioning_factor: float = 1.0,
|
163 |
+
# Optional. controlnet
|
164 |
+
controlnet_image: Optional[PipelineImageInput] = None,
|
165 |
+
controlnet_conditioning_scale: Optional[float] = 1.0,
|
166 |
+
**kwargs,
|
167 |
+
):
|
168 |
+
r"""
|
169 |
+
The call function to the pipeline for generation.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
prompt (`str` or `List[str]`, *optional*):
|
173 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
174 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
175 |
+
The height in pixels of the generated image.
|
176 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
177 |
+
The width in pixels of the generated image.
|
178 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
179 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
180 |
+
expense of slower inference.
|
181 |
+
timesteps (`List[int]`, *optional*):
|
182 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
183 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
184 |
+
passed will be used. Must be in descending order.
|
185 |
+
sigmas (`List[float]`, *optional*):
|
186 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
187 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
188 |
+
will be used.
|
189 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
190 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
191 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
192 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
193 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
194 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
195 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
196 |
+
The number of images to generate per prompt.
|
197 |
+
eta (`float`, *optional*, defaults to 0.0):
|
198 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
199 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
200 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
201 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
202 |
+
generation deterministic.
|
203 |
+
latents (`torch.Tensor`, *optional*):
|
204 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
205 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
206 |
+
tensor is generated by sampling using the supplied random `generator`.
|
207 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
208 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
209 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
210 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
211 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
212 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
213 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
214 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
215 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
216 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
217 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
218 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
219 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
220 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
221 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
222 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
223 |
+
plain tuple.
|
224 |
+
cross_attention_kwargs (`dict`, *optional*):
|
225 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
226 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
227 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
228 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
229 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
230 |
+
using zero terminal SNR.
|
231 |
+
clip_skip (`int`, *optional*):
|
232 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
233 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
234 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
235 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
236 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
237 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
238 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
239 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
240 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
241 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
242 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
243 |
+
|
244 |
+
Examples:
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
248 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
249 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
250 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
251 |
+
"not-safe-for-work" (nsfw) content.
|
252 |
+
"""
|
253 |
+
|
254 |
+
callback = kwargs.pop("callback", None)
|
255 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
256 |
+
|
257 |
+
if callback is not None:
|
258 |
+
deprecate(
|
259 |
+
"callback",
|
260 |
+
"1.0.0",
|
261 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
262 |
+
)
|
263 |
+
if callback_steps is not None:
|
264 |
+
deprecate(
|
265 |
+
"callback_steps",
|
266 |
+
"1.0.0",
|
267 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
268 |
+
)
|
269 |
+
|
270 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
271 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
272 |
+
|
273 |
+
# 0. Default height and width to unet
|
274 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
275 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
276 |
+
# to deal with lora scaling and other possible forward hooks
|
277 |
+
|
278 |
+
# 1. Check inputs. Raise error if not correct
|
279 |
+
self.check_inputs(
|
280 |
+
prompt,
|
281 |
+
height,
|
282 |
+
width,
|
283 |
+
callback_steps,
|
284 |
+
negative_prompt,
|
285 |
+
prompt_embeds,
|
286 |
+
negative_prompt_embeds,
|
287 |
+
ip_adapter_image,
|
288 |
+
ip_adapter_image_embeds,
|
289 |
+
callback_on_step_end_tensor_inputs,
|
290 |
+
)
|
291 |
+
|
292 |
+
self._guidance_scale = guidance_scale
|
293 |
+
self._guidance_rescale = guidance_rescale
|
294 |
+
self._clip_skip = clip_skip
|
295 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
296 |
+
self._interrupt = False
|
297 |
+
|
298 |
+
# 2. Define call parameters
|
299 |
+
if prompt is not None and isinstance(prompt, str):
|
300 |
+
batch_size = 1
|
301 |
+
elif prompt is not None and isinstance(prompt, list):
|
302 |
+
batch_size = len(prompt)
|
303 |
+
else:
|
304 |
+
batch_size = prompt_embeds.shape[0]
|
305 |
+
|
306 |
+
device = self._execution_device
|
307 |
+
|
308 |
+
# 3. Encode input prompt
|
309 |
+
lora_scale = (
|
310 |
+
self.cross_attention_kwargs.get("scale", None)
|
311 |
+
if self.cross_attention_kwargs is not None
|
312 |
+
else None
|
313 |
+
)
|
314 |
+
|
315 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
316 |
+
prompt,
|
317 |
+
device,
|
318 |
+
num_images_per_prompt,
|
319 |
+
self.do_classifier_free_guidance,
|
320 |
+
negative_prompt,
|
321 |
+
prompt_embeds=prompt_embeds,
|
322 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
323 |
+
lora_scale=lora_scale,
|
324 |
+
clip_skip=self.clip_skip,
|
325 |
+
)
|
326 |
+
|
327 |
+
# For classifier free guidance, we need to do two forward passes.
|
328 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
329 |
+
# to avoid doing two forward passes
|
330 |
+
if self.do_classifier_free_guidance:
|
331 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
332 |
+
|
333 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
334 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
335 |
+
ip_adapter_image,
|
336 |
+
ip_adapter_image_embeds,
|
337 |
+
device,
|
338 |
+
batch_size * num_images_per_prompt,
|
339 |
+
self.do_classifier_free_guidance,
|
340 |
+
)
|
341 |
+
|
342 |
+
# 4. Prepare timesteps
|
343 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
344 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
345 |
+
)
|
346 |
+
|
347 |
+
# 5. Prepare latent variables
|
348 |
+
num_channels_latents = self.unet.config.in_channels
|
349 |
+
latents = self.prepare_latents(
|
350 |
+
batch_size * num_images_per_prompt,
|
351 |
+
num_channels_latents,
|
352 |
+
height,
|
353 |
+
width,
|
354 |
+
prompt_embeds.dtype,
|
355 |
+
device,
|
356 |
+
generator,
|
357 |
+
latents,
|
358 |
+
)
|
359 |
+
|
360 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
361 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
362 |
+
|
363 |
+
# 6.1 Add image embeds for IP-Adapter
|
364 |
+
added_cond_kwargs = (
|
365 |
+
{"image_embeds": image_embeds}
|
366 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
367 |
+
else None
|
368 |
+
)
|
369 |
+
|
370 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
371 |
+
timestep_cond = None
|
372 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
373 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
374 |
+
batch_size * num_images_per_prompt
|
375 |
+
)
|
376 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
377 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
378 |
+
).to(device=device, dtype=latents.dtype)
|
379 |
+
|
380 |
+
# Preprocess control image
|
381 |
+
control_image_feature = self.prepare_control_image(
|
382 |
+
image=control_image,
|
383 |
+
width=width,
|
384 |
+
height=height,
|
385 |
+
batch_size=batch_size * num_images_per_prompt,
|
386 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
387 |
+
device=device,
|
388 |
+
dtype=latents.dtype,
|
389 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
390 |
+
)
|
391 |
+
control_image_feature = control_image_feature.to(
|
392 |
+
device=device, dtype=latents.dtype
|
393 |
+
)
|
394 |
+
|
395 |
+
adapter_state = self.cond_encoder(control_image_feature)
|
396 |
+
for i, state in enumerate(adapter_state):
|
397 |
+
adapter_state[i] = state * control_conditioning_scale
|
398 |
+
|
399 |
+
# Preprocess controlnet image if provided
|
400 |
+
do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
|
401 |
+
if do_controlnet:
|
402 |
+
controlnet_image = self.prepare_control_image(
|
403 |
+
image=controlnet_image,
|
404 |
+
width=width,
|
405 |
+
height=height,
|
406 |
+
batch_size=batch_size * num_images_per_prompt,
|
407 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
408 |
+
device=device,
|
409 |
+
dtype=latents.dtype,
|
410 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
411 |
+
)
|
412 |
+
controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
|
413 |
+
|
414 |
+
# 7. Denoising loop
|
415 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
416 |
+
self._num_timesteps = len(timesteps)
|
417 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
418 |
+
for i, t in enumerate(timesteps):
|
419 |
+
if self.interrupt:
|
420 |
+
continue
|
421 |
+
|
422 |
+
# expand the latents if we are doing classifier free guidance
|
423 |
+
latent_model_input = (
|
424 |
+
torch.cat([latents] * 2)
|
425 |
+
if self.do_classifier_free_guidance
|
426 |
+
else latents
|
427 |
+
)
|
428 |
+
latent_model_input = self.scheduler.scale_model_input(
|
429 |
+
latent_model_input, t
|
430 |
+
)
|
431 |
+
|
432 |
+
if i < int(num_inference_steps * control_conditioning_factor):
|
433 |
+
down_intrablock_additional_residuals = [
|
434 |
+
state.clone() for state in adapter_state
|
435 |
+
]
|
436 |
+
else:
|
437 |
+
down_intrablock_additional_residuals = None
|
438 |
+
|
439 |
+
unet_add_kwargs = {}
|
440 |
+
|
441 |
+
# Do controlnet if provided
|
442 |
+
if do_controlnet:
|
443 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
444 |
+
latent_model_input,
|
445 |
+
t,
|
446 |
+
encoder_hidden_states=prompt_embeds,
|
447 |
+
controlnet_cond=controlnet_image,
|
448 |
+
conditioning_scale=controlnet_conditioning_scale,
|
449 |
+
guess_mode=False,
|
450 |
+
added_cond_kwargs=added_cond_kwargs,
|
451 |
+
return_dict=False,
|
452 |
+
)
|
453 |
+
unet_add_kwargs.update(
|
454 |
+
{
|
455 |
+
"down_block_additional_residuals": down_block_res_samples,
|
456 |
+
"mid_block_additional_residual": mid_block_res_sample,
|
457 |
+
}
|
458 |
+
)
|
459 |
+
|
460 |
+
# predict the noise residual
|
461 |
+
noise_pred = self.unet(
|
462 |
+
latent_model_input,
|
463 |
+
t,
|
464 |
+
encoder_hidden_states=prompt_embeds,
|
465 |
+
timestep_cond=timestep_cond,
|
466 |
+
cross_attention_kwargs={
|
467 |
+
"mv_scale": mv_scale,
|
468 |
+
"num_views": num_images_per_prompt,
|
469 |
+
**(self.cross_attention_kwargs or {}),
|
470 |
+
},
|
471 |
+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
472 |
+
added_cond_kwargs=added_cond_kwargs,
|
473 |
+
return_dict=False,
|
474 |
+
**unet_add_kwargs,
|
475 |
+
)[0]
|
476 |
+
|
477 |
+
# perform guidance
|
478 |
+
if self.do_classifier_free_guidance:
|
479 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
480 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
481 |
+
noise_pred_text - noise_pred_uncond
|
482 |
+
)
|
483 |
+
|
484 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
485 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
486 |
+
noise_pred = rescale_noise_cfg(
|
487 |
+
noise_pred,
|
488 |
+
noise_pred_text,
|
489 |
+
guidance_rescale=self.guidance_rescale,
|
490 |
+
)
|
491 |
+
|
492 |
+
# compute the previous noisy sample x_t -> x_t-1
|
493 |
+
latents = self.scheduler.step(
|
494 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
495 |
+
)[0]
|
496 |
+
|
497 |
+
if callback_on_step_end is not None:
|
498 |
+
callback_kwargs = {}
|
499 |
+
for k in callback_on_step_end_tensor_inputs:
|
500 |
+
callback_kwargs[k] = locals()[k]
|
501 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
502 |
+
|
503 |
+
latents = callback_outputs.pop("latents", latents)
|
504 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
505 |
+
negative_prompt_embeds = callback_outputs.pop(
|
506 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
507 |
+
)
|
508 |
+
|
509 |
+
# call the callback, if provided
|
510 |
+
if i == len(timesteps) - 1 or (
|
511 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
512 |
+
):
|
513 |
+
progress_bar.update()
|
514 |
+
if callback is not None and i % callback_steps == 0:
|
515 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
516 |
+
callback(step_idx, t, latents)
|
517 |
+
|
518 |
+
if XLA_AVAILABLE:
|
519 |
+
xm.mark_step()
|
520 |
+
|
521 |
+
if not output_type == "latent":
|
522 |
+
image = self.vae.decode(
|
523 |
+
latents / self.vae.config.scaling_factor,
|
524 |
+
return_dict=False,
|
525 |
+
generator=generator,
|
526 |
+
)[0]
|
527 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
528 |
+
image, device, prompt_embeds.dtype
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
image = latents
|
532 |
+
has_nsfw_concept = None
|
533 |
+
|
534 |
+
if has_nsfw_concept is None:
|
535 |
+
do_denormalize = [True] * image.shape[0]
|
536 |
+
else:
|
537 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
538 |
+
image = self.image_processor.postprocess(
|
539 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
540 |
+
)
|
541 |
+
|
542 |
+
# Offload all models
|
543 |
+
self.maybe_free_model_hooks()
|
544 |
+
|
545 |
+
if not return_dict:
|
546 |
+
return (image, has_nsfw_concept)
|
547 |
+
|
548 |
+
return StableDiffusionPipelineOutput(
|
549 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
550 |
+
)
|
551 |
+
|
552 |
+
### NEW: adapters ###
|
553 |
+
def _init_custom_adapter(
|
554 |
+
self,
|
555 |
+
# Multi-view adapter
|
556 |
+
num_views: int = 1,
|
557 |
+
self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
|
558 |
+
# Condition encoder
|
559 |
+
cond_in_channels: int = 6,
|
560 |
+
# For training
|
561 |
+
copy_attn_weights: bool = True,
|
562 |
+
zero_init_module_keys: List[str] = [],
|
563 |
+
):
|
564 |
+
# Condition encoder
|
565 |
+
self.cond_encoder = T2IAdapter(
|
566 |
+
in_channels=cond_in_channels,
|
567 |
+
channels=self.unet.config.block_out_channels,
|
568 |
+
num_res_blocks=self.unet.config.layers_per_block,
|
569 |
+
downscale_factor=8,
|
570 |
+
)
|
571 |
+
|
572 |
+
# set custom attn processor for multi-view attention
|
573 |
+
self.unet: UNet2DConditionModel
|
574 |
+
set_unet_2d_condition_attn_processor(
|
575 |
+
self.unet,
|
576 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
577 |
+
query_dim=hs,
|
578 |
+
inner_dim=hs,
|
579 |
+
num_views=num_views,
|
580 |
+
name=name,
|
581 |
+
use_mv=True,
|
582 |
+
use_ref=False,
|
583 |
+
),
|
584 |
+
set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
585 |
+
query_dim=hs,
|
586 |
+
inner_dim=hs,
|
587 |
+
num_views=num_views,
|
588 |
+
name=name,
|
589 |
+
use_mv=False,
|
590 |
+
use_ref=False,
|
591 |
+
),
|
592 |
+
)
|
593 |
+
|
594 |
+
# copy decoupled attention weights from original unet
|
595 |
+
if copy_attn_weights:
|
596 |
+
state_dict = self.unet.state_dict()
|
597 |
+
for key in state_dict.keys():
|
598 |
+
if "_mv" in key:
|
599 |
+
compatible_key = key.replace("_mv", "").replace("processor.", "")
|
600 |
+
else:
|
601 |
+
compatible_key = key
|
602 |
+
|
603 |
+
is_zero_init_key = any([k in key for k in zero_init_module_keys])
|
604 |
+
if is_zero_init_key:
|
605 |
+
state_dict[key] = torch.zeros_like(state_dict[compatible_key])
|
606 |
+
else:
|
607 |
+
state_dict[key] = state_dict[compatible_key].clone()
|
608 |
+
self.unet.load_state_dict(state_dict)
|
609 |
+
|
610 |
+
def _load_custom_adapter(self, state_dict):
|
611 |
+
self.unet.load_state_dict(state_dict, strict=False)
|
612 |
+
self.cond_encoder.load_state_dict(state_dict, strict=False)
|
613 |
+
|
614 |
+
def _save_custom_adapter(
|
615 |
+
self,
|
616 |
+
include_keys: Optional[List[str]] = None,
|
617 |
+
exclude_keys: Optional[List[str]] = None,
|
618 |
+
):
|
619 |
+
def include_fn(k):
|
620 |
+
is_included = False
|
621 |
+
|
622 |
+
if include_keys is not None:
|
623 |
+
is_included = is_included or any([key in k for key in include_keys])
|
624 |
+
if exclude_keys is not None:
|
625 |
+
is_included = is_included and not any(
|
626 |
+
[key in k for key in exclude_keys]
|
627 |
+
)
|
628 |
+
|
629 |
+
return is_included
|
630 |
+
|
631 |
+
state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
|
632 |
+
state_dict.update(self.cond_encoder.state_dict())
|
633 |
+
|
634 |
+
return state_dict
|
mvadapter/pipelines/pipeline_mvadapter_t2mv_sdxl.py
ADDED
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
19 |
+
from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
|
20 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
|
21 |
+
StableDiffusionXLPipelineOutput,
|
22 |
+
)
|
23 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
24 |
+
StableDiffusionXLPipeline,
|
25 |
+
rescale_noise_cfg,
|
26 |
+
retrieve_timesteps,
|
27 |
+
)
|
28 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
29 |
+
from diffusers.utils import deprecate, logging
|
30 |
+
from transformers import (
|
31 |
+
CLIPImageProcessor,
|
32 |
+
CLIPTextModel,
|
33 |
+
CLIPTextModelWithProjection,
|
34 |
+
CLIPTokenizer,
|
35 |
+
CLIPVisionModelWithProjection,
|
36 |
+
)
|
37 |
+
|
38 |
+
from ..loaders import CustomAdapterMixin
|
39 |
+
from ..models.attention_processor import (
|
40 |
+
DecoupledMVRowSelfAttnProcessor2_0,
|
41 |
+
set_unet_2d_condition_attn_processor,
|
42 |
+
)
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
45 |
+
|
46 |
+
|
47 |
+
class MVAdapterT2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
vae: AutoencoderKL,
|
51 |
+
text_encoder: CLIPTextModel,
|
52 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
53 |
+
tokenizer: CLIPTokenizer,
|
54 |
+
tokenizer_2: CLIPTokenizer,
|
55 |
+
unet: UNet2DConditionModel,
|
56 |
+
scheduler: KarrasDiffusionSchedulers,
|
57 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
58 |
+
feature_extractor: CLIPImageProcessor = None,
|
59 |
+
force_zeros_for_empty_prompt: bool = True,
|
60 |
+
add_watermarker: Optional[bool] = None,
|
61 |
+
):
|
62 |
+
super().__init__(
|
63 |
+
vae=vae,
|
64 |
+
text_encoder=text_encoder,
|
65 |
+
text_encoder_2=text_encoder_2,
|
66 |
+
tokenizer=tokenizer,
|
67 |
+
tokenizer_2=tokenizer_2,
|
68 |
+
unet=unet,
|
69 |
+
scheduler=scheduler,
|
70 |
+
image_encoder=image_encoder,
|
71 |
+
feature_extractor=feature_extractor,
|
72 |
+
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
|
73 |
+
add_watermarker=add_watermarker,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.control_image_processor = VaeImageProcessor(
|
77 |
+
vae_scale_factor=self.vae_scale_factor,
|
78 |
+
do_convert_rgb=True,
|
79 |
+
do_normalize=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
def prepare_control_image(
|
83 |
+
self,
|
84 |
+
image,
|
85 |
+
width,
|
86 |
+
height,
|
87 |
+
batch_size,
|
88 |
+
num_images_per_prompt,
|
89 |
+
device,
|
90 |
+
dtype,
|
91 |
+
do_classifier_free_guidance=False,
|
92 |
+
):
|
93 |
+
assert hasattr(
|
94 |
+
self, "control_image_processor"
|
95 |
+
), "control_image_processor is not initialized"
|
96 |
+
|
97 |
+
image = self.control_image_processor.preprocess(
|
98 |
+
image, height=height, width=width
|
99 |
+
).to(dtype=torch.float32)
|
100 |
+
image_batch_size = image.shape[0]
|
101 |
+
|
102 |
+
if image_batch_size == 1:
|
103 |
+
repeat_by = batch_size
|
104 |
+
else:
|
105 |
+
# image batch size is the same as prompt batch size
|
106 |
+
repeat_by = num_images_per_prompt # always 1 for control image
|
107 |
+
|
108 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
109 |
+
|
110 |
+
image = image.to(device=device, dtype=dtype)
|
111 |
+
|
112 |
+
if do_classifier_free_guidance:
|
113 |
+
image = torch.cat([image] * 2)
|
114 |
+
|
115 |
+
return image
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def __call__(
|
119 |
+
self,
|
120 |
+
prompt: Union[str, List[str]] = None,
|
121 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
122 |
+
height: Optional[int] = None,
|
123 |
+
width: Optional[int] = None,
|
124 |
+
num_inference_steps: int = 50,
|
125 |
+
timesteps: List[int] = None,
|
126 |
+
denoising_end: Optional[float] = None,
|
127 |
+
guidance_scale: float = 5.0,
|
128 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
129 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
130 |
+
num_images_per_prompt: Optional[int] = 1,
|
131 |
+
eta: float = 0.0,
|
132 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
133 |
+
latents: Optional[torch.FloatTensor] = None,
|
134 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
135 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
136 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
137 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
138 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
139 |
+
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
|
140 |
+
output_type: Optional[str] = "pil",
|
141 |
+
return_dict: bool = True,
|
142 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
143 |
+
guidance_rescale: float = 0.0,
|
144 |
+
original_size: Optional[Tuple[int, int]] = None,
|
145 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
146 |
+
target_size: Optional[Tuple[int, int]] = None,
|
147 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
148 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
149 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
150 |
+
clip_skip: Optional[int] = None,
|
151 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
152 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
153 |
+
# NEW
|
154 |
+
mv_scale: float = 1.0,
|
155 |
+
# Camera or geometry condition
|
156 |
+
control_image: Optional[PipelineImageInput] = None,
|
157 |
+
control_conditioning_scale: Optional[float] = 1.0,
|
158 |
+
control_conditioning_factor: float = 1.0,
|
159 |
+
# Optional. controlnet
|
160 |
+
controlnet_image: Optional[PipelineImageInput] = None,
|
161 |
+
controlnet_conditioning_scale: Optional[float] = 1.0,
|
162 |
+
**kwargs,
|
163 |
+
):
|
164 |
+
r"""
|
165 |
+
Function invoked when calling the pipeline for generation.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
prompt (`str` or `List[str]`, *optional*):
|
169 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
170 |
+
instead.
|
171 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
172 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
173 |
+
used in both text-encoders
|
174 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
175 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
176 |
+
Anything below 512 pixels won't work well for
|
177 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
178 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
179 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
180 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
181 |
+
Anything below 512 pixels won't work well for
|
182 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
183 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
184 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
185 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
186 |
+
expense of slower inference.
|
187 |
+
timesteps (`List[int]`, *optional*):
|
188 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
189 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
190 |
+
passed will be used. Must be in descending order.
|
191 |
+
denoising_end (`float`, *optional*):
|
192 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
193 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
194 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
195 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
196 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
197 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
198 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
199 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
200 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
201 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
202 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
203 |
+
usually at the expense of lower image quality.
|
204 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
205 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
206 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
207 |
+
less than `1`).
|
208 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
209 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
210 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
211 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
212 |
+
The number of images to generate per prompt.
|
213 |
+
eta (`float`, *optional*, defaults to 0.0):
|
214 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
215 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
216 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
217 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
218 |
+
to make generation deterministic.
|
219 |
+
latents (`torch.FloatTensor`, *optional*):
|
220 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
221 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
222 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
223 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
224 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
225 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
226 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
227 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
228 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
229 |
+
argument.
|
230 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
231 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
232 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
233 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
234 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
235 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
236 |
+
input argument.
|
237 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
238 |
+
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
|
239 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
240 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
241 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
242 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
243 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
244 |
+
The output format of the generate image. Choose between
|
245 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
246 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
247 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
248 |
+
of a plain tuple.
|
249 |
+
cross_attention_kwargs (`dict`, *optional*):
|
250 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
251 |
+
`self.processor` in
|
252 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
253 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
254 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
255 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
256 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
257 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
258 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
259 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
260 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
261 |
+
explained in section 2.2 of
|
262 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
263 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
264 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
265 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
266 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
267 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
268 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
269 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
270 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
271 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
272 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
273 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
274 |
+
micro-conditioning as explained in section 2.2 of
|
275 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
276 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
277 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
278 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
279 |
+
micro-conditioning as explained in section 2.2 of
|
280 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
281 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
282 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
283 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
284 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
285 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
286 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
287 |
+
callback_on_step_end (`Callable`, *optional*):
|
288 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
289 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
290 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
291 |
+
`callback_on_step_end_tensor_inputs`.
|
292 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
293 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
294 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
295 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
296 |
+
|
297 |
+
Examples:
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
301 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
302 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
303 |
+
"""
|
304 |
+
|
305 |
+
callback = kwargs.pop("callback", None)
|
306 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
307 |
+
|
308 |
+
if callback is not None:
|
309 |
+
deprecate(
|
310 |
+
"callback",
|
311 |
+
"1.0.0",
|
312 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
313 |
+
)
|
314 |
+
if callback_steps is not None:
|
315 |
+
deprecate(
|
316 |
+
"callback_steps",
|
317 |
+
"1.0.0",
|
318 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
319 |
+
)
|
320 |
+
|
321 |
+
# 0. Default height and width to unet
|
322 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
323 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
324 |
+
|
325 |
+
original_size = original_size or (height, width)
|
326 |
+
target_size = target_size or (height, width)
|
327 |
+
|
328 |
+
# 1. Check inputs. Raise error if not correct
|
329 |
+
self.check_inputs(
|
330 |
+
prompt,
|
331 |
+
prompt_2,
|
332 |
+
height,
|
333 |
+
width,
|
334 |
+
callback_steps,
|
335 |
+
negative_prompt,
|
336 |
+
negative_prompt_2,
|
337 |
+
prompt_embeds,
|
338 |
+
negative_prompt_embeds,
|
339 |
+
pooled_prompt_embeds,
|
340 |
+
negative_pooled_prompt_embeds,
|
341 |
+
ip_adapter_image,
|
342 |
+
ip_adapter_image_embeds,
|
343 |
+
callback_on_step_end_tensor_inputs,
|
344 |
+
)
|
345 |
+
|
346 |
+
self._guidance_scale = guidance_scale
|
347 |
+
self._guidance_rescale = guidance_rescale
|
348 |
+
self._clip_skip = clip_skip
|
349 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
350 |
+
self._denoising_end = denoising_end
|
351 |
+
self._interrupt = False
|
352 |
+
|
353 |
+
# 2. Define call parameters
|
354 |
+
if prompt is not None and isinstance(prompt, str):
|
355 |
+
batch_size = 1
|
356 |
+
elif prompt is not None and isinstance(prompt, list):
|
357 |
+
batch_size = len(prompt)
|
358 |
+
else:
|
359 |
+
batch_size = prompt_embeds.shape[0]
|
360 |
+
|
361 |
+
device = self._execution_device
|
362 |
+
|
363 |
+
# 3. Encode input prompt
|
364 |
+
lora_scale = (
|
365 |
+
self.cross_attention_kwargs.get("scale", None)
|
366 |
+
if self.cross_attention_kwargs is not None
|
367 |
+
else None
|
368 |
+
)
|
369 |
+
|
370 |
+
(
|
371 |
+
prompt_embeds,
|
372 |
+
negative_prompt_embeds,
|
373 |
+
pooled_prompt_embeds,
|
374 |
+
negative_pooled_prompt_embeds,
|
375 |
+
) = self.encode_prompt(
|
376 |
+
prompt=prompt,
|
377 |
+
prompt_2=prompt_2,
|
378 |
+
device=device,
|
379 |
+
num_images_per_prompt=num_images_per_prompt,
|
380 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
381 |
+
negative_prompt=negative_prompt,
|
382 |
+
negative_prompt_2=negative_prompt_2,
|
383 |
+
prompt_embeds=prompt_embeds,
|
384 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
385 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
386 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
387 |
+
lora_scale=lora_scale,
|
388 |
+
clip_skip=self.clip_skip,
|
389 |
+
)
|
390 |
+
|
391 |
+
# 4. Prepare timesteps
|
392 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
393 |
+
self.scheduler, num_inference_steps, device, timesteps
|
394 |
+
)
|
395 |
+
|
396 |
+
# 5. Prepare latent variables
|
397 |
+
num_channels_latents = self.unet.config.in_channels
|
398 |
+
latents = self.prepare_latents(
|
399 |
+
batch_size * num_images_per_prompt,
|
400 |
+
num_channels_latents,
|
401 |
+
height,
|
402 |
+
width,
|
403 |
+
prompt_embeds.dtype,
|
404 |
+
device,
|
405 |
+
generator,
|
406 |
+
latents,
|
407 |
+
)
|
408 |
+
|
409 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
410 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
411 |
+
|
412 |
+
# 7. Prepare added time ids & embeddings
|
413 |
+
add_text_embeds = pooled_prompt_embeds
|
414 |
+
if self.text_encoder_2 is None:
|
415 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
416 |
+
else:
|
417 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
418 |
+
|
419 |
+
add_time_ids = self._get_add_time_ids(
|
420 |
+
original_size,
|
421 |
+
crops_coords_top_left,
|
422 |
+
target_size,
|
423 |
+
dtype=prompt_embeds.dtype,
|
424 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
425 |
+
)
|
426 |
+
if negative_original_size is not None and negative_target_size is not None:
|
427 |
+
negative_add_time_ids = self._get_add_time_ids(
|
428 |
+
negative_original_size,
|
429 |
+
negative_crops_coords_top_left,
|
430 |
+
negative_target_size,
|
431 |
+
dtype=prompt_embeds.dtype,
|
432 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
433 |
+
)
|
434 |
+
else:
|
435 |
+
negative_add_time_ids = add_time_ids
|
436 |
+
|
437 |
+
if self.do_classifier_free_guidance:
|
438 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
439 |
+
add_text_embeds = torch.cat(
|
440 |
+
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
441 |
+
)
|
442 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
443 |
+
|
444 |
+
prompt_embeds = prompt_embeds.to(device)
|
445 |
+
add_text_embeds = add_text_embeds.to(device)
|
446 |
+
add_time_ids = add_time_ids.to(device).repeat(
|
447 |
+
batch_size * num_images_per_prompt, 1
|
448 |
+
)
|
449 |
+
|
450 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
451 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
452 |
+
ip_adapter_image,
|
453 |
+
ip_adapter_image_embeds,
|
454 |
+
device,
|
455 |
+
batch_size * num_images_per_prompt,
|
456 |
+
self.do_classifier_free_guidance,
|
457 |
+
)
|
458 |
+
|
459 |
+
# Preprocess control image
|
460 |
+
control_image_feature = self.prepare_control_image(
|
461 |
+
image=control_image,
|
462 |
+
width=width,
|
463 |
+
height=height,
|
464 |
+
batch_size=batch_size * num_images_per_prompt,
|
465 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
466 |
+
device=device,
|
467 |
+
dtype=latents.dtype,
|
468 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
469 |
+
)
|
470 |
+
control_image_feature = control_image_feature.to(
|
471 |
+
device=device, dtype=latents.dtype
|
472 |
+
)
|
473 |
+
|
474 |
+
adapter_state = self.cond_encoder(control_image_feature)
|
475 |
+
for i, state in enumerate(adapter_state):
|
476 |
+
adapter_state[i] = state * control_conditioning_scale
|
477 |
+
|
478 |
+
# Preprocess controlnet image if provided
|
479 |
+
do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
|
480 |
+
if do_controlnet:
|
481 |
+
controlnet_image = self.prepare_control_image(
|
482 |
+
image=controlnet_image,
|
483 |
+
width=width,
|
484 |
+
height=height,
|
485 |
+
batch_size=batch_size * num_images_per_prompt,
|
486 |
+
num_images_per_prompt=1, # NOTE: always 1 for control images
|
487 |
+
device=device,
|
488 |
+
dtype=latents.dtype,
|
489 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
490 |
+
)
|
491 |
+
controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
|
492 |
+
|
493 |
+
# 8. Denoising loop
|
494 |
+
num_warmup_steps = max(
|
495 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
496 |
+
)
|
497 |
+
|
498 |
+
# 8.1 Apply denoising_end
|
499 |
+
if (
|
500 |
+
self.denoising_end is not None
|
501 |
+
and isinstance(self.denoising_end, float)
|
502 |
+
and self.denoising_end > 0
|
503 |
+
and self.denoising_end < 1
|
504 |
+
):
|
505 |
+
discrete_timestep_cutoff = int(
|
506 |
+
round(
|
507 |
+
self.scheduler.config.num_train_timesteps
|
508 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
509 |
+
)
|
510 |
+
)
|
511 |
+
num_inference_steps = len(
|
512 |
+
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
|
513 |
+
)
|
514 |
+
timesteps = timesteps[:num_inference_steps]
|
515 |
+
|
516 |
+
# 9. Optionally get Guidance Scale Embedding
|
517 |
+
timestep_cond = None
|
518 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
519 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
520 |
+
batch_size * num_images_per_prompt
|
521 |
+
)
|
522 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
523 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
524 |
+
).to(device=device, dtype=latents.dtype)
|
525 |
+
|
526 |
+
self._num_timesteps = len(timesteps)
|
527 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
528 |
+
for i, t in enumerate(timesteps):
|
529 |
+
if self.interrupt:
|
530 |
+
continue
|
531 |
+
|
532 |
+
# expand the latents if we are doing classifier free guidance
|
533 |
+
latent_model_input = (
|
534 |
+
torch.cat([latents] * 2)
|
535 |
+
if self.do_classifier_free_guidance
|
536 |
+
else latents
|
537 |
+
)
|
538 |
+
|
539 |
+
latent_model_input = self.scheduler.scale_model_input(
|
540 |
+
latent_model_input, t
|
541 |
+
)
|
542 |
+
|
543 |
+
added_cond_kwargs = {
|
544 |
+
"text_embeds": add_text_embeds,
|
545 |
+
"time_ids": add_time_ids,
|
546 |
+
}
|
547 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
548 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
549 |
+
|
550 |
+
if i < int(num_inference_steps * control_conditioning_factor):
|
551 |
+
down_intrablock_additional_residuals = [
|
552 |
+
state.clone() for state in adapter_state
|
553 |
+
]
|
554 |
+
else:
|
555 |
+
down_intrablock_additional_residuals = None
|
556 |
+
|
557 |
+
unet_add_kwargs = {}
|
558 |
+
|
559 |
+
# Do controlnet if provided
|
560 |
+
if do_controlnet:
|
561 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
562 |
+
latent_model_input,
|
563 |
+
t,
|
564 |
+
encoder_hidden_states=prompt_embeds,
|
565 |
+
controlnet_cond=controlnet_image,
|
566 |
+
conditioning_scale=controlnet_conditioning_scale,
|
567 |
+
guess_mode=False,
|
568 |
+
added_cond_kwargs=added_cond_kwargs,
|
569 |
+
return_dict=False,
|
570 |
+
)
|
571 |
+
unet_add_kwargs.update(
|
572 |
+
{
|
573 |
+
"down_block_additional_residuals": down_block_res_samples,
|
574 |
+
"mid_block_additional_residual": mid_block_res_sample,
|
575 |
+
}
|
576 |
+
)
|
577 |
+
|
578 |
+
# predict the noise residual
|
579 |
+
noise_pred = self.unet(
|
580 |
+
latent_model_input,
|
581 |
+
t,
|
582 |
+
encoder_hidden_states=prompt_embeds,
|
583 |
+
timestep_cond=timestep_cond,
|
584 |
+
cross_attention_kwargs={
|
585 |
+
"mv_scale": mv_scale,
|
586 |
+
"num_views": num_images_per_prompt,
|
587 |
+
**(self.cross_attention_kwargs or {}),
|
588 |
+
},
|
589 |
+
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
590 |
+
added_cond_kwargs=added_cond_kwargs,
|
591 |
+
return_dict=False,
|
592 |
+
**unet_add_kwargs,
|
593 |
+
)[0]
|
594 |
+
|
595 |
+
# perform guidance
|
596 |
+
if self.do_classifier_free_guidance:
|
597 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
598 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
599 |
+
noise_pred_text - noise_pred_uncond
|
600 |
+
)
|
601 |
+
|
602 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
603 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
604 |
+
noise_pred = rescale_noise_cfg(
|
605 |
+
noise_pred,
|
606 |
+
noise_pred_text,
|
607 |
+
guidance_rescale=self.guidance_rescale,
|
608 |
+
)
|
609 |
+
|
610 |
+
# compute the previous noisy sample x_t -> x_t-1
|
611 |
+
latents_dtype = latents.dtype
|
612 |
+
latents = self.scheduler.step(
|
613 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
614 |
+
)[0]
|
615 |
+
if latents.dtype != latents_dtype:
|
616 |
+
if torch.backends.mps.is_available():
|
617 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
618 |
+
latents = latents.to(latents_dtype)
|
619 |
+
|
620 |
+
if callback_on_step_end is not None:
|
621 |
+
callback_kwargs = {}
|
622 |
+
for k in callback_on_step_end_tensor_inputs:
|
623 |
+
callback_kwargs[k] = locals()[k]
|
624 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
625 |
+
|
626 |
+
latents = callback_outputs.pop("latents", latents)
|
627 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
628 |
+
negative_prompt_embeds = callback_outputs.pop(
|
629 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
630 |
+
)
|
631 |
+
add_text_embeds = callback_outputs.pop(
|
632 |
+
"add_text_embeds", add_text_embeds
|
633 |
+
)
|
634 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
635 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
636 |
+
)
|
637 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
638 |
+
negative_add_time_ids = callback_outputs.pop(
|
639 |
+
"negative_add_time_ids", negative_add_time_ids
|
640 |
+
)
|
641 |
+
|
642 |
+
# call the callback, if provided
|
643 |
+
if i == len(timesteps) - 1 or (
|
644 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
645 |
+
):
|
646 |
+
progress_bar.update()
|
647 |
+
if callback is not None and i % callback_steps == 0:
|
648 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
649 |
+
callback(step_idx, t, latents)
|
650 |
+
|
651 |
+
if not output_type == "latent":
|
652 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
653 |
+
needs_upcasting = (
|
654 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
655 |
+
)
|
656 |
+
|
657 |
+
if needs_upcasting:
|
658 |
+
self.upcast_vae()
|
659 |
+
latents = latents.to(
|
660 |
+
next(iter(self.vae.post_quant_conv.parameters())).dtype
|
661 |
+
)
|
662 |
+
elif latents.dtype != self.vae.dtype:
|
663 |
+
if torch.backends.mps.is_available():
|
664 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
665 |
+
self.vae = self.vae.to(latents.dtype)
|
666 |
+
|
667 |
+
# unscale/denormalize the latents
|
668 |
+
# denormalize with the mean and std if available and not None
|
669 |
+
has_latents_mean = (
|
670 |
+
hasattr(self.vae.config, "latents_mean")
|
671 |
+
and self.vae.config.latents_mean is not None
|
672 |
+
)
|
673 |
+
has_latents_std = (
|
674 |
+
hasattr(self.vae.config, "latents_std")
|
675 |
+
and self.vae.config.latents_std is not None
|
676 |
+
)
|
677 |
+
if has_latents_mean and has_latents_std:
|
678 |
+
latents_mean = (
|
679 |
+
torch.tensor(self.vae.config.latents_mean)
|
680 |
+
.view(1, 4, 1, 1)
|
681 |
+
.to(latents.device, latents.dtype)
|
682 |
+
)
|
683 |
+
latents_std = (
|
684 |
+
torch.tensor(self.vae.config.latents_std)
|
685 |
+
.view(1, 4, 1, 1)
|
686 |
+
.to(latents.device, latents.dtype)
|
687 |
+
)
|
688 |
+
latents = (
|
689 |
+
latents * latents_std / self.vae.config.scaling_factor
|
690 |
+
+ latents_mean
|
691 |
+
)
|
692 |
+
else:
|
693 |
+
latents = latents / self.vae.config.scaling_factor
|
694 |
+
|
695 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
696 |
+
|
697 |
+
# cast back to fp16 if needed
|
698 |
+
if needs_upcasting:
|
699 |
+
self.vae.to(dtype=torch.float16)
|
700 |
+
else:
|
701 |
+
image = latents
|
702 |
+
|
703 |
+
if not output_type == "latent":
|
704 |
+
# apply watermark if available
|
705 |
+
if self.watermark is not None:
|
706 |
+
image = self.watermark.apply_watermark(image)
|
707 |
+
|
708 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
709 |
+
|
710 |
+
# Offload all models
|
711 |
+
self.maybe_free_model_hooks()
|
712 |
+
|
713 |
+
if not return_dict:
|
714 |
+
return (image,)
|
715 |
+
|
716 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
717 |
+
|
718 |
+
### NEW: adapters ###
|
719 |
+
def _init_custom_adapter(
|
720 |
+
self,
|
721 |
+
# Multi-view adapter
|
722 |
+
num_views: int = 1,
|
723 |
+
self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
|
724 |
+
# Condition encoder
|
725 |
+
cond_in_channels: int = 6,
|
726 |
+
# For training
|
727 |
+
copy_attn_weights: bool = True,
|
728 |
+
zero_init_module_keys: List[str] = [],
|
729 |
+
):
|
730 |
+
# Condition encoder
|
731 |
+
self.cond_encoder = T2IAdapter(
|
732 |
+
in_channels=cond_in_channels,
|
733 |
+
channels=(320, 640, 1280, 1280),
|
734 |
+
num_res_blocks=2,
|
735 |
+
downscale_factor=16,
|
736 |
+
adapter_type="full_adapter_xl",
|
737 |
+
)
|
738 |
+
|
739 |
+
# set custom attn processor for multi-view attention
|
740 |
+
self.unet: UNet2DConditionModel
|
741 |
+
set_unet_2d_condition_attn_processor(
|
742 |
+
self.unet,
|
743 |
+
set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
744 |
+
query_dim=hs,
|
745 |
+
inner_dim=hs,
|
746 |
+
num_views=num_views,
|
747 |
+
name=name,
|
748 |
+
use_mv=True,
|
749 |
+
use_ref=False,
|
750 |
+
),
|
751 |
+
set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
|
752 |
+
query_dim=hs,
|
753 |
+
inner_dim=hs,
|
754 |
+
num_views=num_views,
|
755 |
+
name=name,
|
756 |
+
use_mv=False,
|
757 |
+
use_ref=False,
|
758 |
+
),
|
759 |
+
)
|
760 |
+
|
761 |
+
# copy decoupled attention weights from original unet
|
762 |
+
if copy_attn_weights:
|
763 |
+
state_dict = self.unet.state_dict()
|
764 |
+
for key in state_dict.keys():
|
765 |
+
if "_mv" in key:
|
766 |
+
compatible_key = key.replace("_mv", "").replace("processor.", "")
|
767 |
+
else:
|
768 |
+
compatible_key = key
|
769 |
+
|
770 |
+
is_zero_init_key = any([k in key for k in zero_init_module_keys])
|
771 |
+
if is_zero_init_key:
|
772 |
+
state_dict[key] = torch.zeros_like(state_dict[compatible_key])
|
773 |
+
else:
|
774 |
+
state_dict[key] = state_dict[compatible_key].clone()
|
775 |
+
self.unet.load_state_dict(state_dict)
|
776 |
+
|
777 |
+
def _load_custom_adapter(self, state_dict):
|
778 |
+
self.unet.load_state_dict(state_dict, strict=False)
|
779 |
+
self.cond_encoder.load_state_dict(state_dict, strict=False)
|
780 |
+
|
781 |
+
def _save_custom_adapter(
|
782 |
+
self,
|
783 |
+
include_keys: Optional[List[str]] = None,
|
784 |
+
exclude_keys: Optional[List[str]] = None,
|
785 |
+
):
|
786 |
+
def include_fn(k):
|
787 |
+
is_included = False
|
788 |
+
|
789 |
+
if include_keys is not None:
|
790 |
+
is_included = is_included or any([key in k for key in include_keys])
|
791 |
+
if exclude_keys is not None:
|
792 |
+
is_included = is_included and not any(
|
793 |
+
[key in k for key in exclude_keys]
|
794 |
+
)
|
795 |
+
|
796 |
+
return is_included
|
797 |
+
|
798 |
+
state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
|
799 |
+
state_dict.update(self.cond_encoder.state_dict())
|
800 |
+
|
801 |
+
return state_dict
|
mvadapter/schedulers/scheduler_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None):
|
5 |
+
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
6 |
+
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
7 |
+
timesteps = timesteps.to(device)
|
8 |
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
9 |
+
sigma = sigmas[step_indices].flatten()
|
10 |
+
while len(sigma.shape) < n_dim:
|
11 |
+
sigma = sigma.unsqueeze(-1)
|
12 |
+
return sigma
|
13 |
+
|
14 |
+
|
15 |
+
def SNR_to_betas(snr):
|
16 |
+
"""
|
17 |
+
Converts SNR to betas
|
18 |
+
"""
|
19 |
+
# alphas_cumprod = pass
|
20 |
+
# snr = (alpha / ) ** 2
|
21 |
+
# alpha_t^2 / (1 - alpha_t^2) = snr
|
22 |
+
alpha_t = (snr / (1 + snr)) ** 0.5
|
23 |
+
alphas_cumprod = alpha_t**2
|
24 |
+
alphas = alphas_cumprod / torch.cat(
|
25 |
+
[torch.ones(1, device=snr.device), alphas_cumprod[:-1]]
|
26 |
+
)
|
27 |
+
betas = 1 - alphas
|
28 |
+
return betas
|
29 |
+
|
30 |
+
|
31 |
+
def compute_snr(timesteps, noise_scheduler):
|
32 |
+
"""
|
33 |
+
Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
|
34 |
+
"""
|
35 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
36 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
37 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
38 |
+
|
39 |
+
# Expand the tensors.
|
40 |
+
# Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
|
41 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
42 |
+
timesteps
|
43 |
+
].float()
|
44 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
45 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
46 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
47 |
+
|
48 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
49 |
+
device=timesteps.device
|
50 |
+
)[timesteps].float()
|
51 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
52 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
53 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
54 |
+
|
55 |
+
# Compute SNR.
|
56 |
+
snr = (alpha / sigma) ** 2
|
57 |
+
return snr
|
58 |
+
|
59 |
+
|
60 |
+
def compute_alpha(timesteps, noise_scheduler):
|
61 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
62 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
63 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
64 |
+
timesteps
|
65 |
+
].float()
|
66 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
67 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
68 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
69 |
+
|
70 |
+
return alpha
|
mvadapter/schedulers/scheduling_shift_snr.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .scheduler_utils import SNR_to_betas, compute_snr
|
6 |
+
|
7 |
+
|
8 |
+
class ShiftSNRScheduler:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
noise_scheduler: Any,
|
12 |
+
timesteps: Any,
|
13 |
+
shift_scale: float,
|
14 |
+
scheduler_class: Any,
|
15 |
+
):
|
16 |
+
self.noise_scheduler = noise_scheduler
|
17 |
+
self.timesteps = timesteps
|
18 |
+
self.shift_scale = shift_scale
|
19 |
+
self.scheduler_class = scheduler_class
|
20 |
+
|
21 |
+
def _get_shift_scheduler(self):
|
22 |
+
"""
|
23 |
+
Prepare scheduler for shifted betas.
|
24 |
+
|
25 |
+
:return: A scheduler object configured with shifted betas
|
26 |
+
"""
|
27 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
28 |
+
shifted_betas = SNR_to_betas(snr / self.shift_scale)
|
29 |
+
|
30 |
+
return self.scheduler_class.from_config(
|
31 |
+
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
|
32 |
+
)
|
33 |
+
|
34 |
+
def _get_interpolated_shift_scheduler(self):
|
35 |
+
"""
|
36 |
+
Prepare scheduler for shifted betas and interpolate with the original betas in log space.
|
37 |
+
|
38 |
+
:return: A scheduler object configured with interpolated shifted betas
|
39 |
+
"""
|
40 |
+
snr = compute_snr(self.timesteps, self.noise_scheduler)
|
41 |
+
shifted_snr = snr / self.shift_scale
|
42 |
+
|
43 |
+
weighting = self.timesteps.float() / (
|
44 |
+
self.noise_scheduler.config.num_train_timesteps - 1
|
45 |
+
)
|
46 |
+
interpolated_snr = torch.exp(
|
47 |
+
torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
|
48 |
+
)
|
49 |
+
|
50 |
+
shifted_betas = SNR_to_betas(interpolated_snr)
|
51 |
+
|
52 |
+
return self.scheduler_class.from_config(
|
53 |
+
self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
|
54 |
+
)
|
55 |
+
|
56 |
+
@classmethod
|
57 |
+
def from_scheduler(
|
58 |
+
cls,
|
59 |
+
noise_scheduler: Any,
|
60 |
+
shift_mode: str = "default",
|
61 |
+
timesteps: Any = None,
|
62 |
+
shift_scale: float = 1.0,
|
63 |
+
scheduler_class: Any = None,
|
64 |
+
):
|
65 |
+
# Check input
|
66 |
+
if timesteps is None:
|
67 |
+
timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
|
68 |
+
if scheduler_class is None:
|
69 |
+
scheduler_class = noise_scheduler.__class__
|
70 |
+
|
71 |
+
# Create scheduler
|
72 |
+
shift_scheduler = cls(
|
73 |
+
noise_scheduler=noise_scheduler,
|
74 |
+
timesteps=timesteps,
|
75 |
+
shift_scale=shift_scale,
|
76 |
+
scheduler_class=scheduler_class,
|
77 |
+
)
|
78 |
+
|
79 |
+
if shift_mode == "default":
|
80 |
+
return shift_scheduler._get_shift_scheduler()
|
81 |
+
elif shift_mode == "interpolated":
|
82 |
+
return shift_scheduler._get_interpolated_shift_scheduler()
|
83 |
+
else:
|
84 |
+
raise ValueError(f"Unknown shift_mode: {shift_mode}")
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
"""
|
89 |
+
Compare the alpha values for different noise schedulers.
|
90 |
+
"""
|
91 |
+
import matplotlib.pyplot as plt
|
92 |
+
from diffusers import DDPMScheduler
|
93 |
+
|
94 |
+
from .scheduler_utils import compute_alpha
|
95 |
+
|
96 |
+
# Base
|
97 |
+
timesteps = torch.arange(0, 1000)
|
98 |
+
noise_scheduler_base = DDPMScheduler.from_pretrained(
|
99 |
+
"runwayml/stable-diffusion-v1-5", subfolder="scheduler"
|
100 |
+
)
|
101 |
+
alpha = compute_alpha(timesteps, noise_scheduler_base)
|
102 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
|
103 |
+
|
104 |
+
# Kolors
|
105 |
+
num_train_timesteps_ = 1100
|
106 |
+
timesteps_ = torch.arange(0, num_train_timesteps_)
|
107 |
+
noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
|
108 |
+
noise_scheduler_kolors = DDPMScheduler.from_config(
|
109 |
+
noise_scheduler_base.config, **noise_kwargs
|
110 |
+
)
|
111 |
+
alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
|
112 |
+
plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
|
113 |
+
|
114 |
+
# Shift betas
|
115 |
+
shift_scale = 8.0
|
116 |
+
noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
|
117 |
+
noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
|
118 |
+
)
|
119 |
+
alpha = compute_alpha(timesteps, noise_scheduler_shift)
|
120 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
|
121 |
+
|
122 |
+
# Shift betas (interpolated)
|
123 |
+
noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
|
124 |
+
noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
|
125 |
+
)
|
126 |
+
alpha = compute_alpha(timesteps, noise_scheduler_inter)
|
127 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
|
128 |
+
|
129 |
+
# ZeroSNR
|
130 |
+
noise_scheduler = DDPMScheduler.from_config(
|
131 |
+
noise_scheduler_base.config, rescale_betas_zero_snr=True
|
132 |
+
)
|
133 |
+
alpha = compute_alpha(timesteps, noise_scheduler)
|
134 |
+
plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
|
135 |
+
|
136 |
+
plt.legend()
|
137 |
+
plt.grid()
|
138 |
+
plt.savefig("check_alpha.png")
|
mvadapter/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .camera import get_camera, get_orthogonal_camera
|
2 |
+
from .geometry import get_plucker_embeds_from_cameras_ortho
|
3 |
+
from .saving import make_image_grid, tensor_to_image
|
mvadapter/utils/camera.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import trimesh
|
9 |
+
from PIL import Image
|
10 |
+
from torch import BoolTensor, FloatTensor
|
11 |
+
|
12 |
+
LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
|
13 |
+
|
14 |
+
|
15 |
+
def list_to_pt(
|
16 |
+
x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None
|
17 |
+
) -> torch.Tensor:
|
18 |
+
if isinstance(x, list) or isinstance(x, np.ndarray):
|
19 |
+
return torch.tensor(x, dtype=dtype, device=device)
|
20 |
+
return x.to(dtype=dtype)
|
21 |
+
|
22 |
+
|
23 |
+
def get_c2w(
|
24 |
+
elevation_deg: LIST_TYPE,
|
25 |
+
distance: LIST_TYPE,
|
26 |
+
azimuth_deg: Optional[LIST_TYPE],
|
27 |
+
num_views: Optional[int] = 1,
|
28 |
+
device: Optional[str] = None,
|
29 |
+
) -> torch.FloatTensor:
|
30 |
+
if azimuth_deg is None:
|
31 |
+
assert (
|
32 |
+
num_views is not None
|
33 |
+
), "num_views must be provided if azimuth_deg is None."
|
34 |
+
azimuth_deg = torch.linspace(
|
35 |
+
0, 360, num_views + 1, dtype=torch.float32, device=device
|
36 |
+
)[:-1]
|
37 |
+
else:
|
38 |
+
num_views = len(azimuth_deg)
|
39 |
+
azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
|
40 |
+
elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
|
41 |
+
camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
|
42 |
+
elevation = elevation_deg * math.pi / 180
|
43 |
+
azimuth = azimuth_deg * math.pi / 180
|
44 |
+
camera_positions = torch.stack(
|
45 |
+
[
|
46 |
+
camera_distances * torch.cos(elevation) * torch.cos(azimuth),
|
47 |
+
camera_distances * torch.cos(elevation) * torch.sin(azimuth),
|
48 |
+
camera_distances * torch.sin(elevation),
|
49 |
+
],
|
50 |
+
dim=-1,
|
51 |
+
)
|
52 |
+
center = torch.zeros_like(camera_positions)
|
53 |
+
up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat(
|
54 |
+
num_views, 1
|
55 |
+
)
|
56 |
+
lookat = F.normalize(center - camera_positions, dim=-1)
|
57 |
+
right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
|
58 |
+
up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
|
59 |
+
c2w3x4 = torch.cat(
|
60 |
+
[torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
|
61 |
+
dim=-1,
|
62 |
+
)
|
63 |
+
c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
|
64 |
+
c2w[:, 3, 3] = 1.0
|
65 |
+
return c2w
|
66 |
+
|
67 |
+
|
68 |
+
def get_projection_matrix(
|
69 |
+
fovy_deg: LIST_TYPE,
|
70 |
+
aspect_wh: float = 1.0,
|
71 |
+
near: float = 0.1,
|
72 |
+
far: float = 100.0,
|
73 |
+
device: Optional[str] = None,
|
74 |
+
) -> torch.FloatTensor:
|
75 |
+
fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device)
|
76 |
+
batch_size = fovy_deg.shape[0]
|
77 |
+
fovy = fovy_deg * math.pi / 180
|
78 |
+
tan_half_fovy = torch.tan(fovy / 2)
|
79 |
+
projection_matrix = torch.zeros(
|
80 |
+
batch_size, 4, 4, dtype=torch.float32, device=device
|
81 |
+
)
|
82 |
+
projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy)
|
83 |
+
projection_matrix[:, 1, 1] = -1 / tan_half_fovy
|
84 |
+
projection_matrix[:, 2, 2] = -(far + near) / (far - near)
|
85 |
+
projection_matrix[:, 2, 3] = -2 * far * near / (far - near)
|
86 |
+
projection_matrix[:, 3, 2] = -1
|
87 |
+
return projection_matrix
|
88 |
+
|
89 |
+
|
90 |
+
def get_orthogonal_projection_matrix(
|
91 |
+
batch_size: int,
|
92 |
+
left: float,
|
93 |
+
right: float,
|
94 |
+
bottom: float,
|
95 |
+
top: float,
|
96 |
+
near: float = 0.1,
|
97 |
+
far: float = 100.0,
|
98 |
+
device: Optional[str] = None,
|
99 |
+
) -> torch.FloatTensor:
|
100 |
+
projection_matrix = torch.zeros(
|
101 |
+
batch_size, 4, 4, dtype=torch.float32, device=device
|
102 |
+
)
|
103 |
+
projection_matrix[:, 0, 0] = 2 / (right - left)
|
104 |
+
projection_matrix[:, 1, 1] = -2 / (top - bottom)
|
105 |
+
projection_matrix[:, 2, 2] = -2 / (far - near)
|
106 |
+
projection_matrix[:, 0, 3] = -(right + left) / (right - left)
|
107 |
+
projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom)
|
108 |
+
projection_matrix[:, 2, 3] = -(far + near) / (far - near)
|
109 |
+
projection_matrix[:, 3, 3] = 1
|
110 |
+
return projection_matrix
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class Camera:
|
115 |
+
c2w: Optional[torch.FloatTensor]
|
116 |
+
w2c: torch.FloatTensor
|
117 |
+
proj_mtx: torch.FloatTensor
|
118 |
+
mvp_mtx: torch.FloatTensor
|
119 |
+
cam_pos: Optional[torch.FloatTensor]
|
120 |
+
|
121 |
+
def __getitem__(self, index):
|
122 |
+
if isinstance(index, int):
|
123 |
+
sl = slice(index, index + 1)
|
124 |
+
elif isinstance(index, slice):
|
125 |
+
sl = index
|
126 |
+
else:
|
127 |
+
raise NotImplementedError
|
128 |
+
|
129 |
+
return Camera(
|
130 |
+
c2w=self.c2w[sl] if self.c2w is not None else None,
|
131 |
+
w2c=self.w2c[sl],
|
132 |
+
proj_mtx=self.proj_mtx[sl],
|
133 |
+
mvp_mtx=self.mvp_mtx[sl],
|
134 |
+
cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None,
|
135 |
+
)
|
136 |
+
|
137 |
+
def to(self, device: Optional[str] = None):
|
138 |
+
if self.c2w is not None:
|
139 |
+
self.c2w = self.c2w.to(device)
|
140 |
+
self.w2c = self.w2c.to(device)
|
141 |
+
self.proj_mtx = self.proj_mtx.to(device)
|
142 |
+
self.mvp_mtx = self.mvp_mtx.to(device)
|
143 |
+
if self.cam_pos is not None:
|
144 |
+
self.cam_pos = self.cam_pos.to(device)
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return self.c2w.shape[0]
|
148 |
+
|
149 |
+
|
150 |
+
def get_camera(
|
151 |
+
elevation_deg: Optional[LIST_TYPE] = None,
|
152 |
+
distance: Optional[LIST_TYPE] = None,
|
153 |
+
fovy_deg: Optional[LIST_TYPE] = None,
|
154 |
+
azimuth_deg: Optional[LIST_TYPE] = None,
|
155 |
+
num_views: Optional[int] = 1,
|
156 |
+
c2w: Optional[torch.FloatTensor] = None,
|
157 |
+
w2c: Optional[torch.FloatTensor] = None,
|
158 |
+
proj_mtx: Optional[torch.FloatTensor] = None,
|
159 |
+
aspect_wh: float = 1.0,
|
160 |
+
near: float = 0.1,
|
161 |
+
far: float = 100.0,
|
162 |
+
device: Optional[str] = None,
|
163 |
+
):
|
164 |
+
if w2c is None:
|
165 |
+
if c2w is None:
|
166 |
+
c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
|
167 |
+
camera_positions = c2w[:, :3, 3]
|
168 |
+
w2c = torch.linalg.inv(c2w)
|
169 |
+
else:
|
170 |
+
camera_positions = None
|
171 |
+
c2w = None
|
172 |
+
if proj_mtx is None:
|
173 |
+
proj_mtx = get_projection_matrix(
|
174 |
+
fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device
|
175 |
+
)
|
176 |
+
mvp_mtx = proj_mtx @ w2c
|
177 |
+
return Camera(
|
178 |
+
c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
def get_orthogonal_camera(
|
183 |
+
elevation_deg: LIST_TYPE,
|
184 |
+
distance: LIST_TYPE,
|
185 |
+
left: float,
|
186 |
+
right: float,
|
187 |
+
bottom: float,
|
188 |
+
top: float,
|
189 |
+
azimuth_deg: Optional[LIST_TYPE] = None,
|
190 |
+
num_views: Optional[int] = 1,
|
191 |
+
near: float = 0.1,
|
192 |
+
far: float = 100.0,
|
193 |
+
device: Optional[str] = None,
|
194 |
+
):
|
195 |
+
c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
|
196 |
+
camera_positions = c2w[:, :3, 3]
|
197 |
+
w2c = torch.linalg.inv(c2w)
|
198 |
+
proj_mtx = get_orthogonal_projection_matrix(
|
199 |
+
batch_size=c2w.shape[0],
|
200 |
+
left=left,
|
201 |
+
right=right,
|
202 |
+
bottom=bottom,
|
203 |
+
top=top,
|
204 |
+
near=near,
|
205 |
+
far=far,
|
206 |
+
device=device,
|
207 |
+
)
|
208 |
+
mvp_mtx = proj_mtx @ w2c
|
209 |
+
return Camera(
|
210 |
+
c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
|
211 |
+
)
|
mvadapter/utils/geometry.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None):
|
9 |
+
"""Compute the position map from the depth map and the camera parameters for a batch of views.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
|
13 |
+
mask (torch.Tensor): The masks with the shape (B, H, W, 1).
|
14 |
+
intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3).
|
15 |
+
extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
|
16 |
+
image_wh (Tuple[int, int]): The image width and height.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
torch.Tensor: The position maps with the shape (B, H, W, 3).
|
20 |
+
"""
|
21 |
+
if image_wh is None:
|
22 |
+
image_wh = depth.shape[2], depth.shape[1]
|
23 |
+
|
24 |
+
B, H, W, _ = depth.shape
|
25 |
+
depth = depth.squeeze(-1)
|
26 |
+
|
27 |
+
u_coord, v_coord = torch.meshgrid(
|
28 |
+
torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy"
|
29 |
+
)
|
30 |
+
u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
|
31 |
+
v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
|
32 |
+
|
33 |
+
# Compute the position map by back-projecting depth pixels to 3D space
|
34 |
+
x = (
|
35 |
+
(u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1))
|
36 |
+
* depth
|
37 |
+
/ intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
|
38 |
+
)
|
39 |
+
y = (
|
40 |
+
(v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1))
|
41 |
+
* depth
|
42 |
+
/ intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
|
43 |
+
)
|
44 |
+
z = depth
|
45 |
+
|
46 |
+
# Concatenate to form the 3D coordinates in the camera frame
|
47 |
+
camera_coords = torch.stack([x, y, z], dim=-1)
|
48 |
+
|
49 |
+
# Apply the extrinsic matrix to get coordinates in the world frame
|
50 |
+
coords_homogeneous = torch.nn.functional.pad(
|
51 |
+
camera_coords, (0, 1), "constant", 1.0
|
52 |
+
) # Add a homogeneous coordinate
|
53 |
+
world_coords = torch.matmul(
|
54 |
+
coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
|
55 |
+
).view(B, H, W, 4)
|
56 |
+
|
57 |
+
# Apply the mask to the position map
|
58 |
+
position_map = world_coords[..., :3] * mask
|
59 |
+
|
60 |
+
return position_map
|
61 |
+
|
62 |
+
|
63 |
+
def get_position_map_from_depth_ortho(
|
64 |
+
depth, mask, extrinsics, ortho_scale, image_wh=None
|
65 |
+
):
|
66 |
+
"""Compute the position map from the depth map and the camera parameters for a batch of views
|
67 |
+
using orthographic projection with a given ortho_scale.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
|
71 |
+
mask (torch.Tensor): The masks with the shape (B, H, W, 1).
|
72 |
+
extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
|
73 |
+
ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1).
|
74 |
+
image_wh (Tuple[int, int]): Optional. The image width and height.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
torch.Tensor: The position maps with the shape (B, H, W, 3).
|
78 |
+
"""
|
79 |
+
if image_wh is None:
|
80 |
+
image_wh = depth.shape[2], depth.shape[1]
|
81 |
+
|
82 |
+
B, H, W, _ = depth.shape
|
83 |
+
depth = depth.squeeze(-1)
|
84 |
+
|
85 |
+
# Generating grid of coordinates in the image space
|
86 |
+
u_coord, v_coord = torch.meshgrid(
|
87 |
+
torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy"
|
88 |
+
)
|
89 |
+
u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
|
90 |
+
v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
|
91 |
+
|
92 |
+
# Compute the position map using orthographic projection with ortho_scale
|
93 |
+
x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0]
|
94 |
+
y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1]
|
95 |
+
z = depth
|
96 |
+
|
97 |
+
# Concatenate to form the 3D coordinates in the camera frame
|
98 |
+
camera_coords = torch.stack([x, y, z], dim=-1)
|
99 |
+
|
100 |
+
# Apply the extrinsic matrix to get coordinates in the world frame
|
101 |
+
coords_homogeneous = torch.nn.functional.pad(
|
102 |
+
camera_coords, (0, 1), "constant", 1.0
|
103 |
+
) # Add a homogeneous coordinate
|
104 |
+
world_coords = torch.matmul(
|
105 |
+
coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
|
106 |
+
).view(B, H, W, 4)
|
107 |
+
|
108 |
+
# Apply the mask to the position map
|
109 |
+
position_map = world_coords[..., :3] * mask
|
110 |
+
|
111 |
+
return position_map
|
112 |
+
|
113 |
+
|
114 |
+
def get_opencv_from_blender(matrix_world, fov=None, image_size=None):
|
115 |
+
# convert matrix_world to opencv format extrinsics
|
116 |
+
opencv_world_to_cam = matrix_world.inverse()
|
117 |
+
opencv_world_to_cam[1, :] *= -1
|
118 |
+
opencv_world_to_cam[2, :] *= -1
|
119 |
+
R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
|
120 |
+
|
121 |
+
if fov is None: # orthographic camera
|
122 |
+
return R, T
|
123 |
+
|
124 |
+
R, T = R.unsqueeze(0), T.unsqueeze(0)
|
125 |
+
# convert fov to opencv format intrinsics
|
126 |
+
focal = 1 / np.tan(fov / 2)
|
127 |
+
intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
|
128 |
+
opencv_cam_matrix = (
|
129 |
+
torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device)
|
130 |
+
)
|
131 |
+
opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to(
|
132 |
+
matrix_world.device
|
133 |
+
)
|
134 |
+
opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
|
135 |
+
|
136 |
+
return R, T, opencv_cam_matrix
|
137 |
+
|
138 |
+
|
139 |
+
def get_ray_directions(
|
140 |
+
H: int,
|
141 |
+
W: int,
|
142 |
+
focal: float,
|
143 |
+
principal: Optional[Tuple[float, float]] = None,
|
144 |
+
use_pixel_centers: bool = True,
|
145 |
+
) -> torch.Tensor:
|
146 |
+
"""
|
147 |
+
Get ray directions for all pixels in camera coordinate.
|
148 |
+
Args:
|
149 |
+
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
150 |
+
Outputs:
|
151 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
152 |
+
"""
|
153 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
154 |
+
cx, cy = W / 2, H / 2 if principal is None else principal
|
155 |
+
i, j = torch.meshgrid(
|
156 |
+
torch.arange(W, dtype=torch.float32) + pixel_center,
|
157 |
+
torch.arange(H, dtype=torch.float32) + pixel_center,
|
158 |
+
indexing="xy",
|
159 |
+
)
|
160 |
+
directions = torch.stack(
|
161 |
+
[(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1
|
162 |
+
)
|
163 |
+
return F.normalize(directions, dim=-1)
|
164 |
+
|
165 |
+
|
166 |
+
def get_rays(
|
167 |
+
directions: torch.Tensor, c2w: torch.Tensor
|
168 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
169 |
+
"""
|
170 |
+
Get ray origins and directions from camera coordinates to world coordinates
|
171 |
+
Args:
|
172 |
+
directions: (H, W, 3) ray directions in camera coordinates
|
173 |
+
c2w: (4, 4) camera-to-world transformation matrix
|
174 |
+
Outputs:
|
175 |
+
rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates
|
176 |
+
"""
|
177 |
+
# Rotate ray directions from camera coordinate to the world coordinate
|
178 |
+
rays_d = directions @ c2w[:3, :3].T
|
179 |
+
rays_o = c2w[:3, 3].expand(rays_d.shape)
|
180 |
+
return rays_o, rays_d
|
181 |
+
|
182 |
+
|
183 |
+
def compute_plucker_embed(
|
184 |
+
c2w: torch.Tensor, image_width: int, image_height: int, focal: float
|
185 |
+
) -> torch.Tensor:
|
186 |
+
"""
|
187 |
+
Computes Plucker coordinates for a camera.
|
188 |
+
Args:
|
189 |
+
c2w: (4, 4) camera-to-world transformation matrix
|
190 |
+
image_width: Image width
|
191 |
+
image_height: Image height
|
192 |
+
focal: Focal length of the camera
|
193 |
+
Returns:
|
194 |
+
plucker: (6, H, W) Plucker embedding
|
195 |
+
"""
|
196 |
+
directions = get_ray_directions(image_height, image_width, focal)
|
197 |
+
rays_o, rays_d = get_rays(directions, c2w)
|
198 |
+
# Cross product to get Plucker coordinates
|
199 |
+
cross = torch.cross(rays_o, rays_d, dim=-1)
|
200 |
+
plucker = torch.cat((rays_d, cross), dim=-1)
|
201 |
+
return plucker.permute(2, 0, 1)
|
202 |
+
|
203 |
+
|
204 |
+
def get_plucker_embeds_from_cameras(
|
205 |
+
c2w: List[torch.Tensor], fov: List[float], image_size: int
|
206 |
+
) -> torch.Tensor:
|
207 |
+
"""
|
208 |
+
Given lists of camera transformations and fov, returns the batched plucker embeddings.
|
209 |
+
Args:
|
210 |
+
c2w: list of camera-to-world transformation matrices
|
211 |
+
fov: list of field of view values
|
212 |
+
image_size: size of the image
|
213 |
+
Returns:
|
214 |
+
plucker_embeds: (B, 6, H, W) batched plucker embeddings
|
215 |
+
"""
|
216 |
+
plucker_embeds = []
|
217 |
+
for cam_matrix, cam_fov in zip(c2w, fov):
|
218 |
+
focal = 0.5 * image_size / np.tan(0.5 * cam_fov)
|
219 |
+
plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal)
|
220 |
+
plucker_embeds.append(plucker)
|
221 |
+
return torch.stack(plucker_embeds)
|
222 |
+
|
223 |
+
|
224 |
+
def get_plucker_embeds_from_cameras_ortho(
|
225 |
+
c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int
|
226 |
+
):
|
227 |
+
"""
|
228 |
+
Given lists of camera transformations and fov, returns the batched plucker embeddings.
|
229 |
+
|
230 |
+
Parameters:
|
231 |
+
c2w: list of camera-to-world transformation matrices
|
232 |
+
fov: list of field of view values
|
233 |
+
image_size: size of the image
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
plucker_embeds: plucker embeddings (B, 6, H, W)
|
237 |
+
"""
|
238 |
+
plucker_embeds = []
|
239 |
+
# compute pairwise mask and plucker embeddings
|
240 |
+
for cam_matrix, scale in zip(c2w, ortho_scale):
|
241 |
+
# blender to opencv to pytorch3d
|
242 |
+
R, T = get_opencv_from_blender(cam_matrix)
|
243 |
+
cam_pos = -R.T @ T
|
244 |
+
view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device)
|
245 |
+
# normalize camera position
|
246 |
+
cam_pos = F.normalize(cam_pos, dim=0)
|
247 |
+
plucker = torch.concat([view_dir, cam_pos])
|
248 |
+
plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size)
|
249 |
+
plucker_embeds.append(plucker)
|
250 |
+
|
251 |
+
plucker_embeds = torch.stack(plucker_embeds)
|
252 |
+
|
253 |
+
return plucker_embeds
|
mvadapter/utils/logging.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Logging utilities."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import threading
|
21 |
+
from logging import CRITICAL # NOQA
|
22 |
+
from logging import DEBUG # NOQA
|
23 |
+
from logging import ERROR # NOQA
|
24 |
+
from logging import FATAL # NOQA
|
25 |
+
from logging import INFO # NOQA
|
26 |
+
from logging import NOTSET # NOQA
|
27 |
+
from logging import WARN # NOQA
|
28 |
+
from logging import WARNING # NOQA
|
29 |
+
from typing import Dict, Optional
|
30 |
+
|
31 |
+
from tqdm import auto as tqdm_lib
|
32 |
+
|
33 |
+
_lock = threading.Lock()
|
34 |
+
_default_handler: Optional[logging.Handler] = None
|
35 |
+
|
36 |
+
log_levels = {
|
37 |
+
"debug": logging.DEBUG,
|
38 |
+
"info": logging.INFO,
|
39 |
+
"warning": logging.WARNING,
|
40 |
+
"error": logging.ERROR,
|
41 |
+
"critical": logging.CRITICAL,
|
42 |
+
}
|
43 |
+
|
44 |
+
_default_log_level = logging.INFO
|
45 |
+
|
46 |
+
_tqdm_active = True
|
47 |
+
|
48 |
+
|
49 |
+
def _get_default_logging_level() -> int:
|
50 |
+
"""
|
51 |
+
If LATEXTURE_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
52 |
+
not - fall back to `_default_log_level`
|
53 |
+
"""
|
54 |
+
env_level_str = os.getenv("LATEXTURE_VERBOSITY", None)
|
55 |
+
if env_level_str:
|
56 |
+
if env_level_str in log_levels:
|
57 |
+
return log_levels[env_level_str]
|
58 |
+
else:
|
59 |
+
logging.getLogger().warning(
|
60 |
+
f"Unknown option LATEXTURE_VERBOSITY={env_level_str}, "
|
61 |
+
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
62 |
+
)
|
63 |
+
return _default_log_level
|
64 |
+
|
65 |
+
|
66 |
+
def _get_library_name() -> str:
|
67 |
+
return __name__.split(".")[0]
|
68 |
+
|
69 |
+
|
70 |
+
def _get_library_root_logger() -> logging.Logger:
|
71 |
+
return logging.getLogger(_get_library_name())
|
72 |
+
|
73 |
+
|
74 |
+
def _configure_library_root_logger() -> None:
|
75 |
+
global _default_handler
|
76 |
+
|
77 |
+
with _lock:
|
78 |
+
if _default_handler:
|
79 |
+
# This library has already configured the library root logger.
|
80 |
+
return
|
81 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
82 |
+
_default_handler.flush = sys.stderr.flush
|
83 |
+
|
84 |
+
# Apply our default configuration to the library root logger.
|
85 |
+
library_root_logger = _get_library_root_logger()
|
86 |
+
library_root_logger.addHandler(_default_handler)
|
87 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
88 |
+
library_root_logger.propagate = False
|
89 |
+
|
90 |
+
enable_explicit_format()
|
91 |
+
|
92 |
+
|
93 |
+
def _reset_library_root_logger() -> None:
|
94 |
+
global _default_handler
|
95 |
+
|
96 |
+
with _lock:
|
97 |
+
if not _default_handler:
|
98 |
+
return
|
99 |
+
|
100 |
+
library_root_logger = _get_library_root_logger()
|
101 |
+
library_root_logger.removeHandler(_default_handler)
|
102 |
+
library_root_logger.setLevel(logging.NOTSET)
|
103 |
+
_default_handler = None
|
104 |
+
|
105 |
+
|
106 |
+
def get_log_levels_dict() -> Dict[str, int]:
|
107 |
+
return log_levels
|
108 |
+
|
109 |
+
|
110 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
111 |
+
"""
|
112 |
+
Return a logger with the specified name.
|
113 |
+
|
114 |
+
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
115 |
+
"""
|
116 |
+
|
117 |
+
if name is None:
|
118 |
+
name = _get_library_name()
|
119 |
+
|
120 |
+
_configure_library_root_logger()
|
121 |
+
return logging.getLogger(name)
|
122 |
+
|
123 |
+
|
124 |
+
def get_verbosity() -> int:
|
125 |
+
"""
|
126 |
+
Return the current level for the 🤗 Diffusers' root logger as an `int`.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
`int`:
|
130 |
+
Logging level integers which can be one of:
|
131 |
+
|
132 |
+
- `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
133 |
+
- `40`: `diffusers.logging.ERROR`
|
134 |
+
- `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
135 |
+
- `20`: `diffusers.logging.INFO`
|
136 |
+
- `10`: `diffusers.logging.DEBUG`
|
137 |
+
|
138 |
+
"""
|
139 |
+
|
140 |
+
_configure_library_root_logger()
|
141 |
+
return _get_library_root_logger().getEffectiveLevel()
|
142 |
+
|
143 |
+
|
144 |
+
def set_verbosity(verbosity: int) -> None:
|
145 |
+
"""
|
146 |
+
Set the verbosity level for the 🤗 Diffusers' root logger.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
verbosity (`int`):
|
150 |
+
Logging level which can be one of:
|
151 |
+
|
152 |
+
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
153 |
+
- `diffusers.logging.ERROR`
|
154 |
+
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
155 |
+
- `diffusers.logging.INFO`
|
156 |
+
- `diffusers.logging.DEBUG`
|
157 |
+
"""
|
158 |
+
|
159 |
+
_configure_library_root_logger()
|
160 |
+
_get_library_root_logger().setLevel(verbosity)
|
161 |
+
|
162 |
+
|
163 |
+
def set_verbosity_info() -> None:
|
164 |
+
"""Set the verbosity to the `INFO` level."""
|
165 |
+
return set_verbosity(INFO)
|
166 |
+
|
167 |
+
|
168 |
+
def set_verbosity_warning() -> None:
|
169 |
+
"""Set the verbosity to the `WARNING` level."""
|
170 |
+
return set_verbosity(WARNING)
|
171 |
+
|
172 |
+
|
173 |
+
def set_verbosity_debug() -> None:
|
174 |
+
"""Set the verbosity to the `DEBUG` level."""
|
175 |
+
return set_verbosity(DEBUG)
|
176 |
+
|
177 |
+
|
178 |
+
def set_verbosity_error() -> None:
|
179 |
+
"""Set the verbosity to the `ERROR` level."""
|
180 |
+
return set_verbosity(ERROR)
|
181 |
+
|
182 |
+
|
183 |
+
def disable_default_handler() -> None:
|
184 |
+
"""Disable the default handler of the 🤗 Diffusers' root logger."""
|
185 |
+
|
186 |
+
_configure_library_root_logger()
|
187 |
+
|
188 |
+
assert _default_handler is not None
|
189 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
190 |
+
|
191 |
+
|
192 |
+
def enable_default_handler() -> None:
|
193 |
+
"""Enable the default handler of the 🤗 Diffusers' root logger."""
|
194 |
+
|
195 |
+
_configure_library_root_logger()
|
196 |
+
|
197 |
+
assert _default_handler is not None
|
198 |
+
_get_library_root_logger().addHandler(_default_handler)
|
199 |
+
|
200 |
+
|
201 |
+
def add_handler(handler: logging.Handler) -> None:
|
202 |
+
"""adds a handler to the HuggingFace Diffusers' root logger."""
|
203 |
+
|
204 |
+
_configure_library_root_logger()
|
205 |
+
|
206 |
+
assert handler is not None
|
207 |
+
_get_library_root_logger().addHandler(handler)
|
208 |
+
|
209 |
+
|
210 |
+
def remove_handler(handler: logging.Handler) -> None:
|
211 |
+
"""removes given handler from the HuggingFace Diffusers' root logger."""
|
212 |
+
|
213 |
+
_configure_library_root_logger()
|
214 |
+
|
215 |
+
assert handler is not None and handler in _get_library_root_logger().handlers
|
216 |
+
_get_library_root_logger().removeHandler(handler)
|
217 |
+
|
218 |
+
|
219 |
+
def disable_propagation() -> None:
|
220 |
+
"""
|
221 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
222 |
+
"""
|
223 |
+
|
224 |
+
_configure_library_root_logger()
|
225 |
+
_get_library_root_logger().propagate = False
|
226 |
+
|
227 |
+
|
228 |
+
def enable_propagation() -> None:
|
229 |
+
"""
|
230 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
|
231 |
+
double logging if the root logger has been configured.
|
232 |
+
"""
|
233 |
+
|
234 |
+
_configure_library_root_logger()
|
235 |
+
_get_library_root_logger().propagate = True
|
236 |
+
|
237 |
+
|
238 |
+
def enable_explicit_format() -> None:
|
239 |
+
"""
|
240 |
+
Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
|
241 |
+
```
|
242 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
243 |
+
```
|
244 |
+
All handlers currently bound to the root logger are affected by this method.
|
245 |
+
"""
|
246 |
+
handlers = _get_library_root_logger().handlers
|
247 |
+
|
248 |
+
for handler in handlers:
|
249 |
+
formatter = logging.Formatter(
|
250 |
+
"[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
|
251 |
+
)
|
252 |
+
handler.setFormatter(formatter)
|
253 |
+
|
254 |
+
|
255 |
+
def reset_format() -> None:
|
256 |
+
"""
|
257 |
+
Resets the formatting for 🤗 Diffusers' loggers.
|
258 |
+
|
259 |
+
All handlers currently bound to the root logger are affected by this method.
|
260 |
+
"""
|
261 |
+
handlers = _get_library_root_logger().handlers
|
262 |
+
|
263 |
+
for handler in handlers:
|
264 |
+
handler.setFormatter(None)
|
265 |
+
|
266 |
+
|
267 |
+
def warning_advice(self, *args, **kwargs) -> None:
|
268 |
+
"""
|
269 |
+
This method is identical to `logger.warning()`, but if env var LATEXTURE_NO_ADVISORY_WARNINGS=1 is set, this
|
270 |
+
warning will not be printed
|
271 |
+
"""
|
272 |
+
no_advisory_warnings = os.getenv("LATEXTURE_NO_ADVISORY_WARNINGS", False)
|
273 |
+
if no_advisory_warnings:
|
274 |
+
return
|
275 |
+
self.warning(*args, **kwargs)
|
276 |
+
|
277 |
+
|
278 |
+
logging.Logger.warning_advice = warning_advice
|
279 |
+
|
280 |
+
|
281 |
+
class EmptyTqdm:
|
282 |
+
"""Dummy tqdm which doesn't do anything."""
|
283 |
+
|
284 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
285 |
+
self._iterator = args[0] if args else None
|
286 |
+
|
287 |
+
def __iter__(self):
|
288 |
+
return iter(self._iterator)
|
289 |
+
|
290 |
+
def __getattr__(self, _):
|
291 |
+
"""Return empty function."""
|
292 |
+
|
293 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
294 |
+
return
|
295 |
+
|
296 |
+
return empty_fn
|
297 |
+
|
298 |
+
def __enter__(self):
|
299 |
+
return self
|
300 |
+
|
301 |
+
def __exit__(self, type_, value, traceback):
|
302 |
+
return
|
303 |
+
|
304 |
+
|
305 |
+
class _tqdm_cls:
|
306 |
+
def __call__(self, *args, **kwargs):
|
307 |
+
if _tqdm_active:
|
308 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
309 |
+
else:
|
310 |
+
return EmptyTqdm(*args, **kwargs)
|
311 |
+
|
312 |
+
def set_lock(self, *args, **kwargs):
|
313 |
+
self._lock = None
|
314 |
+
if _tqdm_active:
|
315 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
316 |
+
|
317 |
+
def get_lock(self):
|
318 |
+
if _tqdm_active:
|
319 |
+
return tqdm_lib.tqdm.get_lock()
|
320 |
+
|
321 |
+
|
322 |
+
tqdm = _tqdm_cls()
|
323 |
+
|
324 |
+
|
325 |
+
def is_progress_bar_enabled() -> bool:
|
326 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
327 |
+
global _tqdm_active
|
328 |
+
return bool(_tqdm_active)
|
329 |
+
|
330 |
+
|
331 |
+
def enable_progress_bar() -> None:
|
332 |
+
"""Enable tqdm progress bar."""
|
333 |
+
global _tqdm_active
|
334 |
+
_tqdm_active = True
|
335 |
+
|
336 |
+
|
337 |
+
def disable_progress_bar() -> None:
|
338 |
+
"""Disable tqdm progress bar."""
|
339 |
+
global _tqdm_active
|
340 |
+
_tqdm_active = False
|
mvadapter/utils/render.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from datetime import datetime
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import nvdiffrast.torch as dr
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import trimesh
|
12 |
+
from PIL import Image
|
13 |
+
from torch import BoolTensor, FloatTensor
|
14 |
+
|
15 |
+
from . import logging
|
16 |
+
from .camera import Camera
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def dot(x: torch.FloatTensor, y: torch.FloatTensor) -> torch.FloatTensor:
|
22 |
+
return torch.sum(x * y, -1, keepdim=True)
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class TexturedMesh:
|
27 |
+
v_pos: torch.FloatTensor
|
28 |
+
t_pos_idx: torch.LongTensor
|
29 |
+
|
30 |
+
# texture coordinates
|
31 |
+
v_tex: Optional[torch.FloatTensor] = None
|
32 |
+
t_tex_idx: Optional[torch.LongTensor] = None
|
33 |
+
|
34 |
+
# texture map
|
35 |
+
texture: Optional[torch.FloatTensor] = None
|
36 |
+
|
37 |
+
# vertices, faces after vertex merging
|
38 |
+
_stitched_v_pos: Optional[torch.FloatTensor] = None
|
39 |
+
_stitched_t_pos_idx: Optional[torch.LongTensor] = None
|
40 |
+
|
41 |
+
_v_nrm: Optional[torch.FloatTensor] = None
|
42 |
+
|
43 |
+
@property
|
44 |
+
def v_nrm(self) -> torch.FloatTensor:
|
45 |
+
if self._v_nrm is None:
|
46 |
+
self._v_nrm = self._compute_vertex_normal()
|
47 |
+
return self._v_nrm
|
48 |
+
|
49 |
+
def set_stitched_mesh(
|
50 |
+
self, v_pos: torch.FloatTensor, t_pos_idx: torch.LongTensor
|
51 |
+
) -> None:
|
52 |
+
self._stitched_v_pos = v_pos
|
53 |
+
self._stitched_t_pos_idx = t_pos_idx
|
54 |
+
|
55 |
+
@property
|
56 |
+
def stitched_v_pos(self) -> torch.FloatTensor:
|
57 |
+
if self._stitched_v_pos is None:
|
58 |
+
logger.warning("Stitched vertices not available, using original vertices!")
|
59 |
+
return self.v_pos
|
60 |
+
return self._stitched_v_pos
|
61 |
+
|
62 |
+
@property
|
63 |
+
def stitched_t_pos_idx(self) -> torch.LongTensor:
|
64 |
+
if self._stitched_t_pos_idx is None:
|
65 |
+
logger.warning("Stitched faces not available, using original faces!")
|
66 |
+
return self.t_pos_idx
|
67 |
+
return self._stitched_t_pos_idx
|
68 |
+
|
69 |
+
def _compute_vertex_normal(self) -> torch.FloatTensor:
|
70 |
+
if self._stitched_v_pos is None or self._stitched_t_pos_idx is None:
|
71 |
+
logger.warning(
|
72 |
+
"Stitched vertices and faces not available, computing vertex normals on original mesh, which can be erroneous!"
|
73 |
+
)
|
74 |
+
v_pos, t_pos_idx = self.v_pos, self.t_pos_idx
|
75 |
+
else:
|
76 |
+
v_pos, t_pos_idx = self._stitched_v_pos, self._stitched_t_pos_idx
|
77 |
+
|
78 |
+
i0 = t_pos_idx[:, 0]
|
79 |
+
i1 = t_pos_idx[:, 1]
|
80 |
+
i2 = t_pos_idx[:, 2]
|
81 |
+
|
82 |
+
v0 = v_pos[i0, :]
|
83 |
+
v1 = v_pos[i1, :]
|
84 |
+
v2 = v_pos[i2, :]
|
85 |
+
|
86 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
87 |
+
|
88 |
+
# Splat face normals to vertices
|
89 |
+
v_nrm = torch.zeros_like(v_pos)
|
90 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
91 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
92 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
93 |
+
|
94 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
95 |
+
v_nrm = torch.where(
|
96 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
97 |
+
)
|
98 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
99 |
+
|
100 |
+
if torch.is_anomaly_enabled():
|
101 |
+
assert torch.all(torch.isfinite(v_nrm))
|
102 |
+
|
103 |
+
return v_nrm
|
104 |
+
|
105 |
+
def to(self, device: Optional[str] = None):
|
106 |
+
self.v_pos = self.v_pos.to(device)
|
107 |
+
self.t_pos_idx = self.t_pos_idx.to(device)
|
108 |
+
if self.v_tex is not None:
|
109 |
+
self.v_tex = self.v_tex.to(device)
|
110 |
+
if self.t_tex_idx is not None:
|
111 |
+
self.t_tex_idx = self.t_tex_idx.to(device)
|
112 |
+
if self.texture is not None:
|
113 |
+
self.texture = self.texture.to(device)
|
114 |
+
if self._stitched_v_pos is not None:
|
115 |
+
self._stitched_v_pos = self._stitched_v_pos.to(device)
|
116 |
+
if self._stitched_t_pos_idx is not None:
|
117 |
+
self._stitched_t_pos_idx = self._stitched_t_pos_idx.to(device)
|
118 |
+
if self._v_nrm is not None:
|
119 |
+
self._v_nrm = self._v_nrm.to(device)
|
120 |
+
|
121 |
+
|
122 |
+
def load_mesh(
|
123 |
+
mesh_path: str,
|
124 |
+
rescale: bool = False,
|
125 |
+
move_to_center: bool = False,
|
126 |
+
scale: float = 0.5,
|
127 |
+
flip_uv: bool = True,
|
128 |
+
merge_vertices: bool = True,
|
129 |
+
default_uv_size: int = 2048,
|
130 |
+
shape_init_mesh_up: str = "+y",
|
131 |
+
shape_init_mesh_front: str = "+x",
|
132 |
+
front_x_to_y: bool = False,
|
133 |
+
device: Optional[str] = None,
|
134 |
+
return_transform: bool = False,
|
135 |
+
) -> TexturedMesh:
|
136 |
+
scene = trimesh.load(mesh_path, force="mesh", process=False)
|
137 |
+
if isinstance(scene, trimesh.Trimesh):
|
138 |
+
mesh = scene
|
139 |
+
elif isinstance(scene, trimesh.scene.Scene):
|
140 |
+
mesh = trimesh.Trimesh()
|
141 |
+
for obj in scene.geometry.values():
|
142 |
+
mesh = trimesh.util.concatenate([mesh, obj])
|
143 |
+
else:
|
144 |
+
raise ValueError(f"Unknown mesh type at {mesh_path}.")
|
145 |
+
|
146 |
+
# move to center
|
147 |
+
if move_to_center:
|
148 |
+
centroid = mesh.vertices.mean(0)
|
149 |
+
mesh.vertices = mesh.vertices - centroid
|
150 |
+
|
151 |
+
# rescale
|
152 |
+
if rescale:
|
153 |
+
max_scale = np.abs(mesh.vertices).max()
|
154 |
+
mesh.vertices = mesh.vertices / max_scale * scale
|
155 |
+
|
156 |
+
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
157 |
+
dir2vec = {
|
158 |
+
"+x": np.array([1, 0, 0]),
|
159 |
+
"+y": np.array([0, 1, 0]),
|
160 |
+
"+z": np.array([0, 0, 1]),
|
161 |
+
"-x": np.array([-1, 0, 0]),
|
162 |
+
"-y": np.array([0, -1, 0]),
|
163 |
+
"-z": np.array([0, 0, -1]),
|
164 |
+
}
|
165 |
+
if shape_init_mesh_up not in dirs or shape_init_mesh_front not in dirs:
|
166 |
+
raise ValueError(
|
167 |
+
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
168 |
+
)
|
169 |
+
if shape_init_mesh_up[1] == shape_init_mesh_front[1]:
|
170 |
+
raise ValueError(
|
171 |
+
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
172 |
+
)
|
173 |
+
z_, x_ = (
|
174 |
+
dir2vec[shape_init_mesh_up],
|
175 |
+
dir2vec[shape_init_mesh_front],
|
176 |
+
)
|
177 |
+
y_ = np.cross(z_, x_)
|
178 |
+
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
179 |
+
mesh2std = np.linalg.inv(std2mesh)
|
180 |
+
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
181 |
+
if front_x_to_y:
|
182 |
+
x = mesh.vertices[:, 1].copy()
|
183 |
+
y = -mesh.vertices[:, 0].copy()
|
184 |
+
mesh.vertices[:, 0] = x
|
185 |
+
mesh.vertices[:, 1] = y
|
186 |
+
|
187 |
+
v_pos = torch.tensor(mesh.vertices, dtype=torch.float32)
|
188 |
+
t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64)
|
189 |
+
|
190 |
+
if hasattr(mesh, "visual") and hasattr(mesh.visual, "uv"):
|
191 |
+
v_tex = torch.tensor(mesh.visual.uv, dtype=torch.float32)
|
192 |
+
if flip_uv:
|
193 |
+
v_tex[:, 1] = 1.0 - v_tex[:, 1]
|
194 |
+
t_tex_idx = t_pos_idx.clone()
|
195 |
+
if (
|
196 |
+
hasattr(mesh.visual.material, "baseColorTexture")
|
197 |
+
and mesh.visual.material.baseColorTexture
|
198 |
+
):
|
199 |
+
texture = torch.tensor(
|
200 |
+
np.array(mesh.visual.material.baseColorTexture) / 255.0,
|
201 |
+
dtype=torch.float32,
|
202 |
+
)[..., :3]
|
203 |
+
else:
|
204 |
+
texture = torch.zeros(
|
205 |
+
(default_uv_size, default_uv_size, 3), dtype=torch.float32
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
v_tex = None
|
209 |
+
t_tex_idx = None
|
210 |
+
texture = None
|
211 |
+
|
212 |
+
textured_mesh = TexturedMesh(
|
213 |
+
v_pos=v_pos,
|
214 |
+
t_pos_idx=t_pos_idx,
|
215 |
+
v_tex=v_tex,
|
216 |
+
t_tex_idx=t_tex_idx,
|
217 |
+
texture=texture,
|
218 |
+
)
|
219 |
+
|
220 |
+
if merge_vertices:
|
221 |
+
mesh.merge_vertices(merge_tex=True)
|
222 |
+
textured_mesh.set_stitched_mesh(
|
223 |
+
torch.tensor(mesh.vertices, dtype=torch.float32),
|
224 |
+
torch.tensor(mesh.faces, dtype=torch.int64),
|
225 |
+
)
|
226 |
+
|
227 |
+
textured_mesh.to(device)
|
228 |
+
|
229 |
+
if return_transform:
|
230 |
+
return textured_mesh, np.array(centroid), max_scale / scale
|
231 |
+
|
232 |
+
return textured_mesh
|
233 |
+
|
234 |
+
|
235 |
+
@dataclass
|
236 |
+
class RenderOutput:
|
237 |
+
attr: Optional[torch.FloatTensor] = None
|
238 |
+
mask: Optional[torch.BoolTensor] = None
|
239 |
+
depth: Optional[torch.FloatTensor] = None
|
240 |
+
normal: Optional[torch.FloatTensor] = None
|
241 |
+
pos: Optional[torch.FloatTensor] = None
|
242 |
+
|
243 |
+
|
244 |
+
class NVDiffRastContextWrapper:
|
245 |
+
def __init__(self, device: str, context_type: str = "gl"):
|
246 |
+
if context_type == "gl":
|
247 |
+
self.ctx = dr.RasterizeGLContext(device=device)
|
248 |
+
elif context_type == "cuda":
|
249 |
+
self.ctx = dr.RasterizeCudaContext(device=device)
|
250 |
+
else:
|
251 |
+
raise NotImplementedError
|
252 |
+
|
253 |
+
def rasterize(self, pos, tri, resolution, ranges=None, grad_db=True):
|
254 |
+
"""
|
255 |
+
Rasterize triangles.
|
256 |
+
|
257 |
+
All input tensors must be contiguous and reside in GPU memory except for the ranges tensor that, if specified, has to reside in CPU memory. The output tensors will be contiguous and reside in GPU memory.
|
258 |
+
|
259 |
+
Arguments:
|
260 |
+
glctx Rasterizer context of type RasterizeGLContext or RasterizeCudaContext.
|
261 |
+
pos Vertex position tensor with dtype torch.float32. To enable range mode, this tensor should have a 2D shape [num_vertices, 4]. To enable instanced mode, use a 3D shape [minibatch_size, num_vertices, 4].
|
262 |
+
tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32.
|
263 |
+
resolution Output resolution as integer tuple (height, width).
|
264 |
+
ranges In range mode, tensor with shape [minibatch_size, 2] and dtype torch.int32, specifying start indices and counts into tri. Ignored in instanced mode.
|
265 |
+
grad_db Propagate gradients of image-space derivatives of barycentrics into pos in backward pass. Ignored if using an OpenGL context that was not configured to output image-space derivatives.
|
266 |
+
Returns:
|
267 |
+
A tuple of two tensors. The first output tensor has shape [minibatch_size, height, width, 4] and contains the main rasterizer output in order (u, v, z/w, triangle_id). If the OpenGL context was configured to output image-space derivatives of barycentrics, the second output tensor will also have shape [minibatch_size, height, width, 4] and contain said derivatives in order (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape [minibatch_size, height, width, 0].
|
268 |
+
"""
|
269 |
+
return dr.rasterize(
|
270 |
+
self.ctx, pos.float(), tri.int(), resolution, ranges, grad_db
|
271 |
+
)
|
272 |
+
|
273 |
+
def interpolate(self, attr, rast, tri, rast_db=None, diff_attrs=None):
|
274 |
+
"""
|
275 |
+
Interpolate vertex attributes.
|
276 |
+
|
277 |
+
All input tensors must be contiguous and reside in GPU memory. The output tensors will be contiguous and reside in GPU memory.
|
278 |
+
|
279 |
+
Arguments:
|
280 |
+
attr Attribute tensor with dtype torch.float32. Shape is [num_vertices, num_attributes] in range mode, or [minibatch_size, num_vertices, num_attributes] in instanced mode. Broadcasting is supported along the minibatch axis.
|
281 |
+
rast Main output tensor from rasterize().
|
282 |
+
tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32.
|
283 |
+
rast_db (Optional) Tensor containing image-space derivatives of barycentrics, i.e., the second output tensor from rasterize(). Enables computing image-space derivatives of attributes.
|
284 |
+
diff_attrs (Optional) List of attribute indices for which image-space derivatives are to be computed. Special value 'all' is equivalent to list [0, 1, ..., num_attributes - 1].
|
285 |
+
Returns:
|
286 |
+
A tuple of two tensors. The first output tensor contains interpolated attributes and has shape [minibatch_size, height, width, num_attributes]. If rast_db and diff_attrs were specified, the second output tensor contains the image-space derivatives of the selected attributes and has shape [minibatch_size, height, width, 2 * len(diff_attrs)]. The derivatives of the first selected attribute A will be on channels 0 and 1 as (dA/dX, dA/dY), etc. Otherwise, the second output tensor will be an empty tensor with shape [minibatch_size, height, width, 0].
|
287 |
+
"""
|
288 |
+
return dr.interpolate(attr.float(), rast, tri.int(), rast_db, diff_attrs)
|
289 |
+
|
290 |
+
def texture(
|
291 |
+
self,
|
292 |
+
tex,
|
293 |
+
uv,
|
294 |
+
uv_da=None,
|
295 |
+
mip_level_bias=None,
|
296 |
+
mip=None,
|
297 |
+
filter_mode="auto",
|
298 |
+
boundary_mode="wrap",
|
299 |
+
max_mip_level=None,
|
300 |
+
):
|
301 |
+
"""
|
302 |
+
Perform texture sampling.
|
303 |
+
|
304 |
+
All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory.
|
305 |
+
|
306 |
+
Arguments:
|
307 |
+
tex Texture tensor with dtype torch.float32. For 2D textures, must have shape [minibatch_size, tex_height, tex_width, tex_channels]. For cube map textures, must have shape [minibatch_size, 6, tex_height, tex_width, tex_channels] where tex_width and tex_height are equal. Note that boundary_mode must also be set to 'cube' to enable cube map mode. Broadcasting is supported along the minibatch axis.
|
308 |
+
uv Tensor containing per-pixel texture coordinates. When sampling a 2D texture, must have shape [minibatch_size, height, width, 2]. When sampling a cube map texture, must have shape [minibatch_size, height, width, 3].
|
309 |
+
uv_da (Optional) Tensor containing image-space derivatives of texture coordinates. Must have same shape as uv except for the last dimension that is to be twice as long.
|
310 |
+
mip_level_bias (Optional) Per-pixel bias for mip level selection. If uv_da is omitted, determines mip level directly. Must have shape [minibatch_size, height, width].
|
311 |
+
mip (Optional) Preconstructed mipmap stack from a texture_construct_mip() call, or a list of tensors specifying a custom mipmap stack. When specifying a custom mipmap stack, the tensors in the list must follow the same format as tex except for width and height that must follow the usual rules for mipmap sizes. The base level texture is still supplied in tex and must not be included in the list. Gradients of a custom mipmap stack are not automatically propagated to base texture but the mipmap tensors will receive gradients of their own. If a mipmap stack is not specified but the chosen filter mode requires it, the mipmap stack is constructed internally and discarded afterwards.
|
312 |
+
filter_mode Texture filtering mode to be used. Valid values are 'auto', 'nearest', 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' selects 'linear' if neither uv_da or mip_level_bias is specified, and 'linear-mipmap-linear' when at least one of them is specified, these being the highest-quality modes possible depending on the availability of the image-space derivatives of the texture coordinates or direct mip level information.
|
313 |
+
boundary_mode Valid values are 'wrap', 'clamp', 'zero', and 'cube'. If tex defines a cube map, this must be set to 'cube'. The default mode 'wrap' takes fractional part of texture coordinates. Mode 'clamp' clamps texture coordinates to the centers of the boundary texels. Mode 'zero' virtually extends the texture with all-zero values in all directions.
|
314 |
+
max_mip_level If specified, limits the number of mipmaps constructed and used in mipmap-based filter modes.
|
315 |
+
Returns:
|
316 |
+
A tensor containing the results of the texture sampling with shape [minibatch_size, height, width, tex_channels]. Cube map fetches with invalid uv coordinates (e.g., zero vectors) output all zeros and do not propagate gradients.
|
317 |
+
"""
|
318 |
+
return dr.texture(
|
319 |
+
tex.float(),
|
320 |
+
uv.float(),
|
321 |
+
uv_da,
|
322 |
+
mip_level_bias,
|
323 |
+
mip,
|
324 |
+
filter_mode,
|
325 |
+
boundary_mode,
|
326 |
+
max_mip_level,
|
327 |
+
)
|
328 |
+
|
329 |
+
def antialias(
|
330 |
+
self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0
|
331 |
+
):
|
332 |
+
"""
|
333 |
+
Perform antialiasing.
|
334 |
+
|
335 |
+
All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory.
|
336 |
+
|
337 |
+
Note that silhouette edge determination is based on vertex indices in the triangle tensor. For it to work properly, a vertex belonging to multiple triangles must be referred to using the same vertex index in each triangle. Otherwise, nvdiffrast will always classify the adjacent edges as silhouette edges, which leads to bad performance and potentially incorrect gradients. If you are unsure whether your data is good, check which pixels are modified by the antialias operation and compare to the example in the documentation.
|
338 |
+
|
339 |
+
Arguments:
|
340 |
+
color Input image to antialias with shape [minibatch_size, height, width, num_channels].
|
341 |
+
rast Main output tensor from rasterize().
|
342 |
+
pos Vertex position tensor used in the rasterization operation.
|
343 |
+
tri Triangle tensor used in the rasterization operation.
|
344 |
+
topology_hash (Optional) Preconstructed topology hash for the triangle tensor. If not specified, the topology hash is constructed internally and discarded afterwards.
|
345 |
+
pos_gradient_boost (Optional) Multiplier for gradients propagated to pos.
|
346 |
+
Returns:
|
347 |
+
A tensor containing the antialiased image with the same shape as color input tensor.
|
348 |
+
"""
|
349 |
+
return dr.antialias(
|
350 |
+
color.float(),
|
351 |
+
rast,
|
352 |
+
pos.float(),
|
353 |
+
tri.int(),
|
354 |
+
topology_hash,
|
355 |
+
pos_gradient_boost,
|
356 |
+
)
|
357 |
+
|
358 |
+
|
359 |
+
def get_clip_space_position(pos: torch.FloatTensor, mvp_mtx: torch.FloatTensor):
|
360 |
+
pos_homo = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos)], dim=-1)
|
361 |
+
return torch.matmul(pos_homo, mvp_mtx.permute(0, 2, 1))
|
362 |
+
|
363 |
+
|
364 |
+
def transform_points_homo(pos: torch.FloatTensor, mtx: torch.FloatTensor):
|
365 |
+
batch_size = pos.shape[0]
|
366 |
+
pos_shape = pos.shape[1:-1]
|
367 |
+
pos = pos.reshape(batch_size, -1, 3)
|
368 |
+
pos_homo = torch.cat([pos, torch.ones_like(pos[..., 0:1])], dim=-1)
|
369 |
+
pos = (pos_homo.unsqueeze(2) * mtx.unsqueeze(1)).sum(-1)[..., :3]
|
370 |
+
pos = pos.reshape(batch_size, *pos_shape, 3)
|
371 |
+
return pos
|
372 |
+
|
373 |
+
|
374 |
+
class DepthNormalizationStrategy(ABC):
|
375 |
+
@abstractmethod
|
376 |
+
def __init__(self, *args, **kwargs):
|
377 |
+
pass
|
378 |
+
|
379 |
+
@abstractmethod
|
380 |
+
def __call__(
|
381 |
+
self, depth: torch.FloatTensor, mask: torch.BoolTensor
|
382 |
+
) -> torch.FloatTensor:
|
383 |
+
pass
|
384 |
+
|
385 |
+
|
386 |
+
class DepthControlNetNormalization(DepthNormalizationStrategy):
|
387 |
+
def __init__(
|
388 |
+
self, far_clip: float = 0.25, near_clip: float = 1.0, bg_value: float = 0.0
|
389 |
+
):
|
390 |
+
self.far_clip = far_clip
|
391 |
+
self.near_clip = near_clip
|
392 |
+
self.bg_value = bg_value
|
393 |
+
|
394 |
+
def __call__(
|
395 |
+
self, depth: torch.FloatTensor, mask: torch.BoolTensor
|
396 |
+
) -> torch.FloatTensor:
|
397 |
+
batch_size = depth.shape[0]
|
398 |
+
min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None]
|
399 |
+
max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None]
|
400 |
+
depth = 1.0 - ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp(
|
401 |
+
0.0, 1.0
|
402 |
+
)
|
403 |
+
depth = depth * (self.near_clip - self.far_clip) + self.far_clip
|
404 |
+
depth[~mask] = self.bg_value
|
405 |
+
return depth
|
406 |
+
|
407 |
+
|
408 |
+
class Zero123PlusPlusNormalization(DepthNormalizationStrategy):
|
409 |
+
def __init__(self, bg_value: float = 0.8):
|
410 |
+
self.bg_value = bg_value
|
411 |
+
|
412 |
+
def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor:
|
413 |
+
batch_size = depth.shape[0]
|
414 |
+
min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None]
|
415 |
+
max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None]
|
416 |
+
depth = ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp(0.0, 1.0)
|
417 |
+
depth[~mask] = self.bg_value
|
418 |
+
return depth
|
419 |
+
|
420 |
+
|
421 |
+
class SimpleNormalization(DepthNormalizationStrategy):
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
scale: float = 1.0,
|
425 |
+
offset: float = -1.0,
|
426 |
+
clamp: bool = True,
|
427 |
+
bg_value: float = 1.0,
|
428 |
+
):
|
429 |
+
self.scale = scale
|
430 |
+
self.offset = offset
|
431 |
+
self.clamp = clamp
|
432 |
+
self.bg_value = bg_value
|
433 |
+
|
434 |
+
def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor:
|
435 |
+
depth = depth * self.scale + self.offset
|
436 |
+
if self.clamp:
|
437 |
+
depth = depth.clamp(0.0, 1.0)
|
438 |
+
depth[~mask] = self.bg_value
|
439 |
+
return depth
|
440 |
+
|
441 |
+
|
442 |
+
def render(
|
443 |
+
ctx: NVDiffRastContextWrapper,
|
444 |
+
mesh: TexturedMesh,
|
445 |
+
cam: Camera,
|
446 |
+
height: int,
|
447 |
+
width: int,
|
448 |
+
render_attr: bool = True,
|
449 |
+
render_depth: bool = True,
|
450 |
+
render_normal: bool = True,
|
451 |
+
depth_normalization_strategy: DepthNormalizationStrategy = DepthControlNetNormalization(),
|
452 |
+
attr_background: Union[float, torch.FloatTensor] = 0.5,
|
453 |
+
antialias_attr=False,
|
454 |
+
normal_background: Union[float, torch.FloatTensor] = 0.5,
|
455 |
+
texture_override=None,
|
456 |
+
texture_filter_mode: str = "linear",
|
457 |
+
) -> RenderOutput:
|
458 |
+
output_dict = {}
|
459 |
+
|
460 |
+
v_pos_clip = get_clip_space_position(mesh.v_pos, cam.mvp_mtx)
|
461 |
+
rast, _ = ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width), grad_db=True)
|
462 |
+
mask = rast[..., 3] > 0
|
463 |
+
|
464 |
+
gb_pos, _ = ctx.interpolate(mesh.v_pos[None], rast, mesh.t_pos_idx)
|
465 |
+
output_dict.update({"mask": mask, "pos": gb_pos})
|
466 |
+
|
467 |
+
if render_depth:
|
468 |
+
gb_pos_vs = transform_points_homo(gb_pos, cam.w2c)
|
469 |
+
gb_depth = -gb_pos_vs[..., 2]
|
470 |
+
# set background pixels to min depth value for correct min/max calculation
|
471 |
+
gb_depth = torch.where(
|
472 |
+
mask,
|
473 |
+
gb_depth,
|
474 |
+
gb_depth.view(gb_depth.shape[0], -1).min(dim=-1)[0][:, None, None],
|
475 |
+
)
|
476 |
+
gb_depth = depth_normalization_strategy(gb_depth, mask)
|
477 |
+
output_dict["depth"] = gb_depth
|
478 |
+
|
479 |
+
if render_attr:
|
480 |
+
tex_c, _ = ctx.interpolate(mesh.v_tex[None], rast, mesh.t_tex_idx)
|
481 |
+
texture = (
|
482 |
+
texture_override[None]
|
483 |
+
if texture_override is not None
|
484 |
+
else mesh.texture[None]
|
485 |
+
)
|
486 |
+
gb_rgb_fg = ctx.texture(texture, tex_c, filter_mode=texture_filter_mode)
|
487 |
+
gb_rgb_bg = torch.ones_like(gb_rgb_fg) * attr_background
|
488 |
+
gb_rgb = torch.where(mask[..., None], gb_rgb_fg, gb_rgb_bg)
|
489 |
+
if antialias_attr:
|
490 |
+
gb_rgb = ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx)
|
491 |
+
output_dict["attr"] = gb_rgb
|
492 |
+
|
493 |
+
if render_normal:
|
494 |
+
gb_nrm, _ = ctx.interpolate(mesh.v_nrm[None], rast, mesh.stitched_t_pos_idx)
|
495 |
+
gb_nrm = F.normalize(gb_nrm, dim=-1, p=2)
|
496 |
+
gb_nrm[~mask] = normal_background
|
497 |
+
output_dict["normal"] = gb_nrm
|
498 |
+
|
499 |
+
return RenderOutput(**output_dict)
|
mvadapter/utils/saving.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
def tensor_to_image(
|
10 |
+
data: Union[Image.Image, torch.Tensor, np.ndarray],
|
11 |
+
batched: bool = False,
|
12 |
+
format: str = "HWC",
|
13 |
+
) -> Union[Image.Image, List[Image.Image]]:
|
14 |
+
if isinstance(data, Image.Image):
|
15 |
+
return data
|
16 |
+
if isinstance(data, torch.Tensor):
|
17 |
+
data = data.detach().cpu().numpy()
|
18 |
+
if data.dtype == np.float32 or data.dtype == np.float16:
|
19 |
+
data = (data * 255).astype(np.uint8)
|
20 |
+
elif data.dtype == np.bool_:
|
21 |
+
data = data.astype(np.uint8) * 255
|
22 |
+
assert data.dtype == np.uint8
|
23 |
+
if format == "CHW":
|
24 |
+
if batched and data.ndim == 4:
|
25 |
+
data = data.transpose((0, 2, 3, 1))
|
26 |
+
elif not batched and data.ndim == 3:
|
27 |
+
data = data.transpose((1, 2, 0))
|
28 |
+
|
29 |
+
if batched:
|
30 |
+
return [Image.fromarray(d) for d in data]
|
31 |
+
return Image.fromarray(data)
|
32 |
+
|
33 |
+
|
34 |
+
def largest_factor_near_sqrt(n: int) -> int:
|
35 |
+
"""
|
36 |
+
Finds the largest factor of n that is closest to the square root of n.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
n (int): The integer for which to find the largest factor near its square root.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
int: The largest factor of n that is closest to the square root of n.
|
43 |
+
"""
|
44 |
+
sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root
|
45 |
+
|
46 |
+
# First, check if the square root itself is a factor
|
47 |
+
if sqrt_n * sqrt_n == n:
|
48 |
+
return sqrt_n
|
49 |
+
|
50 |
+
# Otherwise, find the largest factor by iterating from sqrt_n downwards
|
51 |
+
for i in range(sqrt_n, 0, -1):
|
52 |
+
if n % i == 0:
|
53 |
+
return i
|
54 |
+
|
55 |
+
# If n is 1, return 1
|
56 |
+
return 1
|
57 |
+
|
58 |
+
|
59 |
+
def make_image_grid(
|
60 |
+
images: List[Image.Image],
|
61 |
+
rows: Optional[int] = None,
|
62 |
+
cols: Optional[int] = None,
|
63 |
+
resize: Optional[int] = None,
|
64 |
+
) -> Image.Image:
|
65 |
+
"""
|
66 |
+
Prepares a single grid of images. Useful for visualization purposes.
|
67 |
+
"""
|
68 |
+
if rows is None and cols is not None:
|
69 |
+
assert len(images) % cols == 0
|
70 |
+
rows = len(images) // cols
|
71 |
+
elif cols is None and rows is not None:
|
72 |
+
assert len(images) % rows == 0
|
73 |
+
cols = len(images) // rows
|
74 |
+
elif rows is None and cols is None:
|
75 |
+
rows = largest_factor_near_sqrt(len(images))
|
76 |
+
cols = len(images) // rows
|
77 |
+
|
78 |
+
assert len(images) == rows * cols
|
79 |
+
|
80 |
+
if resize is not None:
|
81 |
+
images = [img.resize((resize, resize)) for img in images]
|
82 |
+
|
83 |
+
w, h = images[0].size
|
84 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
85 |
+
|
86 |
+
for i, img in enumerate(images):
|
87 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
88 |
+
return grid
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision
|
2 |
+
diffusers
|
3 |
+
transformers==4.49.0
|
4 |
+
einops
|
5 |
+
huggingface_hub
|
6 |
+
opencv-python
|
7 |
+
trimesh==4.5.3
|
8 |
+
omegaconf
|
9 |
+
scikit-image
|
10 |
+
numpy
|
11 |
+
peft
|
12 |
+
scipy==1.11.4
|
13 |
+
jaxtyping
|
14 |
+
typeguard
|
15 |
+
pymeshlab==2022.2.post4
|
16 |
+
open3d
|
17 |
+
timm
|
18 |
+
kornia
|
19 |
+
ninja
|
20 |
+
https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
|
21 |
+
cvcuda_cu12
|
22 |
+
gltflib
|
23 |
+
torch-cluster
|