huanngzh commited on
Commit
a207590
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.glb filter=lfs diff=lfs merge=lfs -text
37
+ *.so filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MV Adapter Img2Texture
3
+ emoji: 🔮
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.23.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Generate 3D texture from image
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import shutil
4
+ import subprocess
5
+ from typing import List
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import spaces
10
+ import torch
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
+ from PIL import Image
13
+ from torchvision import transforms
14
+ from transformers import AutoModelForImageSegmentation
15
+
16
+ from inference_ig2mv_sdxl import (
17
+ prepare_pipeline,
18
+ preprocess_image,
19
+ remove_bg,
20
+ run_pipeline,
21
+ )
22
+ from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
23
+
24
+ # install others
25
+ subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
26
+
27
+
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ DTYPE = torch.float16
30
+ MAX_SEED = np.iinfo(np.int32).max
31
+ NUM_VIEWS = 6
32
+ HEIGHT = 768
33
+ WIDTH = 768
34
+
35
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
36
+ os.makedirs(TMP_DIR, exist_ok=True)
37
+
38
+
39
+ HEADER = """
40
+ # 🔮 Image to Texture with [MV-Adapter](https://github.com/huanngzh/MV-Adapter)
41
+ ## State-of-the-art Open Source Texture Generation Using Multi-View Diffusion Model
42
+ """
43
+
44
+ EXAMPLES = [
45
+ ["examples/001.jpeg", "examples/001.glb"],
46
+ ["examples/002.jpeg", "examples/002.glb"],
47
+ ]
48
+
49
+ # MV-Adapter
50
+ pipe = prepare_pipeline(
51
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
52
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
53
+ unet_model=None,
54
+ lora_model=None,
55
+ adapter_path="huanngzh/mv-adapter",
56
+ scheduler=None,
57
+ num_views=NUM_VIEWS,
58
+ device=DEVICE,
59
+ dtype=DTYPE,
60
+ )
61
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
62
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
63
+ )
64
+ birefnet.to(DEVICE)
65
+ transform_image = transforms.Compose(
66
+ [
67
+ transforms.Resize((1024, 1024)),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
70
+ ]
71
+ )
72
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
73
+
74
+ if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
75
+ hf_hub_download(
76
+ "dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints"
77
+ )
78
+ if not os.path.exists("checkpoints/big-lama.pt"):
79
+ subprocess.run(
80
+ "wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
81
+ shell=True,
82
+ check=True,
83
+ )
84
+
85
+
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+
88
+
89
+ def start_session(req: gr.Request):
90
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
91
+ os.makedirs(save_dir, exist_ok=True)
92
+ print("start session, mkdir", save_dir)
93
+
94
+
95
+ def end_session(req: gr.Request):
96
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
97
+ shutil.rmtree(save_dir)
98
+
99
+
100
+ def get_random_hex():
101
+ random_bytes = os.urandom(8)
102
+ random_hex = random_bytes.hex()
103
+ return random_hex
104
+
105
+
106
+ def get_random_seed(randomize_seed, seed):
107
+ if randomize_seed:
108
+ seed = random.randint(0, MAX_SEED)
109
+ return seed
110
+
111
+
112
+ @spaces.GPU(duration=90)
113
+ @torch.no_grad()
114
+ def run_mvadapter(
115
+ mesh_path,
116
+ prompt,
117
+ image,
118
+ seed=42,
119
+ guidance_scale=3.0,
120
+ num_inference_steps=30,
121
+ reference_conditioning_scale=1.0,
122
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
123
+ progress=gr.Progress(track_tqdm=True),
124
+ ):
125
+ # pre-process the reference image
126
+ image = Image.open(image).convert("RGB") if isinstance(image, str) else image
127
+ image = remove_bg_fn(image)
128
+ image = preprocess_image(image, HEIGHT, WIDTH)
129
+
130
+ if isinstance(seed, str):
131
+ try:
132
+ seed = int(seed.strip())
133
+ except ValueError:
134
+ seed = 42
135
+
136
+ images, _, _, _ = run_pipeline(
137
+ pipe,
138
+ mesh_path=mesh_path,
139
+ num_views=NUM_VIEWS,
140
+ text=prompt,
141
+ image=image,
142
+ height=HEIGHT,
143
+ width=WIDTH,
144
+ num_inference_steps=num_inference_steps,
145
+ guidance_scale=guidance_scale,
146
+ seed=seed,
147
+ remove_bg_fn=None,
148
+ reference_conditioning_scale=reference_conditioning_scale,
149
+ negative_prompt=negative_prompt,
150
+ device=DEVICE,
151
+ )
152
+
153
+ torch.cuda.empty_cache()
154
+
155
+ return images, image
156
+
157
+
158
+ @spaces.GPU(duration=90)
159
+ @torch.no_grad()
160
+ def run_texturing(
161
+ mesh_path: str,
162
+ mv_images: List[Image.Image],
163
+ uv_unwarp: bool,
164
+ preprocess_mesh: bool,
165
+ uv_size: int,
166
+ req: gr.Request,
167
+ ):
168
+ save_dir = os.path.join(TMP_DIR, str(req.session_hash))
169
+ mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
170
+ mv_images = [item[0] for item in mv_images]
171
+ make_image_grid(mv_images, rows=1).save(mv_image_path)
172
+
173
+ from texture import ModProcessConfig, TexturePipeline
174
+
175
+ texture_pipe = TexturePipeline(
176
+ upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
177
+ inpaint_ckpt_path="checkpoints/big-lama.pt",
178
+ device=DEVICE,
179
+ )
180
+
181
+ textured_glb_path = texture_pipe(
182
+ mesh_path=mesh_path,
183
+ save_dir=save_dir,
184
+ save_name=f"texture_mesh_{get_random_hex()}",
185
+ uv_unwarp=uv_unwarp,
186
+ preprocess_mesh=preprocess_mesh,
187
+ uv_size=uv_size,
188
+ rgb_path=mv_image_path,
189
+ rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
190
+ camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
191
+ ).shaded_model_save_path
192
+
193
+ torch.cuda.empty_cache()
194
+
195
+ return textured_glb_path, textured_glb_path
196
+
197
+
198
+ with gr.Blocks(title="MVAdapter") as demo:
199
+ gr.Markdown(HEADER)
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ with gr.Row():
204
+ input_mesh = gr.Model3D(label="Input 3D mesh")
205
+ image_prompt = gr.Image(label="Input Image", type="pil")
206
+
207
+ with gr.Accordion("Generation Settings", open=False):
208
+ prompt = gr.Textbox(
209
+ label="Prompt (Optional)",
210
+ placeholder="Enter your prompt",
211
+ value="high quality",
212
+ )
213
+ seed = gr.Slider(
214
+ label="Seed", minimum=0, maximum=MAX_SEED, step=0, value=0
215
+ )
216
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
217
+ num_inference_steps = gr.Slider(
218
+ label="Number of inference steps",
219
+ minimum=8,
220
+ maximum=50,
221
+ step=1,
222
+ value=25,
223
+ )
224
+ guidance_scale = gr.Slider(
225
+ label="CFG scale",
226
+ minimum=0.0,
227
+ maximum=20.0,
228
+ step=0.1,
229
+ value=3.0,
230
+ )
231
+ reference_conditioning_scale = gr.Slider(
232
+ label="Image conditioning scale",
233
+ minimum=0.0,
234
+ maximum=2.0,
235
+ step=0.1,
236
+ value=1.0,
237
+ )
238
+
239
+ with gr.Accordion("Texture Settings", open=False):
240
+ with gr.Row():
241
+ uv_unwarp = gr.Checkbox(label="Unwarp UV", value=True)
242
+ preprocess_mesh = gr.Checkbox(label="Preprocess Mesh", value=False)
243
+ uv_size = gr.Slider(
244
+ label="UV Size", minimum=1024, maximum=8192, step=512, value=4096
245
+ )
246
+
247
+ gen_button = gr.Button("Generate Texture", variant="primary")
248
+
249
+ examples = gr.Examples(
250
+ examples=EXAMPLES,
251
+ inputs=[image_prompt, input_mesh],
252
+ outputs=[image_prompt],
253
+ )
254
+
255
+ with gr.Column():
256
+ mv_result = gr.Gallery(
257
+ label="Multi-View Results",
258
+ show_label=False,
259
+ columns=[3],
260
+ rows=[2],
261
+ object_fit="contain",
262
+ height="auto",
263
+ type="pil",
264
+ )
265
+ textured_model_output = gr.Model3D(label="Textured GLB", interactive=False)
266
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
267
+
268
+ gen_button.click(
269
+ get_random_seed, inputs=[randomize_seed, seed], outputs=[seed]
270
+ ).then(
271
+ run_mvadapter,
272
+ inputs=[
273
+ input_mesh,
274
+ prompt,
275
+ image_prompt,
276
+ seed,
277
+ guidance_scale,
278
+ num_inference_steps,
279
+ reference_conditioning_scale,
280
+ ],
281
+ outputs=[mv_result, image_prompt],
282
+ ).then(
283
+ run_texturing,
284
+ inputs=[input_mesh, mv_result, uv_unwarp, preprocess_mesh, uv_size],
285
+ outputs=[textured_model_output, download_glb],
286
+ ).then(
287
+ lambda: gr.Button(interactive=True), outputs=[download_glb]
288
+ )
289
+
290
+ demo.load(start_session)
291
+ demo.unload(end_session)
292
+
293
+ demo.launch()
inference_ig2mv_sdxl.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
+
11
+ from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0
12
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
13
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
14
+ from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image
15
+ from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
16
+
17
+
18
+ def prepare_pipeline(
19
+ base_model,
20
+ vae_model,
21
+ unet_model,
22
+ lora_model,
23
+ adapter_path,
24
+ scheduler,
25
+ num_views,
26
+ device,
27
+ dtype,
28
+ ):
29
+ # Load vae and unet if provided
30
+ pipe_kwargs = {}
31
+ if vae_model is not None:
32
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
33
+ if unet_model is not None:
34
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
35
+
36
+ # Prepare pipeline
37
+ pipe: MVAdapterI2MVSDXLPipeline
38
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
39
+
40
+ # Load scheduler if provided
41
+ scheduler_class = None
42
+ if scheduler == "ddpm":
43
+ scheduler_class = DDPMScheduler
44
+ elif scheduler == "lcm":
45
+ scheduler_class = LCMScheduler
46
+
47
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
48
+ pipe.scheduler,
49
+ shift_mode="interpolated",
50
+ shift_scale=8.0,
51
+ scheduler_class=scheduler_class,
52
+ )
53
+ pipe.init_custom_adapter(
54
+ num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0
55
+ )
56
+ pipe.load_custom_adapter(
57
+ adapter_path, weight_name="mvadapter_ig2mv_sdxl.safetensors"
58
+ )
59
+
60
+ pipe.to(device=device, dtype=dtype)
61
+ pipe.cond_encoder.to(device=device, dtype=dtype)
62
+
63
+ # load lora if provided
64
+ if lora_model is not None:
65
+ model_, name_ = lora_model.rsplit("/", 1)
66
+ pipe.load_lora_weights(model_, weight_name=name_)
67
+
68
+ return pipe
69
+
70
+
71
+ def remove_bg(image, net, transform, device):
72
+ image_size = image.size
73
+ input_images = transform(image).unsqueeze(0).to(device)
74
+ with torch.no_grad():
75
+ preds = net(input_images)[-1].sigmoid().cpu()
76
+ pred = preds[0].squeeze()
77
+ pred_pil = transforms.ToPILImage()(pred)
78
+ mask = pred_pil.resize(image_size)
79
+ image.putalpha(mask)
80
+ return image
81
+
82
+
83
+ def preprocess_image(image: Image.Image, height, width):
84
+ image = np.array(image)
85
+ alpha = image[..., 3] > 0
86
+ H, W = alpha.shape
87
+ # get the bounding box of alpha
88
+ y, x = np.where(alpha)
89
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
90
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
91
+ image_center = image[y0:y1, x0:x1]
92
+ # resize the longer side to H * 0.9
93
+ H, W, _ = image_center.shape
94
+ if H > W:
95
+ W = int(W * (height * 0.9) / H)
96
+ H = int(height * 0.9)
97
+ else:
98
+ H = int(H * (width * 0.9) / W)
99
+ W = int(width * 0.9)
100
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
101
+ # pad to H, W
102
+ start_h = (height - H) // 2
103
+ start_w = (width - W) // 2
104
+ image = np.zeros((height, width, 4), dtype=np.uint8)
105
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
106
+ image = image.astype(np.float32) / 255.0
107
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
108
+ image = (image * 255).clip(0, 255).astype(np.uint8)
109
+ image = Image.fromarray(image)
110
+
111
+ return image
112
+
113
+
114
+ def run_pipeline(
115
+ pipe,
116
+ mesh_path,
117
+ num_views,
118
+ text,
119
+ image,
120
+ height,
121
+ width,
122
+ num_inference_steps,
123
+ guidance_scale,
124
+ seed,
125
+ remove_bg_fn=None,
126
+ reference_conditioning_scale=1.0,
127
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
128
+ lora_scale=1.0,
129
+ device="cuda",
130
+ ):
131
+ # Prepare cameras
132
+ cameras = get_orthogonal_camera(
133
+ elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
134
+ distance=[1.8] * num_views,
135
+ left=-0.55,
136
+ right=0.55,
137
+ bottom=-0.55,
138
+ top=0.55,
139
+ azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
140
+ device=device,
141
+ )
142
+ ctx = NVDiffRastContextWrapper(device=device, context_type="cuda")
143
+
144
+ mesh = load_mesh(mesh_path, rescale=True, device=device)
145
+ render_out = render(
146
+ ctx,
147
+ mesh,
148
+ cameras,
149
+ height=height,
150
+ width=width,
151
+ render_attr=False,
152
+ normal_background=0.0,
153
+ )
154
+ pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True)
155
+ normal_images = tensor_to_image(
156
+ (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True
157
+ )
158
+ control_images = (
159
+ torch.cat(
160
+ [
161
+ (render_out.pos + 0.5).clamp(0, 1),
162
+ (render_out.normal / 2 + 0.5).clamp(0, 1),
163
+ ],
164
+ dim=-1,
165
+ )
166
+ .permute(0, 3, 1, 2)
167
+ .to(device)
168
+ )
169
+
170
+ # Prepare image
171
+ reference_image = Image.open(image) if isinstance(image, str) else image
172
+ if remove_bg_fn is not None:
173
+ reference_image = remove_bg_fn(reference_image)
174
+ reference_image = preprocess_image(reference_image, height, width)
175
+ elif reference_image.mode == "RGBA":
176
+ reference_image = preprocess_image(reference_image, height, width)
177
+
178
+ pipe_kwargs = {}
179
+ if seed != -1 and isinstance(seed, int):
180
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
181
+
182
+ images = pipe(
183
+ text,
184
+ height=height,
185
+ width=width,
186
+ num_inference_steps=num_inference_steps,
187
+ guidance_scale=guidance_scale,
188
+ num_images_per_prompt=num_views,
189
+ control_image=control_images,
190
+ control_conditioning_scale=1.0,
191
+ reference_image=reference_image,
192
+ reference_conditioning_scale=reference_conditioning_scale,
193
+ negative_prompt=negative_prompt,
194
+ cross_attention_kwargs={"scale": lora_scale},
195
+ **pipe_kwargs,
196
+ ).images
197
+
198
+ return images, pos_images, normal_images, reference_image
199
+
200
+
201
+ if __name__ == "__main__":
202
+ parser = argparse.ArgumentParser()
203
+ # Models
204
+ parser.add_argument(
205
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
206
+ )
207
+ parser.add_argument(
208
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
209
+ )
210
+ parser.add_argument("--unet_model", type=str, default=None)
211
+ parser.add_argument("--scheduler", type=str, default=None)
212
+ parser.add_argument("--lora_model", type=str, default=None)
213
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
214
+ parser.add_argument("--num_views", type=int, default=6)
215
+ # Device
216
+ parser.add_argument("--device", type=str, default="cuda")
217
+ # Inference
218
+ parser.add_argument("--mesh", type=str, required=True)
219
+ parser.add_argument("--image", type=str, required=True)
220
+ parser.add_argument("--text", type=str, required=False, default="high quality")
221
+ parser.add_argument("--num_inference_steps", type=int, default=50)
222
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
223
+ parser.add_argument("--seed", type=int, default=-1)
224
+ parser.add_argument("--lora_scale", type=float, default=1.0)
225
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
226
+ parser.add_argument(
227
+ "--negative_prompt",
228
+ type=str,
229
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
230
+ )
231
+ parser.add_argument("--output", type=str, default="output.png")
232
+ # Extra
233
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
234
+ args = parser.parse_args()
235
+
236
+ pipe = prepare_pipeline(
237
+ base_model=args.base_model,
238
+ vae_model=args.vae_model,
239
+ unet_model=args.unet_model,
240
+ lora_model=args.lora_model,
241
+ adapter_path=args.adapter_path,
242
+ scheduler=args.scheduler,
243
+ num_views=args.num_views,
244
+ device=args.device,
245
+ dtype=torch.float16,
246
+ )
247
+
248
+ if args.remove_bg:
249
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
250
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
251
+ )
252
+ birefnet.to(args.device)
253
+ transform_image = transforms.Compose(
254
+ [
255
+ transforms.Resize((1024, 1024)),
256
+ transforms.ToTensor(),
257
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
258
+ ]
259
+ )
260
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
261
+ else:
262
+ remove_bg_fn = None
263
+
264
+ images, pos_images, normal_images, reference_image = run_pipeline(
265
+ pipe,
266
+ mesh_path=args.mesh,
267
+ num_views=args.num_views,
268
+ text=args.text,
269
+ image=args.image,
270
+ height=768,
271
+ width=768,
272
+ num_inference_steps=args.num_inference_steps,
273
+ guidance_scale=args.guidance_scale,
274
+ seed=args.seed,
275
+ lora_scale=args.lora_scale,
276
+ reference_conditioning_scale=args.reference_conditioning_scale,
277
+ negative_prompt=args.negative_prompt,
278
+ device=args.device,
279
+ remove_bg_fn=remove_bg_fn,
280
+ )
281
+ make_image_grid(images, rows=1).save(args.output)
282
+ make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png")
283
+ make_image_grid(normal_images, rows=1).save(
284
+ args.output.rsplit(".", 1)[0] + "_nor.png"
285
+ )
286
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")
mvadapter/__init__.py ADDED
File without changes
mvadapter/loaders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .custom_adapter import CustomAdapterMixin
mvadapter/loaders/custom_adapter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Optional, Union
3
+
4
+ import safetensors
5
+ import torch
6
+ from diffusers.utils import _get_model_file, logging
7
+ from safetensors import safe_open
8
+
9
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10
+
11
+
12
+ class CustomAdapterMixin:
13
+ def init_custom_adapter(self, *args, **kwargs):
14
+ self._init_custom_adapter(*args, **kwargs)
15
+
16
+ def _init_custom_adapter(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def load_custom_adapter(
20
+ self,
21
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
22
+ weight_name: str,
23
+ subfolder: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ # Load the main state dict first.
27
+ cache_dir = kwargs.pop("cache_dir", None)
28
+ force_download = kwargs.pop("force_download", False)
29
+ proxies = kwargs.pop("proxies", None)
30
+ local_files_only = kwargs.pop("local_files_only", None)
31
+ token = kwargs.pop("token", None)
32
+ revision = kwargs.pop("revision", None)
33
+
34
+ user_agent = {
35
+ "file_type": "attn_procs_weights",
36
+ "framework": "pytorch",
37
+ }
38
+
39
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
40
+ model_file = _get_model_file(
41
+ pretrained_model_name_or_path_or_dict,
42
+ weights_name=weight_name,
43
+ subfolder=subfolder,
44
+ cache_dir=cache_dir,
45
+ force_download=force_download,
46
+ proxies=proxies,
47
+ local_files_only=local_files_only,
48
+ token=token,
49
+ revision=revision,
50
+ user_agent=user_agent,
51
+ )
52
+ if weight_name.endswith(".safetensors"):
53
+ state_dict = {}
54
+ with safe_open(model_file, framework="pt", device="cpu") as f:
55
+ for key in f.keys():
56
+ state_dict[key] = f.get_tensor(key)
57
+ else:
58
+ state_dict = torch.load(model_file, map_location="cpu")
59
+ else:
60
+ state_dict = pretrained_model_name_or_path_or_dict
61
+
62
+ self._load_custom_adapter(state_dict)
63
+
64
+ def _load_custom_adapter(self, state_dict):
65
+ raise NotImplementedError
66
+
67
+ def save_custom_adapter(
68
+ self,
69
+ save_directory: Union[str, os.PathLike],
70
+ weight_name: str,
71
+ safe_serialization: bool = False,
72
+ **kwargs,
73
+ ):
74
+ if os.path.isfile(save_directory):
75
+ logger.error(
76
+ f"Provided path ({save_directory}) should be a directory, not a file"
77
+ )
78
+ return
79
+
80
+ if safe_serialization:
81
+
82
+ def save_function(weights, filename):
83
+ return safetensors.torch.save_file(
84
+ weights, filename, metadata={"format": "pt"}
85
+ )
86
+
87
+ else:
88
+ save_function = torch.save
89
+
90
+ # Save the model
91
+ state_dict = self._save_custom_adapter(**kwargs)
92
+ save_function(state_dict, os.path.join(save_directory, weight_name))
93
+ logger.info(
94
+ f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
95
+ )
96
+
97
+ def _save_custom_adapter(self):
98
+ raise NotImplementedError
mvadapter/models/__init__.py ADDED
File without changes
mvadapter/models/attention_processor.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers.models.attention_processor import Attention
7
+ from diffusers.models.unets import UNet2DConditionModel
8
+ from diffusers.utils import deprecate, logging
9
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+
13
+
14
+ def default_set_attn_proc_func(
15
+ name: str,
16
+ hidden_size: int,
17
+ cross_attention_dim: Optional[int],
18
+ ori_attn_proc: object,
19
+ ) -> object:
20
+ return ori_attn_proc
21
+
22
+
23
+ def set_unet_2d_condition_attn_processor(
24
+ unet: UNet2DConditionModel,
25
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
26
+ set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
27
+ set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
28
+ set_self_attn_module_names: Optional[List[str]] = None,
29
+ set_cross_attn_module_names: Optional[List[str]] = None,
30
+ set_custom_attn_module_names: Optional[List[str]] = None,
31
+ ) -> None:
32
+ do_set_processor = lambda name, module_names: (
33
+ any([name.startswith(module_name) for module_name in module_names])
34
+ if module_names is not None
35
+ else True
36
+ ) # prefix match
37
+
38
+ attn_procs = {}
39
+ for name, attn_processor in unet.attn_processors.items():
40
+ # set attn_processor by default, if module_names is None
41
+ set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
42
+ set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
43
+ set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
44
+
45
+ if name.startswith("mid_block"):
46
+ hidden_size = unet.config.block_out_channels[-1]
47
+ elif name.startswith("up_blocks"):
48
+ block_id = int(name[len("up_blocks.")])
49
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
50
+ elif name.startswith("down_blocks"):
51
+ block_id = int(name[len("down_blocks.")])
52
+ hidden_size = unet.config.block_out_channels[block_id]
53
+
54
+ is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
55
+ if is_custom:
56
+ attn_procs[name] = (
57
+ set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
58
+ if set_custom_attn_processor
59
+ else attn_processor
60
+ )
61
+ else:
62
+ cross_attention_dim = (
63
+ None
64
+ if name.endswith("attn1.processor")
65
+ else unet.config.cross_attention_dim
66
+ )
67
+ if cross_attention_dim is None or "motion_modules" in name:
68
+ # self attention
69
+ attn_procs[name] = (
70
+ set_self_attn_proc_func(
71
+ name, hidden_size, cross_attention_dim, attn_processor
72
+ )
73
+ if set_self_attn_processor
74
+ else attn_processor
75
+ )
76
+ else:
77
+ # cross attention
78
+ attn_procs[name] = (
79
+ set_cross_attn_proc_func(
80
+ name, hidden_size, cross_attention_dim, attn_processor
81
+ )
82
+ if set_cross_attn_processor
83
+ else attn_processor
84
+ )
85
+
86
+ unet.set_attn_processor(attn_procs)
87
+
88
+
89
+ class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
90
+ r"""
91
+ Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ query_dim: int,
97
+ inner_dim: int,
98
+ num_views: int = 1,
99
+ name: Optional[str] = None,
100
+ use_mv: bool = True,
101
+ use_ref: bool = False,
102
+ ):
103
+ if not hasattr(F, "scaled_dot_product_attention"):
104
+ raise ImportError(
105
+ "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
106
+ )
107
+
108
+ super().__init__()
109
+
110
+ self.num_views = num_views
111
+ self.name = name # NOTE: need for image cross-attention
112
+ self.use_mv = use_mv
113
+ self.use_ref = use_ref
114
+
115
+ if self.use_mv:
116
+ self.to_q_mv = nn.Linear(
117
+ in_features=query_dim, out_features=inner_dim, bias=False
118
+ )
119
+ self.to_k_mv = nn.Linear(
120
+ in_features=query_dim, out_features=inner_dim, bias=False
121
+ )
122
+ self.to_v_mv = nn.Linear(
123
+ in_features=query_dim, out_features=inner_dim, bias=False
124
+ )
125
+ self.to_out_mv = nn.ModuleList(
126
+ [
127
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
128
+ nn.Dropout(0.0),
129
+ ]
130
+ )
131
+
132
+ if self.use_ref:
133
+ self.to_q_ref = nn.Linear(
134
+ in_features=query_dim, out_features=inner_dim, bias=False
135
+ )
136
+ self.to_k_ref = nn.Linear(
137
+ in_features=query_dim, out_features=inner_dim, bias=False
138
+ )
139
+ self.to_v_ref = nn.Linear(
140
+ in_features=query_dim, out_features=inner_dim, bias=False
141
+ )
142
+ self.to_out_ref = nn.ModuleList(
143
+ [
144
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
145
+ nn.Dropout(0.0),
146
+ ]
147
+ )
148
+
149
+ def __call__(
150
+ self,
151
+ attn: Attention,
152
+ hidden_states: torch.FloatTensor,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ attention_mask: Optional[torch.FloatTensor] = None,
155
+ temb: Optional[torch.FloatTensor] = None,
156
+ mv_scale: float = 1.0,
157
+ ref_hidden_states: Optional[torch.FloatTensor] = None,
158
+ ref_scale: float = 1.0,
159
+ cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
160
+ use_mv: bool = True,
161
+ use_ref: bool = True,
162
+ num_views: Optional[int] = None,
163
+ *args,
164
+ **kwargs,
165
+ ) -> torch.FloatTensor:
166
+ """
167
+ New args:
168
+ mv_scale (float): scale for multi-view self-attention.
169
+ ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
170
+ ref_scale (float): scale for image cross-attention.
171
+ cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
172
+
173
+ """
174
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
175
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
176
+ deprecate("scale", "1.0.0", deprecation_message)
177
+
178
+ if num_views is not None:
179
+ self.num_views = num_views
180
+
181
+ # NEW: cache hidden states for reference unet
182
+ if cache_hidden_states is not None:
183
+ cache_hidden_states[self.name] = hidden_states.clone()
184
+
185
+ # NEW: whether to use multi-view attention and image cross-attention
186
+ use_mv = self.use_mv and use_mv
187
+ use_ref = self.use_ref and use_ref
188
+
189
+ residual = hidden_states
190
+ if attn.spatial_norm is not None:
191
+ hidden_states = attn.spatial_norm(hidden_states, temb)
192
+
193
+ input_ndim = hidden_states.ndim
194
+
195
+ if input_ndim == 4:
196
+ batch_size, channel, height, width = hidden_states.shape
197
+ hidden_states = hidden_states.view(
198
+ batch_size, channel, height * width
199
+ ).transpose(1, 2)
200
+
201
+ batch_size, sequence_length, _ = (
202
+ hidden_states.shape
203
+ if encoder_hidden_states is None
204
+ else encoder_hidden_states.shape
205
+ )
206
+
207
+ if attention_mask is not None:
208
+ attention_mask = attn.prepare_attention_mask(
209
+ attention_mask, sequence_length, batch_size
210
+ )
211
+ # scaled_dot_product_attention expects attention_mask shape to be
212
+ # (batch, heads, source_length, target_length)
213
+ attention_mask = attention_mask.view(
214
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
215
+ )
216
+
217
+ if attn.group_norm is not None:
218
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
219
+ 1, 2
220
+ )
221
+
222
+ query = attn.to_q(hidden_states)
223
+
224
+ # NEW: for decoupled multi-view attention
225
+ if use_mv:
226
+ query_mv = self.to_q_mv(hidden_states)
227
+
228
+ # NEW: for decoupled reference cross attention
229
+ if use_ref:
230
+ query_ref = self.to_q_ref(hidden_states)
231
+
232
+ if encoder_hidden_states is None:
233
+ encoder_hidden_states = hidden_states
234
+ elif attn.norm_cross:
235
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
236
+ encoder_hidden_states
237
+ )
238
+
239
+ key = attn.to_k(encoder_hidden_states)
240
+ value = attn.to_v(encoder_hidden_states)
241
+
242
+ inner_dim = key.shape[-1]
243
+ head_dim = inner_dim // attn.heads
244
+
245
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
246
+
247
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
248
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+
250
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
251
+ # TODO: add support for attn.scale when we move to Torch 2.1
252
+ hidden_states = F.scaled_dot_product_attention(
253
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
254
+ )
255
+
256
+ hidden_states = hidden_states.transpose(1, 2).reshape(
257
+ batch_size, -1, attn.heads * head_dim
258
+ )
259
+ hidden_states = hidden_states.to(query.dtype)
260
+
261
+ ####### Decoupled multi-view self-attention ########
262
+ if use_mv:
263
+ key_mv = self.to_k_mv(encoder_hidden_states)
264
+ value_mv = self.to_v_mv(encoder_hidden_states)
265
+
266
+ query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
267
+ key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
268
+ value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
269
+
270
+ height = width = math.isqrt(sequence_length)
271
+
272
+ # row self-attention
273
+ query_mv = rearrange(
274
+ query_mv,
275
+ "(b nv) (ih iw) h c -> (b nv ih) iw h c",
276
+ nv=self.num_views,
277
+ ih=height,
278
+ iw=width,
279
+ ).transpose(1, 2)
280
+ key_mv = rearrange(
281
+ key_mv,
282
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
283
+ nv=self.num_views,
284
+ ih=height,
285
+ iw=width,
286
+ )
287
+ key_mv = (
288
+ key_mv.repeat_interleave(self.num_views, dim=0)
289
+ .view(batch_size * height, -1, attn.heads, head_dim)
290
+ .transpose(1, 2)
291
+ )
292
+ value_mv = rearrange(
293
+ value_mv,
294
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
295
+ nv=self.num_views,
296
+ ih=height,
297
+ iw=width,
298
+ )
299
+ value_mv = (
300
+ value_mv.repeat_interleave(self.num_views, dim=0)
301
+ .view(batch_size * height, -1, attn.heads, head_dim)
302
+ .transpose(1, 2)
303
+ )
304
+
305
+ hidden_states_mv = F.scaled_dot_product_attention(
306
+ query_mv,
307
+ key_mv,
308
+ value_mv,
309
+ dropout_p=0.0,
310
+ is_causal=False,
311
+ )
312
+ hidden_states_mv = rearrange(
313
+ hidden_states_mv,
314
+ "(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
315
+ nv=self.num_views,
316
+ ih=height,
317
+ )
318
+ hidden_states_mv = hidden_states_mv.to(query.dtype)
319
+
320
+ # linear proj
321
+ hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
322
+ # dropout
323
+ hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
324
+
325
+ if use_ref:
326
+ reference_hidden_states = ref_hidden_states[self.name]
327
+
328
+ key_ref = self.to_k_ref(reference_hidden_states)
329
+ value_ref = self.to_v_ref(reference_hidden_states)
330
+
331
+ query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
332
+ 1, 2
333
+ )
334
+ key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
335
+ value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
336
+ 1, 2
337
+ )
338
+
339
+ hidden_states_ref = F.scaled_dot_product_attention(
340
+ query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
341
+ )
342
+
343
+ hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
344
+ batch_size, -1, attn.heads * head_dim
345
+ )
346
+ hidden_states_ref = hidden_states_ref.to(query.dtype)
347
+
348
+ # linear proj
349
+ hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
350
+ # dropout
351
+ hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
352
+
353
+ # linear proj
354
+ hidden_states = attn.to_out[0](hidden_states)
355
+ # dropout
356
+ hidden_states = attn.to_out[1](hidden_states)
357
+
358
+ if use_mv:
359
+ hidden_states = hidden_states + hidden_states_mv * mv_scale
360
+
361
+ if use_ref:
362
+ hidden_states = hidden_states + hidden_states_ref * ref_scale
363
+
364
+ if input_ndim == 4:
365
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
366
+ batch_size, channel, height, width
367
+ )
368
+
369
+ if attn.residual_connection:
370
+ hidden_states = hidden_states + residual
371
+
372
+ hidden_states = hidden_states / attn.rescale_output_factor
373
+
374
+ return hidden_states
375
+
376
+ def set_num_views(self, num_views: int) -> None:
377
+ self.num_views = num_views
378
+
379
+
380
+ class DecoupledMVRowColSelfAttnProcessor2_0(torch.nn.Module):
381
+ r"""
382
+ Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
383
+ """
384
+
385
+ def __init__(
386
+ self,
387
+ query_dim: int,
388
+ inner_dim: int,
389
+ num_views: int = 1,
390
+ name: Optional[str] = None,
391
+ use_mv: bool = True,
392
+ use_ref: bool = False,
393
+ ):
394
+ if not hasattr(F, "scaled_dot_product_attention"):
395
+ raise ImportError(
396
+ "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
397
+ )
398
+
399
+ super().__init__()
400
+
401
+ self.num_views = num_views
402
+ self.name = name # NOTE: need for image cross-attention
403
+ self.use_mv = use_mv
404
+ self.use_ref = use_ref
405
+
406
+ if self.use_mv:
407
+ self.to_q_mv = nn.Linear(
408
+ in_features=query_dim, out_features=inner_dim, bias=False
409
+ )
410
+ self.to_k_mv = nn.Linear(
411
+ in_features=query_dim, out_features=inner_dim, bias=False
412
+ )
413
+ self.to_v_mv = nn.Linear(
414
+ in_features=query_dim, out_features=inner_dim, bias=False
415
+ )
416
+ self.to_out_mv = nn.ModuleList(
417
+ [
418
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
419
+ nn.Dropout(0.0),
420
+ ]
421
+ )
422
+
423
+ if self.use_ref:
424
+ self.to_q_ref = nn.Linear(
425
+ in_features=query_dim, out_features=inner_dim, bias=False
426
+ )
427
+ self.to_k_ref = nn.Linear(
428
+ in_features=query_dim, out_features=inner_dim, bias=False
429
+ )
430
+ self.to_v_ref = nn.Linear(
431
+ in_features=query_dim, out_features=inner_dim, bias=False
432
+ )
433
+ self.to_out_ref = nn.ModuleList(
434
+ [
435
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
436
+ nn.Dropout(0.0),
437
+ ]
438
+ )
439
+
440
+ def __call__(
441
+ self,
442
+ attn: Attention,
443
+ hidden_states: torch.FloatTensor,
444
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
445
+ attention_mask: Optional[torch.FloatTensor] = None,
446
+ temb: Optional[torch.FloatTensor] = None,
447
+ mv_scale: float = 1.0,
448
+ ref_hidden_states: Optional[torch.FloatTensor] = None,
449
+ ref_scale: float = 1.0,
450
+ cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
451
+ use_mv: bool = True,
452
+ use_ref: bool = True,
453
+ num_views: Optional[int] = None,
454
+ *args,
455
+ **kwargs,
456
+ ) -> torch.FloatTensor:
457
+ """
458
+ New args:
459
+ mv_scale (float): scale for multi-view self-attention.
460
+ ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
461
+ ref_scale (float): scale for image cross-attention.
462
+ cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
463
+
464
+ """
465
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
466
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
467
+ deprecate("scale", "1.0.0", deprecation_message)
468
+
469
+ if num_views is not None:
470
+ self.num_views = num_views
471
+
472
+ # NEW: cache hidden states for reference unet
473
+ if cache_hidden_states is not None:
474
+ cache_hidden_states[self.name] = hidden_states.clone()
475
+
476
+ # NEW: whether to use multi-view attention and image cross-attention
477
+ use_mv = self.use_mv and use_mv
478
+ use_ref = self.use_ref and use_ref
479
+
480
+ residual = hidden_states
481
+ if attn.spatial_norm is not None:
482
+ hidden_states = attn.spatial_norm(hidden_states, temb)
483
+
484
+ input_ndim = hidden_states.ndim
485
+
486
+ if input_ndim == 4:
487
+ batch_size, channel, height, width = hidden_states.shape
488
+ hidden_states = hidden_states.view(
489
+ batch_size, channel, height * width
490
+ ).transpose(1, 2)
491
+
492
+ batch_size, sequence_length, _ = (
493
+ hidden_states.shape
494
+ if encoder_hidden_states is None
495
+ else encoder_hidden_states.shape
496
+ )
497
+
498
+ if attention_mask is not None:
499
+ attention_mask = attn.prepare_attention_mask(
500
+ attention_mask, sequence_length, batch_size
501
+ )
502
+ # scaled_dot_product_attention expects attention_mask shape to be
503
+ # (batch, heads, source_length, target_length)
504
+ attention_mask = attention_mask.view(
505
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
506
+ )
507
+
508
+ if attn.group_norm is not None:
509
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
510
+ 1, 2
511
+ )
512
+
513
+ query = attn.to_q(hidden_states)
514
+
515
+ # NEW: for decoupled multi-view attention
516
+ if use_mv:
517
+ query_mv = self.to_q_mv(hidden_states)
518
+
519
+ # NEW: for decoupled reference cross attention
520
+ if use_ref:
521
+ query_ref = self.to_q_ref(hidden_states)
522
+
523
+ if encoder_hidden_states is None:
524
+ encoder_hidden_states = hidden_states
525
+ elif attn.norm_cross:
526
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
527
+ encoder_hidden_states
528
+ )
529
+
530
+ key = attn.to_k(encoder_hidden_states)
531
+ value = attn.to_v(encoder_hidden_states)
532
+
533
+ inner_dim = key.shape[-1]
534
+ head_dim = inner_dim // attn.heads
535
+
536
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
537
+
538
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
539
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
540
+
541
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
542
+ # TODO: add support for attn.scale when we move to Torch 2.1
543
+ hidden_states = F.scaled_dot_product_attention(
544
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
545
+ )
546
+
547
+ hidden_states = hidden_states.transpose(1, 2).reshape(
548
+ batch_size, -1, attn.heads * head_dim
549
+ )
550
+ hidden_states = hidden_states.to(query.dtype)
551
+
552
+ ####### Decoupled multi-view self-attention ########
553
+ if use_mv:
554
+ key_mv = self.to_k_mv(encoder_hidden_states)
555
+ value_mv = self.to_v_mv(encoder_hidden_states)
556
+
557
+ query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
558
+ key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
559
+ value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
560
+
561
+ height = width = math.isqrt(sequence_length)
562
+
563
+ query_mv = rearrange(
564
+ query_mv,
565
+ "(b nv) (ih iw) h c -> b nv ih iw h c",
566
+ nv=self.num_views,
567
+ ih=height,
568
+ iw=width,
569
+ )
570
+ key_mv = rearrange(
571
+ key_mv,
572
+ "(b nv) (ih iw) h c -> b nv ih iw h c",
573
+ nv=self.num_views,
574
+ ih=height,
575
+ iw=width,
576
+ )
577
+ value_mv = rearrange(
578
+ value_mv,
579
+ "(b nv) (ih iw) h c -> b nv ih iw h c",
580
+ nv=self.num_views,
581
+ ih=height,
582
+ iw=width,
583
+ )
584
+
585
+ # row-wise attention for view 0123 (front, right, back, left)
586
+ query_mv_0123 = rearrange(
587
+ query_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
588
+ )
589
+ key_mv_0123 = rearrange(
590
+ key_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
591
+ )
592
+ value_mv_0123 = rearrange(
593
+ value_mv[:, 0:4], "b nv ih iw h c -> (b ih) h (nv iw) c"
594
+ )
595
+ hidden_states_mv_0123 = F.scaled_dot_product_attention(
596
+ query_mv_0123,
597
+ key_mv_0123,
598
+ value_mv_0123,
599
+ dropout_p=0.0,
600
+ is_causal=False,
601
+ )
602
+ hidden_states_mv_0123 = rearrange(
603
+ hidden_states_mv_0123,
604
+ "(b ih) h (nv iw) c -> b nv (ih iw) (h c)",
605
+ ih=height,
606
+ iw=height,
607
+ )
608
+
609
+ # col-wise attention for view 0245 (front, back, top, bottom)
610
+ # flip first
611
+ query_mv_0245 = torch.cat(
612
+ [
613
+ torch.flip(query_mv[:, [0]], [3]), # horizontal flip
614
+ query_mv[:, [2, 4, 5]],
615
+ ],
616
+ dim=1,
617
+ )
618
+ key_mv_0245 = torch.cat(
619
+ [
620
+ torch.flip(key_mv[:, [0]], [3]), # horizontal flip
621
+ key_mv[:, [2, 4, 5]],
622
+ ],
623
+ dim=1,
624
+ )
625
+ value_mv_0245 = torch.cat(
626
+ [
627
+ torch.flip(value_mv[:, [0]], [3]), # horizontal flip
628
+ value_mv[:, [2, 4, 5]],
629
+ ],
630
+ dim=1,
631
+ )
632
+ # attention
633
+ query_mv_0245 = rearrange(
634
+ query_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c"
635
+ )
636
+ key_mv_0245 = rearrange(key_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c")
637
+ value_mv_0245 = rearrange(
638
+ value_mv_0245, "b nv ih iw h c -> (b iw) h (nv ih) c"
639
+ )
640
+ hidden_states_mv_0245 = F.scaled_dot_product_attention(
641
+ query_mv_0245,
642
+ key_mv_0245,
643
+ value_mv_0245,
644
+ dropout_p=0.0,
645
+ is_causal=False,
646
+ )
647
+ # flip back
648
+ hidden_states_mv_0245 = rearrange(
649
+ hidden_states_mv_0245,
650
+ "(b iw) h (nv ih) c -> b nv ih iw (h c)",
651
+ ih=height,
652
+ iw=height,
653
+ )
654
+ hidden_states_mv_0245 = torch.cat(
655
+ [
656
+ torch.flip(hidden_states_mv_0245[:, [0]], [3]), # horizontal flip
657
+ hidden_states_mv_0245[:, [1, 2, 3]],
658
+ ],
659
+ dim=1,
660
+ )
661
+ hidden_states_mv_0245 = hidden_states_mv_0245.view(
662
+ hidden_states_mv_0245.shape[0],
663
+ hidden_states_mv_0245.shape[1],
664
+ -1,
665
+ hidden_states_mv_0245.shape[-1],
666
+ )
667
+
668
+ # combine row and col
669
+ hidden_states_mv = torch.stack(
670
+ [
671
+ (hidden_states_mv_0123[:, 0] + hidden_states_mv_0245[:, 0]) / 2,
672
+ hidden_states_mv_0123[:, 1],
673
+ (hidden_states_mv_0123[:, 2] + hidden_states_mv_0245[:, 1]) / 2,
674
+ hidden_states_mv_0123[:, 3],
675
+ hidden_states_mv_0245[:, 2],
676
+ hidden_states_mv_0245[:, 3],
677
+ ],
678
+ dim=1,
679
+ )
680
+
681
+ hidden_states_mv = hidden_states_mv.view(
682
+ -1, hidden_states_mv.shape[-2], hidden_states_mv.shape[-1]
683
+ )
684
+ hidden_states_mv = hidden_states_mv.to(query.dtype)
685
+
686
+ # linear proj
687
+ hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
688
+ # dropout
689
+ hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
690
+
691
+ if use_ref:
692
+ reference_hidden_states = ref_hidden_states[self.name]
693
+
694
+ key_ref = self.to_k_ref(reference_hidden_states)
695
+ value_ref = self.to_v_ref(reference_hidden_states)
696
+
697
+ query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
698
+ 1, 2
699
+ )
700
+ key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
701
+ value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
702
+ 1, 2
703
+ )
704
+
705
+ hidden_states_ref = F.scaled_dot_product_attention(
706
+ query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
707
+ )
708
+
709
+ hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
710
+ batch_size, -1, attn.heads * head_dim
711
+ )
712
+ hidden_states_ref = hidden_states_ref.to(query.dtype)
713
+
714
+ # linear proj
715
+ hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
716
+ # dropout
717
+ hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
718
+
719
+ # linear proj
720
+ hidden_states = attn.to_out[0](hidden_states)
721
+ # dropout
722
+ hidden_states = attn.to_out[1](hidden_states)
723
+
724
+ if use_mv:
725
+ hidden_states = hidden_states + hidden_states_mv * mv_scale
726
+
727
+ if use_ref:
728
+ hidden_states = hidden_states + hidden_states_ref * ref_scale
729
+
730
+ if input_ndim == 4:
731
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
732
+ batch_size, channel, height, width
733
+ )
734
+
735
+ if attn.residual_connection:
736
+ hidden_states = hidden_states + residual
737
+
738
+ hidden_states = hidden_states / attn.rescale_output_factor
739
+
740
+ return hidden_states
741
+
742
+ def set_num_views(self, num_views: int) -> None:
743
+ self.num_views = num_views
mvadapter/pipelines/pipeline_mvadapter_i2mv_sd.py ADDED
@@ -0,0 +1,777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+
17
+ import PIL
18
+ import torch
19
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
20
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
21
+ from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
22
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
23
+ StableDiffusionPipelineOutput,
24
+ )
25
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
26
+ StableDiffusionPipeline,
27
+ rescale_noise_cfg,
28
+ retrieve_timesteps,
29
+ )
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
31
+ StableDiffusionSafetyChecker,
32
+ )
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import deprecate, is_torch_xla_available, logging
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from transformers import (
37
+ CLIPImageProcessor,
38
+ CLIPTextModel,
39
+ CLIPTokenizer,
40
+ CLIPVisionModelWithProjection,
41
+ )
42
+
43
+ from ..loaders import CustomAdapterMixin
44
+ from ..models.attention_processor import (
45
+ DecoupledMVRowSelfAttnProcessor2_0,
46
+ set_unet_2d_condition_attn_processor,
47
+ )
48
+
49
+ if is_torch_xla_available():
50
+ import torch_xla.core.xla_model as xm
51
+
52
+ XLA_AVAILABLE = True
53
+ else:
54
+ XLA_AVAILABLE = False
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ def retrieve_latents(
60
+ encoder_output: torch.Tensor,
61
+ generator: Optional[torch.Generator] = None,
62
+ sample_mode: str = "sample",
63
+ ):
64
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
65
+ return encoder_output.latent_dist.sample(generator)
66
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
67
+ return encoder_output.latent_dist.mode()
68
+ elif hasattr(encoder_output, "latents"):
69
+ return encoder_output.latents
70
+ else:
71
+ raise AttributeError("Could not access latents of provided encoder_output")
72
+
73
+
74
+ class MVAdapterI2MVSDPipeline(StableDiffusionPipeline, CustomAdapterMixin):
75
+ def __init__(
76
+ self,
77
+ vae: AutoencoderKL,
78
+ text_encoder: CLIPTextModel,
79
+ tokenizer: CLIPTokenizer,
80
+ unet: UNet2DConditionModel,
81
+ scheduler: KarrasDiffusionSchedulers,
82
+ safety_checker: StableDiffusionSafetyChecker,
83
+ feature_extractor: CLIPImageProcessor,
84
+ image_encoder: CLIPVisionModelWithProjection = None,
85
+ requires_safety_checker: bool = False,
86
+ ):
87
+ super().__init__(
88
+ vae=vae,
89
+ text_encoder=text_encoder,
90
+ tokenizer=tokenizer,
91
+ unet=unet,
92
+ scheduler=scheduler,
93
+ safety_checker=safety_checker,
94
+ feature_extractor=feature_extractor,
95
+ image_encoder=image_encoder,
96
+ requires_safety_checker=requires_safety_checker,
97
+ )
98
+
99
+ self.control_image_processor = VaeImageProcessor(
100
+ vae_scale_factor=self.vae_scale_factor,
101
+ do_convert_rgb=True,
102
+ do_normalize=False,
103
+ )
104
+
105
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
106
+ def prepare_image_latents(
107
+ self,
108
+ image,
109
+ timestep,
110
+ batch_size,
111
+ num_images_per_prompt,
112
+ dtype,
113
+ device,
114
+ generator=None,
115
+ add_noise=True,
116
+ ):
117
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
118
+ raise ValueError(
119
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
120
+ )
121
+
122
+ image = image.to(device=device, dtype=dtype)
123
+
124
+ batch_size = batch_size * num_images_per_prompt
125
+
126
+ if image.shape[1] == 4:
127
+ init_latents = image
128
+
129
+ else:
130
+ if isinstance(generator, list) and len(generator) != batch_size:
131
+ raise ValueError(
132
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
133
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
134
+ )
135
+
136
+ elif isinstance(generator, list):
137
+ init_latents = [
138
+ retrieve_latents(
139
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
140
+ )
141
+ for i in range(batch_size)
142
+ ]
143
+ init_latents = torch.cat(init_latents, dim=0)
144
+ else:
145
+ init_latents = retrieve_latents(
146
+ self.vae.encode(image), generator=generator
147
+ )
148
+
149
+ init_latents = self.vae.config.scaling_factor * init_latents
150
+
151
+ if (
152
+ batch_size > init_latents.shape[0]
153
+ and batch_size % init_latents.shape[0] == 0
154
+ ):
155
+ # expand init_latents for batch_size
156
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
157
+ init_latents = torch.cat(
158
+ [init_latents] * additional_image_per_prompt, dim=0
159
+ )
160
+ elif (
161
+ batch_size > init_latents.shape[0]
162
+ and batch_size % init_latents.shape[0] != 0
163
+ ):
164
+ raise ValueError(
165
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
166
+ )
167
+ else:
168
+ init_latents = torch.cat([init_latents], dim=0)
169
+
170
+ if add_noise:
171
+ shape = init_latents.shape
172
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
173
+ # get latents
174
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
175
+
176
+ latents = init_latents
177
+
178
+ return latents
179
+
180
+ def prepare_control_image(
181
+ self,
182
+ image,
183
+ width,
184
+ height,
185
+ batch_size,
186
+ num_images_per_prompt,
187
+ device,
188
+ dtype,
189
+ do_classifier_free_guidance=False,
190
+ num_empty_images=0, # for concat in batch like ImageDream
191
+ ):
192
+ assert hasattr(
193
+ self, "control_image_processor"
194
+ ), "control_image_processor is not initialized"
195
+
196
+ image = self.control_image_processor.preprocess(
197
+ image, height=height, width=width
198
+ ).to(dtype=torch.float32)
199
+
200
+ if num_empty_images > 0:
201
+ image = torch.cat(
202
+ [image, torch.zeros_like(image[:num_empty_images])], dim=0
203
+ )
204
+
205
+ image_batch_size = image.shape[0]
206
+
207
+ if image_batch_size == 1:
208
+ repeat_by = batch_size
209
+ else:
210
+ # image batch size is the same as prompt batch size
211
+ repeat_by = num_images_per_prompt # always 1 for control image
212
+
213
+ image = image.repeat_interleave(repeat_by, dim=0)
214
+
215
+ image = image.to(device=device, dtype=dtype)
216
+
217
+ if do_classifier_free_guidance:
218
+ image = torch.cat([image] * 2)
219
+
220
+ return image
221
+
222
+ @torch.no_grad()
223
+ def __call__(
224
+ self,
225
+ prompt: Union[str, List[str]] = None,
226
+ height: Optional[int] = None,
227
+ width: Optional[int] = None,
228
+ num_inference_steps: int = 50,
229
+ timesteps: List[int] = None,
230
+ sigmas: List[float] = None,
231
+ guidance_scale: float = 7.5,
232
+ negative_prompt: Optional[Union[str, List[str]]] = None,
233
+ num_images_per_prompt: Optional[int] = 1,
234
+ eta: float = 0.0,
235
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
236
+ latents: Optional[torch.Tensor] = None,
237
+ prompt_embeds: Optional[torch.Tensor] = None,
238
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
239
+ ip_adapter_image: Optional[PipelineImageInput] = None,
240
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
241
+ output_type: Optional[str] = "pil",
242
+ return_dict: bool = True,
243
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
244
+ guidance_rescale: float = 0.0,
245
+ clip_skip: Optional[int] = None,
246
+ callback_on_step_end: Optional[
247
+ Union[
248
+ Callable[[int, int, Dict], None],
249
+ PipelineCallback,
250
+ MultiPipelineCallbacks,
251
+ ]
252
+ ] = None,
253
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
254
+ # NEW
255
+ mv_scale: float = 1.0,
256
+ # Camera or geometry condition
257
+ control_image: Optional[PipelineImageInput] = None,
258
+ control_conditioning_scale: Optional[float] = 1.0,
259
+ control_conditioning_factor: float = 1.0,
260
+ # Image condition
261
+ reference_image: Optional[PipelineImageInput] = None,
262
+ reference_conditioning_scale: Optional[float] = 1.0,
263
+ # Optional. controlnet
264
+ controlnet_image: Optional[PipelineImageInput] = None,
265
+ controlnet_conditioning_scale: Optional[float] = 1.0,
266
+ **kwargs,
267
+ ):
268
+ r"""
269
+ The call function to the pipeline for generation.
270
+
271
+ Args:
272
+ prompt (`str` or `List[str]`, *optional*):
273
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
274
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
275
+ The height in pixels of the generated image.
276
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
277
+ The width in pixels of the generated image.
278
+ num_inference_steps (`int`, *optional*, defaults to 50):
279
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
280
+ expense of slower inference.
281
+ timesteps (`List[int]`, *optional*):
282
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
283
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
284
+ passed will be used. Must be in descending order.
285
+ sigmas (`List[float]`, *optional*):
286
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
287
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
288
+ will be used.
289
+ guidance_scale (`float`, *optional*, defaults to 7.5):
290
+ A higher guidance scale value encourages the model to generate images closely linked to the text
291
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
292
+ negative_prompt (`str` or `List[str]`, *optional*):
293
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
294
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
295
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
296
+ The number of images to generate per prompt.
297
+ eta (`float`, *optional*, defaults to 0.0):
298
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
299
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
300
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
301
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
302
+ generation deterministic.
303
+ latents (`torch.Tensor`, *optional*):
304
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
305
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
306
+ tensor is generated by sampling using the supplied random `generator`.
307
+ prompt_embeds (`torch.Tensor`, *optional*):
308
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
309
+ provided, text embeddings are generated from the `prompt` input argument.
310
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
311
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
312
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
313
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
314
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
315
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
316
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
317
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
318
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
319
+ output_type (`str`, *optional*, defaults to `"pil"`):
320
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
321
+ return_dict (`bool`, *optional*, defaults to `True`):
322
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
323
+ plain tuple.
324
+ cross_attention_kwargs (`dict`, *optional*):
325
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
326
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
327
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
328
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
329
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
330
+ using zero terminal SNR.
331
+ clip_skip (`int`, *optional*):
332
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
333
+ the output of the pre-final layer will be used for computing the prompt embeddings.
334
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
335
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
336
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
337
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
338
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
339
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
340
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
341
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
342
+ `._callback_tensor_inputs` attribute of your pipeline class.
343
+
344
+ Examples:
345
+
346
+ Returns:
347
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
348
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
349
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
350
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
351
+ "not-safe-for-work" (nsfw) content.
352
+ """
353
+
354
+ callback = kwargs.pop("callback", None)
355
+ callback_steps = kwargs.pop("callback_steps", None)
356
+
357
+ if callback is not None:
358
+ deprecate(
359
+ "callback",
360
+ "1.0.0",
361
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
362
+ )
363
+ if callback_steps is not None:
364
+ deprecate(
365
+ "callback_steps",
366
+ "1.0.0",
367
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
368
+ )
369
+
370
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
371
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
372
+
373
+ # 0. Default height and width to unet
374
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
375
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
376
+ # to deal with lora scaling and other possible forward hooks
377
+
378
+ # 1. Check inputs. Raise error if not correct
379
+ self.check_inputs(
380
+ prompt,
381
+ height,
382
+ width,
383
+ callback_steps,
384
+ negative_prompt,
385
+ prompt_embeds,
386
+ negative_prompt_embeds,
387
+ ip_adapter_image,
388
+ ip_adapter_image_embeds,
389
+ callback_on_step_end_tensor_inputs,
390
+ )
391
+
392
+ self._guidance_scale = guidance_scale
393
+ self._guidance_rescale = guidance_rescale
394
+ self._clip_skip = clip_skip
395
+ self._cross_attention_kwargs = cross_attention_kwargs
396
+ self._interrupt = False
397
+
398
+ # 2. Define call parameters
399
+ if prompt is not None and isinstance(prompt, str):
400
+ batch_size = 1
401
+ elif prompt is not None and isinstance(prompt, list):
402
+ batch_size = len(prompt)
403
+ else:
404
+ batch_size = prompt_embeds.shape[0]
405
+
406
+ device = self._execution_device
407
+
408
+ # 3. Encode input prompt
409
+ lora_scale = (
410
+ self.cross_attention_kwargs.get("scale", None)
411
+ if self.cross_attention_kwargs is not None
412
+ else None
413
+ )
414
+
415
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
416
+ prompt,
417
+ device,
418
+ num_images_per_prompt,
419
+ self.do_classifier_free_guidance,
420
+ negative_prompt,
421
+ prompt_embeds=prompt_embeds,
422
+ negative_prompt_embeds=negative_prompt_embeds,
423
+ lora_scale=lora_scale,
424
+ clip_skip=self.clip_skip,
425
+ )
426
+
427
+ # For classifier free guidance, we need to do two forward passes.
428
+ # Here we concatenate the unconditional and text embeddings into a single batch
429
+ # to avoid doing two forward passes
430
+ if self.do_classifier_free_guidance:
431
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
432
+
433
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
434
+ image_embeds = self.prepare_ip_adapter_image_embeds(
435
+ ip_adapter_image,
436
+ ip_adapter_image_embeds,
437
+ device,
438
+ batch_size * num_images_per_prompt,
439
+ self.do_classifier_free_guidance,
440
+ )
441
+
442
+ # 4. Prepare timesteps
443
+ timesteps, num_inference_steps = retrieve_timesteps(
444
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
445
+ )
446
+
447
+ # 5. Prepare latent variables
448
+ num_channels_latents = self.unet.config.in_channels
449
+ latents = self.prepare_latents(
450
+ batch_size * num_images_per_prompt,
451
+ num_channels_latents,
452
+ height,
453
+ width,
454
+ prompt_embeds.dtype,
455
+ device,
456
+ generator,
457
+ latents,
458
+ )
459
+
460
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
461
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
462
+
463
+ # 6.1 Add image embeds for IP-Adapter
464
+ added_cond_kwargs = (
465
+ {"image_embeds": image_embeds}
466
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
467
+ else None
468
+ )
469
+
470
+ # 6.2 Optionally get Guidance Scale Embedding
471
+ timestep_cond = None
472
+ if self.unet.config.time_cond_proj_dim is not None:
473
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
474
+ batch_size * num_images_per_prompt
475
+ )
476
+ timestep_cond = self.get_guidance_scale_embedding(
477
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
478
+ ).to(device=device, dtype=latents.dtype)
479
+
480
+ # Preprocess reference image
481
+ reference_image = self.image_processor.preprocess(reference_image)
482
+ reference_latents = self.prepare_image_latents(
483
+ reference_image,
484
+ timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
485
+ batch_size,
486
+ 1,
487
+ prompt_embeds.dtype,
488
+ device,
489
+ generator,
490
+ add_noise=False,
491
+ )
492
+
493
+ ref_timesteps = torch.zeros_like(timesteps[0])
494
+ ref_hidden_states = {}
495
+ with torch.no_grad():
496
+ self.unet(
497
+ reference_latents,
498
+ ref_timesteps,
499
+ encoder_hidden_states=prompt_embeds[-1:],
500
+ cross_attention_kwargs={
501
+ "cache_hidden_states": ref_hidden_states,
502
+ "use_mv": False,
503
+ "use_ref": False,
504
+ },
505
+ return_dict=False,
506
+ )
507
+ ref_hidden_states = {
508
+ k: v.repeat_interleave(num_images_per_prompt, dim=0)
509
+ for k, v in ref_hidden_states.items()
510
+ }
511
+ if self.do_classifier_free_guidance:
512
+ ref_hidden_states = {
513
+ k: torch.cat([torch.zeros_like(v), v], dim=0)
514
+ for k, v in ref_hidden_states.items()
515
+ }
516
+
517
+ cross_attention_kwargs = {
518
+ "num_views": num_images_per_prompt,
519
+ "mv_scale": mv_scale,
520
+ "ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()},
521
+ "ref_scale": reference_conditioning_scale,
522
+ **(self.cross_attention_kwargs or {}),
523
+ }
524
+
525
+ # Preprocess control image
526
+ control_image_feature = self.prepare_control_image(
527
+ image=control_image,
528
+ width=width,
529
+ height=height,
530
+ batch_size=batch_size * num_images_per_prompt,
531
+ num_images_per_prompt=1, # NOTE: always 1 for control images
532
+ device=device,
533
+ dtype=latents.dtype,
534
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
535
+ )
536
+ control_image_feature = control_image_feature.to(
537
+ device=device, dtype=latents.dtype
538
+ )
539
+
540
+ adapter_state = self.cond_encoder(control_image_feature)
541
+ for i, state in enumerate(adapter_state):
542
+ adapter_state[i] = state * control_conditioning_scale
543
+
544
+ # Preprocess controlnet image if provided
545
+ do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
546
+ if do_controlnet:
547
+ controlnet_image = self.prepare_control_image(
548
+ image=controlnet_image,
549
+ width=width,
550
+ height=height,
551
+ batch_size=batch_size * num_images_per_prompt,
552
+ num_images_per_prompt=1, # NOTE: always 1 for control images
553
+ device=device,
554
+ dtype=latents.dtype,
555
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
556
+ )
557
+ controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
558
+
559
+ # 7. Denoising loop
560
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
561
+ self._num_timesteps = len(timesteps)
562
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
563
+ for i, t in enumerate(timesteps):
564
+ if self.interrupt:
565
+ continue
566
+
567
+ # expand the latents if we are doing classifier free guidance
568
+ latent_model_input = (
569
+ torch.cat([latents] * 2)
570
+ if self.do_classifier_free_guidance
571
+ else latents
572
+ )
573
+ latent_model_input = self.scheduler.scale_model_input(
574
+ latent_model_input, t
575
+ )
576
+
577
+ if i < int(num_inference_steps * control_conditioning_factor):
578
+ down_intrablock_additional_residuals = [
579
+ state.clone() for state in adapter_state
580
+ ]
581
+ else:
582
+ down_intrablock_additional_residuals = None
583
+
584
+ unet_add_kwargs = {}
585
+
586
+ # Do controlnet if provided
587
+ if do_controlnet:
588
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
589
+ latent_model_input,
590
+ t,
591
+ encoder_hidden_states=prompt_embeds,
592
+ controlnet_cond=controlnet_image,
593
+ conditioning_scale=controlnet_conditioning_scale,
594
+ guess_mode=False,
595
+ added_cond_kwargs=added_cond_kwargs,
596
+ return_dict=False,
597
+ )
598
+ unet_add_kwargs.update(
599
+ {
600
+ "down_block_additional_residuals": down_block_res_samples,
601
+ "mid_block_additional_residual": mid_block_res_sample,
602
+ }
603
+ )
604
+
605
+ # predict the noise residual
606
+ noise_pred = self.unet(
607
+ latent_model_input,
608
+ t,
609
+ encoder_hidden_states=prompt_embeds,
610
+ timestep_cond=timestep_cond,
611
+ cross_attention_kwargs=cross_attention_kwargs,
612
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
613
+ added_cond_kwargs=added_cond_kwargs,
614
+ return_dict=False,
615
+ **unet_add_kwargs,
616
+ )[0]
617
+
618
+ # perform guidance
619
+ if self.do_classifier_free_guidance:
620
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
621
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
622
+ noise_pred_text - noise_pred_uncond
623
+ )
624
+
625
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
626
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
627
+ noise_pred = rescale_noise_cfg(
628
+ noise_pred,
629
+ noise_pred_text,
630
+ guidance_rescale=self.guidance_rescale,
631
+ )
632
+
633
+ # compute the previous noisy sample x_t -> x_t-1
634
+ latents = self.scheduler.step(
635
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
636
+ )[0]
637
+
638
+ if callback_on_step_end is not None:
639
+ callback_kwargs = {}
640
+ for k in callback_on_step_end_tensor_inputs:
641
+ callback_kwargs[k] = locals()[k]
642
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
643
+
644
+ latents = callback_outputs.pop("latents", latents)
645
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
646
+ negative_prompt_embeds = callback_outputs.pop(
647
+ "negative_prompt_embeds", negative_prompt_embeds
648
+ )
649
+
650
+ # call the callback, if provided
651
+ if i == len(timesteps) - 1 or (
652
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
653
+ ):
654
+ progress_bar.update()
655
+ if callback is not None and i % callback_steps == 0:
656
+ step_idx = i // getattr(self.scheduler, "order", 1)
657
+ callback(step_idx, t, latents)
658
+
659
+ if XLA_AVAILABLE:
660
+ xm.mark_step()
661
+
662
+ if not output_type == "latent":
663
+ image = self.vae.decode(
664
+ latents / self.vae.config.scaling_factor,
665
+ return_dict=False,
666
+ generator=generator,
667
+ )[0]
668
+ image, has_nsfw_concept = self.run_safety_checker(
669
+ image, device, prompt_embeds.dtype
670
+ )
671
+ else:
672
+ image = latents
673
+ has_nsfw_concept = None
674
+
675
+ if has_nsfw_concept is None:
676
+ do_denormalize = [True] * image.shape[0]
677
+ else:
678
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
679
+ image = self.image_processor.postprocess(
680
+ image, output_type=output_type, do_denormalize=do_denormalize
681
+ )
682
+
683
+ # Offload all models
684
+ self.maybe_free_model_hooks()
685
+
686
+ if not return_dict:
687
+ return (image, has_nsfw_concept)
688
+
689
+ return StableDiffusionPipelineOutput(
690
+ images=image, nsfw_content_detected=has_nsfw_concept
691
+ )
692
+
693
+ ### NEW: adapters ###
694
+ def _init_custom_adapter(
695
+ self,
696
+ # Multi-view adapter
697
+ num_views: int = 1,
698
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
699
+ # Condition encoder
700
+ cond_in_channels: int = 6,
701
+ # For training
702
+ copy_attn_weights: bool = True,
703
+ zero_init_module_keys: List[str] = [],
704
+ ):
705
+ # Condition encoder
706
+ self.cond_encoder = T2IAdapter(
707
+ in_channels=cond_in_channels,
708
+ channels=self.unet.config.block_out_channels,
709
+ num_res_blocks=self.unet.config.layers_per_block,
710
+ downscale_factor=8,
711
+ )
712
+
713
+ # set custom attn processor for multi-view attention
714
+ self.unet: UNet2DConditionModel
715
+ set_unet_2d_condition_attn_processor(
716
+ self.unet,
717
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
718
+ query_dim=hs,
719
+ inner_dim=hs,
720
+ num_views=num_views,
721
+ name=name,
722
+ use_mv=True,
723
+ use_ref=True,
724
+ ),
725
+ set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
726
+ query_dim=hs,
727
+ inner_dim=hs,
728
+ num_views=num_views,
729
+ name=name,
730
+ use_mv=False,
731
+ use_ref=False,
732
+ ),
733
+ )
734
+
735
+ # copy decoupled attention weights from original unet
736
+ if copy_attn_weights:
737
+ state_dict = self.unet.state_dict()
738
+ for key in state_dict.keys():
739
+ if "_mv" in key:
740
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
741
+ elif "_ref" in key:
742
+ compatible_key = key.replace("_ref", "").replace("processor.", "")
743
+ else:
744
+ compatible_key = key
745
+
746
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
747
+ if is_zero_init_key:
748
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
749
+ else:
750
+ state_dict[key] = state_dict[compatible_key].clone()
751
+ self.unet.load_state_dict(state_dict)
752
+
753
+ def _load_custom_adapter(self, state_dict):
754
+ self.unet.load_state_dict(state_dict, strict=False)
755
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
756
+
757
+ def _save_custom_adapter(
758
+ self,
759
+ include_keys: Optional[List[str]] = None,
760
+ exclude_keys: Optional[List[str]] = None,
761
+ ):
762
+ def include_fn(k):
763
+ is_included = False
764
+
765
+ if include_keys is not None:
766
+ is_included = is_included or any([key in k for key in include_keys])
767
+ if exclude_keys is not None:
768
+ is_included = is_included and not any(
769
+ [key in k for key in exclude_keys]
770
+ )
771
+
772
+ return is_included
773
+
774
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
775
+ state_dict.update(self.cond_encoder.state_dict())
776
+
777
+ return state_dict
mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ import torch.nn as nn
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.models import (
24
+ AutoencoderKL,
25
+ ImageProjection,
26
+ T2IAdapter,
27
+ UNet2DConditionModel,
28
+ )
29
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
30
+ StableDiffusionXLPipelineOutput,
31
+ )
32
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
33
+ StableDiffusionXLPipeline,
34
+ rescale_noise_cfg,
35
+ retrieve_timesteps,
36
+ )
37
+ from diffusers.schedulers import KarrasDiffusionSchedulers
38
+ from diffusers.utils import deprecate, logging
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from einops import rearrange
41
+ from transformers import (
42
+ CLIPImageProcessor,
43
+ CLIPTextModel,
44
+ CLIPTextModelWithProjection,
45
+ CLIPTokenizer,
46
+ CLIPVisionModelWithProjection,
47
+ )
48
+
49
+ from ..loaders import CustomAdapterMixin
50
+ from ..models.attention_processor import (
51
+ DecoupledMVRowSelfAttnProcessor2_0,
52
+ set_unet_2d_condition_attn_processor,
53
+ )
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ def retrieve_latents(
59
+ encoder_output: torch.Tensor,
60
+ generator: Optional[torch.Generator] = None,
61
+ sample_mode: str = "sample",
62
+ ):
63
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
64
+ return encoder_output.latent_dist.sample(generator)
65
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
66
+ return encoder_output.latent_dist.mode()
67
+ elif hasattr(encoder_output, "latents"):
68
+ return encoder_output.latents
69
+ else:
70
+ raise AttributeError("Could not access latents of provided encoder_output")
71
+
72
+
73
+ class MVAdapterI2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
74
+ def __init__(
75
+ self,
76
+ vae: AutoencoderKL,
77
+ text_encoder: CLIPTextModel,
78
+ text_encoder_2: CLIPTextModelWithProjection,
79
+ tokenizer: CLIPTokenizer,
80
+ tokenizer_2: CLIPTokenizer,
81
+ unet: UNet2DConditionModel,
82
+ scheduler: KarrasDiffusionSchedulers,
83
+ image_encoder: CLIPVisionModelWithProjection = None,
84
+ feature_extractor: CLIPImageProcessor = None,
85
+ force_zeros_for_empty_prompt: bool = True,
86
+ add_watermarker: Optional[bool] = None,
87
+ ):
88
+ super().__init__(
89
+ vae=vae,
90
+ text_encoder=text_encoder,
91
+ text_encoder_2=text_encoder_2,
92
+ tokenizer=tokenizer,
93
+ tokenizer_2=tokenizer_2,
94
+ unet=unet,
95
+ scheduler=scheduler,
96
+ image_encoder=image_encoder,
97
+ feature_extractor=feature_extractor,
98
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
99
+ add_watermarker=add_watermarker,
100
+ )
101
+
102
+ self.control_image_processor = VaeImageProcessor(
103
+ vae_scale_factor=self.vae_scale_factor,
104
+ do_convert_rgb=True,
105
+ do_normalize=False,
106
+ )
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
109
+ def prepare_image_latents(
110
+ self,
111
+ image,
112
+ timestep,
113
+ batch_size,
114
+ num_images_per_prompt,
115
+ dtype,
116
+ device,
117
+ generator=None,
118
+ add_noise=True,
119
+ ):
120
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
121
+ raise ValueError(
122
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
123
+ )
124
+
125
+ latents_mean = latents_std = None
126
+ if (
127
+ hasattr(self.vae.config, "latents_mean")
128
+ and self.vae.config.latents_mean is not None
129
+ ):
130
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
131
+ if (
132
+ hasattr(self.vae.config, "latents_std")
133
+ and self.vae.config.latents_std is not None
134
+ ):
135
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
136
+
137
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
138
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
139
+ self.text_encoder_2.to("cpu")
140
+ torch.cuda.empty_cache()
141
+
142
+ image = image.to(device=device, dtype=dtype)
143
+
144
+ batch_size = batch_size * num_images_per_prompt
145
+
146
+ if image.shape[1] == 4:
147
+ init_latents = image
148
+
149
+ else:
150
+ # make sure the VAE is in float32 mode, as it overflows in float16
151
+ if self.vae.config.force_upcast:
152
+ image = image.float()
153
+ self.vae.to(dtype=torch.float32)
154
+
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ elif isinstance(generator, list):
162
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
163
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
164
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
165
+ raise ValueError(
166
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
167
+ )
168
+
169
+ init_latents = [
170
+ retrieve_latents(
171
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
172
+ )
173
+ for i in range(batch_size)
174
+ ]
175
+ init_latents = torch.cat(init_latents, dim=0)
176
+ else:
177
+ init_latents = retrieve_latents(
178
+ self.vae.encode(image), generator=generator
179
+ )
180
+
181
+ if self.vae.config.force_upcast:
182
+ self.vae.to(dtype)
183
+
184
+ init_latents = init_latents.to(dtype)
185
+ if latents_mean is not None and latents_std is not None:
186
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
187
+ latents_std = latents_std.to(device=device, dtype=dtype)
188
+ init_latents = (
189
+ (init_latents - latents_mean)
190
+ * self.vae.config.scaling_factor
191
+ / latents_std
192
+ )
193
+ else:
194
+ init_latents = self.vae.config.scaling_factor * init_latents
195
+
196
+ if (
197
+ batch_size > init_latents.shape[0]
198
+ and batch_size % init_latents.shape[0] == 0
199
+ ):
200
+ # expand init_latents for batch_size
201
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
202
+ init_latents = torch.cat(
203
+ [init_latents] * additional_image_per_prompt, dim=0
204
+ )
205
+ elif (
206
+ batch_size > init_latents.shape[0]
207
+ and batch_size % init_latents.shape[0] != 0
208
+ ):
209
+ raise ValueError(
210
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
211
+ )
212
+ else:
213
+ init_latents = torch.cat([init_latents], dim=0)
214
+
215
+ if add_noise:
216
+ shape = init_latents.shape
217
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
218
+ # get latents
219
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
220
+
221
+ latents = init_latents
222
+
223
+ return latents
224
+
225
+ def prepare_control_image(
226
+ self,
227
+ image,
228
+ width,
229
+ height,
230
+ batch_size,
231
+ num_images_per_prompt,
232
+ device,
233
+ dtype,
234
+ do_classifier_free_guidance=False,
235
+ num_empty_images=0, # for concat in batch like ImageDream
236
+ ):
237
+ assert hasattr(
238
+ self, "control_image_processor"
239
+ ), "control_image_processor is not initialized"
240
+
241
+ image = self.control_image_processor.preprocess(
242
+ image, height=height, width=width
243
+ ).to(dtype=torch.float32)
244
+
245
+ if num_empty_images > 0:
246
+ image = torch.cat(
247
+ [image, torch.zeros_like(image[:num_empty_images])], dim=0
248
+ )
249
+
250
+ image_batch_size = image.shape[0]
251
+
252
+ if image_batch_size == 1:
253
+ repeat_by = batch_size
254
+ else:
255
+ # image batch size is the same as prompt batch size
256
+ repeat_by = num_images_per_prompt # always 1 for control image
257
+
258
+ image = image.repeat_interleave(repeat_by, dim=0)
259
+
260
+ image = image.to(device=device, dtype=dtype)
261
+
262
+ if do_classifier_free_guidance:
263
+ image = torch.cat([image] * 2)
264
+
265
+ return image
266
+
267
+ @torch.no_grad()
268
+ def __call__(
269
+ self,
270
+ prompt: Union[str, List[str]] = None,
271
+ prompt_2: Optional[Union[str, List[str]]] = None,
272
+ height: Optional[int] = None,
273
+ width: Optional[int] = None,
274
+ num_inference_steps: int = 50,
275
+ timesteps: List[int] = None,
276
+ denoising_end: Optional[float] = None,
277
+ guidance_scale: float = 5.0,
278
+ negative_prompt: Optional[Union[str, List[str]]] = None,
279
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
280
+ num_images_per_prompt: Optional[int] = 1,
281
+ eta: float = 0.0,
282
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
283
+ latents: Optional[torch.FloatTensor] = None,
284
+ prompt_embeds: Optional[torch.FloatTensor] = None,
285
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
286
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
287
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
288
+ ip_adapter_image: Optional[PipelineImageInput] = None,
289
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
290
+ output_type: Optional[str] = "pil",
291
+ return_dict: bool = True,
292
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
293
+ guidance_rescale: float = 0.0,
294
+ original_size: Optional[Tuple[int, int]] = None,
295
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
296
+ target_size: Optional[Tuple[int, int]] = None,
297
+ negative_original_size: Optional[Tuple[int, int]] = None,
298
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
299
+ negative_target_size: Optional[Tuple[int, int]] = None,
300
+ clip_skip: Optional[int] = None,
301
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
302
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
303
+ # NEW
304
+ mv_scale: float = 1.0,
305
+ # Camera or geometry condition
306
+ control_image: Optional[PipelineImageInput] = None,
307
+ control_conditioning_scale: Optional[float] = 1.0,
308
+ control_conditioning_factor: float = 1.0,
309
+ # Image condition
310
+ reference_image: Optional[PipelineImageInput] = None,
311
+ reference_conditioning_scale: Optional[float] = 1.0,
312
+ **kwargs,
313
+ ):
314
+ r"""
315
+ Function invoked when calling the pipeline for generation.
316
+
317
+ Args:
318
+ prompt (`str` or `List[str]`, *optional*):
319
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
320
+ instead.
321
+ prompt_2 (`str` or `List[str]`, *optional*):
322
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
323
+ used in both text-encoders
324
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
325
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
326
+ Anything below 512 pixels won't work well for
327
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
328
+ and checkpoints that are not specifically fine-tuned on low resolutions.
329
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
330
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
331
+ Anything below 512 pixels won't work well for
332
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
333
+ and checkpoints that are not specifically fine-tuned on low resolutions.
334
+ num_inference_steps (`int`, *optional*, defaults to 50):
335
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
336
+ expense of slower inference.
337
+ timesteps (`List[int]`, *optional*):
338
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
339
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
340
+ passed will be used. Must be in descending order.
341
+ denoising_end (`float`, *optional*):
342
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
343
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
344
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
345
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
346
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
347
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
348
+ guidance_scale (`float`, *optional*, defaults to 5.0):
349
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
350
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
351
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
352
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
353
+ usually at the expense of lower image quality.
354
+ negative_prompt (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
356
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
357
+ less than `1`).
358
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
359
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
360
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
361
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
362
+ The number of images to generate per prompt.
363
+ eta (`float`, *optional*, defaults to 0.0):
364
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
365
+ [`schedulers.DDIMScheduler`], will be ignored for others.
366
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
367
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
368
+ to make generation deterministic.
369
+ latents (`torch.FloatTensor`, *optional*):
370
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
371
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
372
+ tensor will ge generated by sampling using the supplied random `generator`.
373
+ prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
375
+ provided, text embeddings will be generated from `prompt` input argument.
376
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
377
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
378
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
379
+ argument.
380
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
382
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
383
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
384
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
385
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
386
+ input argument.
387
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
388
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
389
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
390
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
391
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
392
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
393
+ output_type (`str`, *optional*, defaults to `"pil"`):
394
+ The output format of the generate image. Choose between
395
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
396
+ return_dict (`bool`, *optional*, defaults to `True`):
397
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
398
+ of a plain tuple.
399
+ cross_attention_kwargs (`dict`, *optional*):
400
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
401
+ `self.processor` in
402
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
403
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
404
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
405
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
406
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
407
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
408
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
409
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
410
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
411
+ explained in section 2.2 of
412
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
413
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
414
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
415
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
416
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
417
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
418
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
419
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
420
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
421
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
422
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
423
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
424
+ micro-conditioning as explained in section 2.2 of
425
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
426
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
427
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
428
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
429
+ micro-conditioning as explained in section 2.2 of
430
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
431
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
432
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
433
+ To negatively condition the generation process based on a target image resolution. It should be as same
434
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
435
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
436
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
437
+ callback_on_step_end (`Callable`, *optional*):
438
+ A function that calls at the end of each denoising steps during the inference. The function is called
439
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
440
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
441
+ `callback_on_step_end_tensor_inputs`.
442
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
443
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
444
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
445
+ `._callback_tensor_inputs` attribute of your pipeline class.
446
+
447
+ Examples:
448
+
449
+ Returns:
450
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
451
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
452
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
453
+ """
454
+
455
+ callback = kwargs.pop("callback", None)
456
+ callback_steps = kwargs.pop("callback_steps", None)
457
+
458
+ if callback is not None:
459
+ deprecate(
460
+ "callback",
461
+ "1.0.0",
462
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
463
+ )
464
+ if callback_steps is not None:
465
+ deprecate(
466
+ "callback_steps",
467
+ "1.0.0",
468
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
469
+ )
470
+
471
+ # 0. Default height and width to unet
472
+ height = height or self.default_sample_size * self.vae_scale_factor
473
+ width = width or self.default_sample_size * self.vae_scale_factor
474
+
475
+ original_size = original_size or (height, width)
476
+ target_size = target_size or (height, width)
477
+
478
+ # 1. Check inputs. Raise error if not correct
479
+ self.check_inputs(
480
+ prompt,
481
+ prompt_2,
482
+ height,
483
+ width,
484
+ callback_steps,
485
+ negative_prompt,
486
+ negative_prompt_2,
487
+ prompt_embeds,
488
+ negative_prompt_embeds,
489
+ pooled_prompt_embeds,
490
+ negative_pooled_prompt_embeds,
491
+ ip_adapter_image,
492
+ ip_adapter_image_embeds,
493
+ callback_on_step_end_tensor_inputs,
494
+ )
495
+
496
+ self._guidance_scale = guidance_scale
497
+ self._guidance_rescale = guidance_rescale
498
+ self._clip_skip = clip_skip
499
+ self._cross_attention_kwargs = cross_attention_kwargs
500
+ self._denoising_end = denoising_end
501
+ self._interrupt = False
502
+
503
+ # 2. Define call parameters
504
+ if prompt is not None and isinstance(prompt, str):
505
+ batch_size = 1
506
+ elif prompt is not None and isinstance(prompt, list):
507
+ batch_size = len(prompt)
508
+ else:
509
+ batch_size = prompt_embeds.shape[0]
510
+
511
+ device = self._execution_device
512
+
513
+ # 3. Encode input prompt
514
+ lora_scale = (
515
+ self.cross_attention_kwargs.get("scale", None)
516
+ if self.cross_attention_kwargs is not None
517
+ else None
518
+ )
519
+
520
+ (
521
+ prompt_embeds,
522
+ negative_prompt_embeds,
523
+ pooled_prompt_embeds,
524
+ negative_pooled_prompt_embeds,
525
+ ) = self.encode_prompt(
526
+ prompt=prompt,
527
+ prompt_2=prompt_2,
528
+ device=device,
529
+ num_images_per_prompt=num_images_per_prompt,
530
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
531
+ negative_prompt=negative_prompt,
532
+ negative_prompt_2=negative_prompt_2,
533
+ prompt_embeds=prompt_embeds,
534
+ negative_prompt_embeds=negative_prompt_embeds,
535
+ pooled_prompt_embeds=pooled_prompt_embeds,
536
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
537
+ lora_scale=lora_scale,
538
+ clip_skip=self.clip_skip,
539
+ )
540
+
541
+ # 4. Prepare timesteps
542
+ timesteps, num_inference_steps = retrieve_timesteps(
543
+ self.scheduler, num_inference_steps, device, timesteps
544
+ )
545
+
546
+ # 5. Prepare latent variables
547
+ num_channels_latents = self.unet.config.in_channels
548
+ latents = self.prepare_latents(
549
+ batch_size * num_images_per_prompt,
550
+ num_channels_latents,
551
+ height,
552
+ width,
553
+ prompt_embeds.dtype,
554
+ device,
555
+ generator,
556
+ latents,
557
+ )
558
+
559
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
560
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
561
+
562
+ # 7. Prepare added time ids & embeddings
563
+ add_text_embeds = pooled_prompt_embeds
564
+ if self.text_encoder_2 is None:
565
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
566
+ else:
567
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
568
+
569
+ add_time_ids = self._get_add_time_ids(
570
+ original_size,
571
+ crops_coords_top_left,
572
+ target_size,
573
+ dtype=prompt_embeds.dtype,
574
+ text_encoder_projection_dim=text_encoder_projection_dim,
575
+ )
576
+ if negative_original_size is not None and negative_target_size is not None:
577
+ negative_add_time_ids = self._get_add_time_ids(
578
+ negative_original_size,
579
+ negative_crops_coords_top_left,
580
+ negative_target_size,
581
+ dtype=prompt_embeds.dtype,
582
+ text_encoder_projection_dim=text_encoder_projection_dim,
583
+ )
584
+ else:
585
+ negative_add_time_ids = add_time_ids
586
+
587
+ if self.do_classifier_free_guidance:
588
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
589
+ add_text_embeds = torch.cat(
590
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
591
+ )
592
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
593
+
594
+ prompt_embeds = prompt_embeds.to(device)
595
+ add_text_embeds = add_text_embeds.to(device)
596
+ add_time_ids = add_time_ids.to(device).repeat(
597
+ batch_size * num_images_per_prompt, 1
598
+ )
599
+
600
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
601
+ image_embeds = self.prepare_ip_adapter_image_embeds(
602
+ ip_adapter_image,
603
+ ip_adapter_image_embeds,
604
+ device,
605
+ batch_size * num_images_per_prompt,
606
+ self.do_classifier_free_guidance,
607
+ )
608
+
609
+ # Preprocess reference image
610
+ reference_image = self.image_processor.preprocess(reference_image)
611
+ reference_latents = self.prepare_image_latents(
612
+ reference_image,
613
+ timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
614
+ batch_size,
615
+ 1,
616
+ prompt_embeds.dtype,
617
+ device,
618
+ generator,
619
+ add_noise=False,
620
+ )
621
+
622
+ with torch.no_grad():
623
+ ref_timesteps = torch.zeros_like(timesteps[0])
624
+ ref_hidden_states = {}
625
+
626
+ self.unet(
627
+ reference_latents,
628
+ ref_timesteps,
629
+ encoder_hidden_states=prompt_embeds[-1:],
630
+ added_cond_kwargs={
631
+ "text_embeds": add_text_embeds[-1:],
632
+ "time_ids": add_time_ids[-1:],
633
+ },
634
+ cross_attention_kwargs={
635
+ "cache_hidden_states": ref_hidden_states,
636
+ "use_mv": False,
637
+ "use_ref": False,
638
+ },
639
+ return_dict=False,
640
+ )
641
+ ref_hidden_states = {
642
+ k: v.repeat_interleave(num_images_per_prompt, dim=0)
643
+ for k, v in ref_hidden_states.items()
644
+ }
645
+ if self.do_classifier_free_guidance:
646
+ ref_hidden_states = {
647
+ k: torch.cat([torch.zeros_like(v), v], dim=0)
648
+ for k, v in ref_hidden_states.items()
649
+ }
650
+
651
+ cross_attention_kwargs = {
652
+ "mv_scale": mv_scale,
653
+ "ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()},
654
+ "ref_scale": reference_conditioning_scale,
655
+ "num_views": num_images_per_prompt,
656
+ **(self.cross_attention_kwargs or {}),
657
+ }
658
+
659
+ # Preprocess control image
660
+ control_image_feature = self.prepare_control_image(
661
+ image=control_image,
662
+ width=width,
663
+ height=height,
664
+ batch_size=batch_size * num_images_per_prompt,
665
+ num_images_per_prompt=1, # NOTE: always 1 for control images
666
+ device=device,
667
+ dtype=latents.dtype,
668
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
669
+ )
670
+ control_image_feature = control_image_feature.to(
671
+ device=device, dtype=latents.dtype
672
+ )
673
+
674
+ adapter_state = self.cond_encoder(control_image_feature)
675
+ for i, state in enumerate(adapter_state):
676
+ adapter_state[i] = state * control_conditioning_scale
677
+
678
+ # 8. Denoising loop
679
+ num_warmup_steps = max(
680
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
681
+ )
682
+
683
+ # 8.1 Apply denoising_end
684
+ if (
685
+ self.denoising_end is not None
686
+ and isinstance(self.denoising_end, float)
687
+ and self.denoising_end > 0
688
+ and self.denoising_end < 1
689
+ ):
690
+ discrete_timestep_cutoff = int(
691
+ round(
692
+ self.scheduler.config.num_train_timesteps
693
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
694
+ )
695
+ )
696
+ num_inference_steps = len(
697
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
698
+ )
699
+ timesteps = timesteps[:num_inference_steps]
700
+
701
+ # 9. Optionally get Guidance Scale Embedding
702
+ timestep_cond = None
703
+ if self.unet.config.time_cond_proj_dim is not None:
704
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
705
+ batch_size * num_images_per_prompt
706
+ )
707
+ timestep_cond = self.get_guidance_scale_embedding(
708
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
709
+ ).to(device=device, dtype=latents.dtype)
710
+
711
+ self._num_timesteps = len(timesteps)
712
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
713
+ for i, t in enumerate(timesteps):
714
+ if self.interrupt:
715
+ continue
716
+
717
+ # expand the latents if we are doing classifier free guidance
718
+ latent_model_input = (
719
+ torch.cat([latents] * 2)
720
+ if self.do_classifier_free_guidance
721
+ else latents
722
+ )
723
+
724
+ latent_model_input = self.scheduler.scale_model_input(
725
+ latent_model_input, t
726
+ )
727
+
728
+ added_cond_kwargs = {
729
+ "text_embeds": add_text_embeds,
730
+ "time_ids": add_time_ids,
731
+ }
732
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
733
+ added_cond_kwargs["image_embeds"] = image_embeds
734
+
735
+ if i < int(num_inference_steps * control_conditioning_factor):
736
+ down_intrablock_additional_residuals = [
737
+ state.clone() for state in adapter_state
738
+ ]
739
+ else:
740
+ down_intrablock_additional_residuals = None
741
+
742
+ # predict the noise residual
743
+ noise_pred = self.unet(
744
+ latent_model_input,
745
+ t,
746
+ encoder_hidden_states=prompt_embeds,
747
+ timestep_cond=timestep_cond,
748
+ cross_attention_kwargs=cross_attention_kwargs,
749
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
750
+ added_cond_kwargs=added_cond_kwargs,
751
+ return_dict=False,
752
+ )[0]
753
+
754
+ # perform guidance
755
+ if self.do_classifier_free_guidance:
756
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
757
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
758
+ noise_pred_text - noise_pred_uncond
759
+ )
760
+
761
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
762
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
763
+ noise_pred = rescale_noise_cfg(
764
+ noise_pred,
765
+ noise_pred_text,
766
+ guidance_rescale=self.guidance_rescale,
767
+ )
768
+
769
+ # compute the previous noisy sample x_t -> x_t-1
770
+ latents_dtype = latents.dtype
771
+ latents = self.scheduler.step(
772
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
773
+ )[0]
774
+ if latents.dtype != latents_dtype:
775
+ if torch.backends.mps.is_available():
776
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
777
+ latents = latents.to(latents_dtype)
778
+
779
+ if callback_on_step_end is not None:
780
+ callback_kwargs = {}
781
+ for k in callback_on_step_end_tensor_inputs:
782
+ callback_kwargs[k] = locals()[k]
783
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
784
+
785
+ latents = callback_outputs.pop("latents", latents)
786
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
787
+ negative_prompt_embeds = callback_outputs.pop(
788
+ "negative_prompt_embeds", negative_prompt_embeds
789
+ )
790
+ add_text_embeds = callback_outputs.pop(
791
+ "add_text_embeds", add_text_embeds
792
+ )
793
+ negative_pooled_prompt_embeds = callback_outputs.pop(
794
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
795
+ )
796
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
797
+ negative_add_time_ids = callback_outputs.pop(
798
+ "negative_add_time_ids", negative_add_time_ids
799
+ )
800
+
801
+ # call the callback, if provided
802
+ if i == len(timesteps) - 1 or (
803
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
804
+ ):
805
+ progress_bar.update()
806
+ if callback is not None and i % callback_steps == 0:
807
+ step_idx = i // getattr(self.scheduler, "order", 1)
808
+ callback(step_idx, t, latents)
809
+
810
+ if not output_type == "latent":
811
+ # make sure the VAE is in float32 mode, as it overflows in float16
812
+ needs_upcasting = (
813
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
814
+ )
815
+
816
+ if needs_upcasting:
817
+ self.upcast_vae()
818
+ latents = latents.to(
819
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
820
+ )
821
+ elif latents.dtype != self.vae.dtype:
822
+ if torch.backends.mps.is_available():
823
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
824
+ self.vae = self.vae.to(latents.dtype)
825
+
826
+ # unscale/denormalize the latents
827
+ # denormalize with the mean and std if available and not None
828
+ has_latents_mean = (
829
+ hasattr(self.vae.config, "latents_mean")
830
+ and self.vae.config.latents_mean is not None
831
+ )
832
+ has_latents_std = (
833
+ hasattr(self.vae.config, "latents_std")
834
+ and self.vae.config.latents_std is not None
835
+ )
836
+ if has_latents_mean and has_latents_std:
837
+ latents_mean = (
838
+ torch.tensor(self.vae.config.latents_mean)
839
+ .view(1, 4, 1, 1)
840
+ .to(latents.device, latents.dtype)
841
+ )
842
+ latents_std = (
843
+ torch.tensor(self.vae.config.latents_std)
844
+ .view(1, 4, 1, 1)
845
+ .to(latents.device, latents.dtype)
846
+ )
847
+ latents = (
848
+ latents * latents_std / self.vae.config.scaling_factor
849
+ + latents_mean
850
+ )
851
+ else:
852
+ latents = latents / self.vae.config.scaling_factor
853
+
854
+ image = self.vae.decode(latents, return_dict=False)[0]
855
+
856
+ # cast back to fp16 if needed
857
+ if needs_upcasting:
858
+ self.vae.to(dtype=torch.float16)
859
+ else:
860
+ image = latents
861
+
862
+ if not output_type == "latent":
863
+ # apply watermark if available
864
+ if self.watermark is not None:
865
+ image = self.watermark.apply_watermark(image)
866
+
867
+ image = self.image_processor.postprocess(image, output_type=output_type)
868
+
869
+ # Offload all models
870
+ self.maybe_free_model_hooks()
871
+
872
+ if not return_dict:
873
+ return (image,)
874
+
875
+ return StableDiffusionXLPipelineOutput(images=image)
876
+
877
+ ### NEW: adapters ###
878
+ def _init_custom_adapter(
879
+ self,
880
+ # Multi-view adapter
881
+ num_views: int = 1,
882
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
883
+ # Condition encoder
884
+ cond_in_channels: int = 6,
885
+ # For training
886
+ copy_attn_weights: bool = True,
887
+ zero_init_module_keys: List[str] = [],
888
+ ):
889
+ # Condition encoder
890
+ self.cond_encoder = T2IAdapter(
891
+ in_channels=cond_in_channels,
892
+ channels=(320, 640, 1280, 1280),
893
+ num_res_blocks=2,
894
+ downscale_factor=16,
895
+ adapter_type="full_adapter_xl",
896
+ )
897
+
898
+ # set custom attn processor for multi-view attention and image cross-attention
899
+ self.unet: UNet2DConditionModel
900
+ set_unet_2d_condition_attn_processor(
901
+ self.unet,
902
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
903
+ query_dim=hs,
904
+ inner_dim=hs,
905
+ num_views=num_views,
906
+ name=name,
907
+ use_mv=True,
908
+ use_ref=True,
909
+ ),
910
+ set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
911
+ query_dim=hs,
912
+ inner_dim=hs,
913
+ num_views=num_views,
914
+ name=name,
915
+ use_mv=False,
916
+ use_ref=False,
917
+ ),
918
+ )
919
+
920
+ # copy decoupled attention weights from original unet
921
+ if copy_attn_weights:
922
+ state_dict = self.unet.state_dict()
923
+ for key in state_dict.keys():
924
+ if "_mv" in key:
925
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
926
+ elif "_ref" in key:
927
+ compatible_key = key.replace("_ref", "").replace("processor.", "")
928
+ else:
929
+ compatible_key = key
930
+
931
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
932
+ if is_zero_init_key:
933
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
934
+ else:
935
+ state_dict[key] = state_dict[compatible_key].clone()
936
+ self.unet.load_state_dict(state_dict)
937
+
938
+ def _load_custom_adapter(self, state_dict):
939
+ self.unet.load_state_dict(state_dict, strict=False)
940
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
941
+
942
+ def _save_custom_adapter(
943
+ self,
944
+ include_keys: Optional[List[str]] = None,
945
+ exclude_keys: Optional[List[str]] = None,
946
+ ):
947
+ def include_fn(k):
948
+ is_included = False
949
+
950
+ if include_keys is not None:
951
+ is_included = is_included or any([key in k for key in include_keys])
952
+ if exclude_keys is not None:
953
+ is_included = is_included and not any(
954
+ [key in k for key in exclude_keys]
955
+ )
956
+
957
+ return is_included
958
+
959
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
960
+ state_dict.update(self.cond_encoder.state_dict())
961
+
962
+ return state_dict
mvadapter/pipelines/pipeline_mvadapter_t2mv_sd.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+
17
+ import torch
18
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
19
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
20
+ from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
21
+ from diffusers.pipelines.stable_diffusion.pipeline_output import (
22
+ StableDiffusionPipelineOutput,
23
+ )
24
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
25
+ StableDiffusionPipeline,
26
+ rescale_noise_cfg,
27
+ retrieve_timesteps,
28
+ )
29
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
30
+ StableDiffusionSafetyChecker,
31
+ )
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import deprecate, is_torch_xla_available, logging
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from packaging import version
36
+ from transformers import (
37
+ CLIPImageProcessor,
38
+ CLIPTextModel,
39
+ CLIPTokenizer,
40
+ CLIPVisionModelWithProjection,
41
+ )
42
+
43
+ from ..loaders import CustomAdapterMixin
44
+ from ..models.attention_processor import (
45
+ DecoupledMVRowSelfAttnProcessor2_0,
46
+ set_unet_2d_condition_attn_processor,
47
+ )
48
+
49
+ if is_torch_xla_available():
50
+ import torch_xla.core.xla_model as xm
51
+
52
+ XLA_AVAILABLE = True
53
+ else:
54
+ XLA_AVAILABLE = False
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ class MVAdapterT2MVSDPipeline(StableDiffusionPipeline, CustomAdapterMixin):
60
+ def __init__(
61
+ self,
62
+ vae: AutoencoderKL,
63
+ text_encoder: CLIPTextModel,
64
+ tokenizer: CLIPTokenizer,
65
+ unet: UNet2DConditionModel,
66
+ scheduler: KarrasDiffusionSchedulers,
67
+ safety_checker: StableDiffusionSafetyChecker,
68
+ feature_extractor: CLIPImageProcessor,
69
+ image_encoder: CLIPVisionModelWithProjection = None,
70
+ requires_safety_checker: bool = True,
71
+ ):
72
+ super().__init__(
73
+ vae=vae,
74
+ text_encoder=text_encoder,
75
+ tokenizer=tokenizer,
76
+ unet=unet,
77
+ scheduler=scheduler,
78
+ safety_checker=safety_checker,
79
+ feature_extractor=feature_extractor,
80
+ image_encoder=image_encoder,
81
+ requires_safety_checker=requires_safety_checker,
82
+ )
83
+
84
+ self.control_image_processor = VaeImageProcessor(
85
+ vae_scale_factor=self.vae_scale_factor,
86
+ do_convert_rgb=True,
87
+ do_normalize=False,
88
+ )
89
+
90
+ def prepare_control_image(
91
+ self,
92
+ image,
93
+ width,
94
+ height,
95
+ batch_size,
96
+ num_images_per_prompt,
97
+ device,
98
+ dtype,
99
+ do_classifier_free_guidance=False,
100
+ ):
101
+ assert hasattr(
102
+ self, "control_image_processor"
103
+ ), "control_image_processor is not initialized"
104
+
105
+ image = self.control_image_processor.preprocess(
106
+ image, height=height, width=width
107
+ ).to(dtype=torch.float32)
108
+ image_batch_size = image.shape[0]
109
+
110
+ if image_batch_size == 1:
111
+ repeat_by = batch_size
112
+ else:
113
+ # image batch size is the same as prompt batch size
114
+ repeat_by = num_images_per_prompt # always 1 for control image
115
+
116
+ image = image.repeat_interleave(repeat_by, dim=0)
117
+
118
+ image = image.to(device=device, dtype=dtype)
119
+
120
+ if do_classifier_free_guidance:
121
+ image = torch.cat([image] * 2)
122
+
123
+ return image
124
+
125
+ @torch.no_grad()
126
+ def __call__(
127
+ self,
128
+ prompt: Union[str, List[str]] = None,
129
+ height: Optional[int] = None,
130
+ width: Optional[int] = None,
131
+ num_inference_steps: int = 50,
132
+ timesteps: List[int] = None,
133
+ sigmas: List[float] = None,
134
+ guidance_scale: float = 7.5,
135
+ negative_prompt: Optional[Union[str, List[str]]] = None,
136
+ num_images_per_prompt: Optional[int] = 1,
137
+ eta: float = 0.0,
138
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
139
+ latents: Optional[torch.Tensor] = None,
140
+ prompt_embeds: Optional[torch.Tensor] = None,
141
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
142
+ ip_adapter_image: Optional[PipelineImageInput] = None,
143
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
144
+ output_type: Optional[str] = "pil",
145
+ return_dict: bool = True,
146
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
147
+ guidance_rescale: float = 0.0,
148
+ clip_skip: Optional[int] = None,
149
+ callback_on_step_end: Optional[
150
+ Union[
151
+ Callable[[int, int, Dict], None],
152
+ PipelineCallback,
153
+ MultiPipelineCallbacks,
154
+ ]
155
+ ] = None,
156
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
157
+ # NEW
158
+ mv_scale: float = 1.0,
159
+ # Camera or geometry condition
160
+ control_image: Optional[PipelineImageInput] = None,
161
+ control_conditioning_scale: Optional[float] = 1.0,
162
+ control_conditioning_factor: float = 1.0,
163
+ # Optional. controlnet
164
+ controlnet_image: Optional[PipelineImageInput] = None,
165
+ controlnet_conditioning_scale: Optional[float] = 1.0,
166
+ **kwargs,
167
+ ):
168
+ r"""
169
+ The call function to the pipeline for generation.
170
+
171
+ Args:
172
+ prompt (`str` or `List[str]`, *optional*):
173
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
174
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
175
+ The height in pixels of the generated image.
176
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
177
+ The width in pixels of the generated image.
178
+ num_inference_steps (`int`, *optional*, defaults to 50):
179
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
180
+ expense of slower inference.
181
+ timesteps (`List[int]`, *optional*):
182
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
183
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
184
+ passed will be used. Must be in descending order.
185
+ sigmas (`List[float]`, *optional*):
186
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
187
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
188
+ will be used.
189
+ guidance_scale (`float`, *optional*, defaults to 7.5):
190
+ A higher guidance scale value encourages the model to generate images closely linked to the text
191
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
192
+ negative_prompt (`str` or `List[str]`, *optional*):
193
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
194
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
195
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
196
+ The number of images to generate per prompt.
197
+ eta (`float`, *optional*, defaults to 0.0):
198
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
199
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
200
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
201
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
202
+ generation deterministic.
203
+ latents (`torch.Tensor`, *optional*):
204
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
205
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
206
+ tensor is generated by sampling using the supplied random `generator`.
207
+ prompt_embeds (`torch.Tensor`, *optional*):
208
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
209
+ provided, text embeddings are generated from the `prompt` input argument.
210
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
211
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
212
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
213
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
214
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
215
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
216
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
217
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
218
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
219
+ output_type (`str`, *optional*, defaults to `"pil"`):
220
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
221
+ return_dict (`bool`, *optional*, defaults to `True`):
222
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
223
+ plain tuple.
224
+ cross_attention_kwargs (`dict`, *optional*):
225
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
226
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
227
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
228
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
229
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
230
+ using zero terminal SNR.
231
+ clip_skip (`int`, *optional*):
232
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
233
+ the output of the pre-final layer will be used for computing the prompt embeddings.
234
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
235
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
236
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
237
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
238
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
239
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
240
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
241
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
242
+ `._callback_tensor_inputs` attribute of your pipeline class.
243
+
244
+ Examples:
245
+
246
+ Returns:
247
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
248
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
249
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
250
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
251
+ "not-safe-for-work" (nsfw) content.
252
+ """
253
+
254
+ callback = kwargs.pop("callback", None)
255
+ callback_steps = kwargs.pop("callback_steps", None)
256
+
257
+ if callback is not None:
258
+ deprecate(
259
+ "callback",
260
+ "1.0.0",
261
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
262
+ )
263
+ if callback_steps is not None:
264
+ deprecate(
265
+ "callback_steps",
266
+ "1.0.0",
267
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
268
+ )
269
+
270
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
271
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
272
+
273
+ # 0. Default height and width to unet
274
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
275
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
276
+ # to deal with lora scaling and other possible forward hooks
277
+
278
+ # 1. Check inputs. Raise error if not correct
279
+ self.check_inputs(
280
+ prompt,
281
+ height,
282
+ width,
283
+ callback_steps,
284
+ negative_prompt,
285
+ prompt_embeds,
286
+ negative_prompt_embeds,
287
+ ip_adapter_image,
288
+ ip_adapter_image_embeds,
289
+ callback_on_step_end_tensor_inputs,
290
+ )
291
+
292
+ self._guidance_scale = guidance_scale
293
+ self._guidance_rescale = guidance_rescale
294
+ self._clip_skip = clip_skip
295
+ self._cross_attention_kwargs = cross_attention_kwargs
296
+ self._interrupt = False
297
+
298
+ # 2. Define call parameters
299
+ if prompt is not None and isinstance(prompt, str):
300
+ batch_size = 1
301
+ elif prompt is not None and isinstance(prompt, list):
302
+ batch_size = len(prompt)
303
+ else:
304
+ batch_size = prompt_embeds.shape[0]
305
+
306
+ device = self._execution_device
307
+
308
+ # 3. Encode input prompt
309
+ lora_scale = (
310
+ self.cross_attention_kwargs.get("scale", None)
311
+ if self.cross_attention_kwargs is not None
312
+ else None
313
+ )
314
+
315
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
316
+ prompt,
317
+ device,
318
+ num_images_per_prompt,
319
+ self.do_classifier_free_guidance,
320
+ negative_prompt,
321
+ prompt_embeds=prompt_embeds,
322
+ negative_prompt_embeds=negative_prompt_embeds,
323
+ lora_scale=lora_scale,
324
+ clip_skip=self.clip_skip,
325
+ )
326
+
327
+ # For classifier free guidance, we need to do two forward passes.
328
+ # Here we concatenate the unconditional and text embeddings into a single batch
329
+ # to avoid doing two forward passes
330
+ if self.do_classifier_free_guidance:
331
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
332
+
333
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
334
+ image_embeds = self.prepare_ip_adapter_image_embeds(
335
+ ip_adapter_image,
336
+ ip_adapter_image_embeds,
337
+ device,
338
+ batch_size * num_images_per_prompt,
339
+ self.do_classifier_free_guidance,
340
+ )
341
+
342
+ # 4. Prepare timesteps
343
+ timesteps, num_inference_steps = retrieve_timesteps(
344
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
345
+ )
346
+
347
+ # 5. Prepare latent variables
348
+ num_channels_latents = self.unet.config.in_channels
349
+ latents = self.prepare_latents(
350
+ batch_size * num_images_per_prompt,
351
+ num_channels_latents,
352
+ height,
353
+ width,
354
+ prompt_embeds.dtype,
355
+ device,
356
+ generator,
357
+ latents,
358
+ )
359
+
360
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
361
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
362
+
363
+ # 6.1 Add image embeds for IP-Adapter
364
+ added_cond_kwargs = (
365
+ {"image_embeds": image_embeds}
366
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
367
+ else None
368
+ )
369
+
370
+ # 6.2 Optionally get Guidance Scale Embedding
371
+ timestep_cond = None
372
+ if self.unet.config.time_cond_proj_dim is not None:
373
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
374
+ batch_size * num_images_per_prompt
375
+ )
376
+ timestep_cond = self.get_guidance_scale_embedding(
377
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
378
+ ).to(device=device, dtype=latents.dtype)
379
+
380
+ # Preprocess control image
381
+ control_image_feature = self.prepare_control_image(
382
+ image=control_image,
383
+ width=width,
384
+ height=height,
385
+ batch_size=batch_size * num_images_per_prompt,
386
+ num_images_per_prompt=1, # NOTE: always 1 for control images
387
+ device=device,
388
+ dtype=latents.dtype,
389
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
390
+ )
391
+ control_image_feature = control_image_feature.to(
392
+ device=device, dtype=latents.dtype
393
+ )
394
+
395
+ adapter_state = self.cond_encoder(control_image_feature)
396
+ for i, state in enumerate(adapter_state):
397
+ adapter_state[i] = state * control_conditioning_scale
398
+
399
+ # Preprocess controlnet image if provided
400
+ do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
401
+ if do_controlnet:
402
+ controlnet_image = self.prepare_control_image(
403
+ image=controlnet_image,
404
+ width=width,
405
+ height=height,
406
+ batch_size=batch_size * num_images_per_prompt,
407
+ num_images_per_prompt=1, # NOTE: always 1 for control images
408
+ device=device,
409
+ dtype=latents.dtype,
410
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
411
+ )
412
+ controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
413
+
414
+ # 7. Denoising loop
415
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
416
+ self._num_timesteps = len(timesteps)
417
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
418
+ for i, t in enumerate(timesteps):
419
+ if self.interrupt:
420
+ continue
421
+
422
+ # expand the latents if we are doing classifier free guidance
423
+ latent_model_input = (
424
+ torch.cat([latents] * 2)
425
+ if self.do_classifier_free_guidance
426
+ else latents
427
+ )
428
+ latent_model_input = self.scheduler.scale_model_input(
429
+ latent_model_input, t
430
+ )
431
+
432
+ if i < int(num_inference_steps * control_conditioning_factor):
433
+ down_intrablock_additional_residuals = [
434
+ state.clone() for state in adapter_state
435
+ ]
436
+ else:
437
+ down_intrablock_additional_residuals = None
438
+
439
+ unet_add_kwargs = {}
440
+
441
+ # Do controlnet if provided
442
+ if do_controlnet:
443
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
444
+ latent_model_input,
445
+ t,
446
+ encoder_hidden_states=prompt_embeds,
447
+ controlnet_cond=controlnet_image,
448
+ conditioning_scale=controlnet_conditioning_scale,
449
+ guess_mode=False,
450
+ added_cond_kwargs=added_cond_kwargs,
451
+ return_dict=False,
452
+ )
453
+ unet_add_kwargs.update(
454
+ {
455
+ "down_block_additional_residuals": down_block_res_samples,
456
+ "mid_block_additional_residual": mid_block_res_sample,
457
+ }
458
+ )
459
+
460
+ # predict the noise residual
461
+ noise_pred = self.unet(
462
+ latent_model_input,
463
+ t,
464
+ encoder_hidden_states=prompt_embeds,
465
+ timestep_cond=timestep_cond,
466
+ cross_attention_kwargs={
467
+ "mv_scale": mv_scale,
468
+ "num_views": num_images_per_prompt,
469
+ **(self.cross_attention_kwargs or {}),
470
+ },
471
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
472
+ added_cond_kwargs=added_cond_kwargs,
473
+ return_dict=False,
474
+ **unet_add_kwargs,
475
+ )[0]
476
+
477
+ # perform guidance
478
+ if self.do_classifier_free_guidance:
479
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
480
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
481
+ noise_pred_text - noise_pred_uncond
482
+ )
483
+
484
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
485
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
486
+ noise_pred = rescale_noise_cfg(
487
+ noise_pred,
488
+ noise_pred_text,
489
+ guidance_rescale=self.guidance_rescale,
490
+ )
491
+
492
+ # compute the previous noisy sample x_t -> x_t-1
493
+ latents = self.scheduler.step(
494
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
495
+ )[0]
496
+
497
+ if callback_on_step_end is not None:
498
+ callback_kwargs = {}
499
+ for k in callback_on_step_end_tensor_inputs:
500
+ callback_kwargs[k] = locals()[k]
501
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
502
+
503
+ latents = callback_outputs.pop("latents", latents)
504
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
505
+ negative_prompt_embeds = callback_outputs.pop(
506
+ "negative_prompt_embeds", negative_prompt_embeds
507
+ )
508
+
509
+ # call the callback, if provided
510
+ if i == len(timesteps) - 1 or (
511
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
512
+ ):
513
+ progress_bar.update()
514
+ if callback is not None and i % callback_steps == 0:
515
+ step_idx = i // getattr(self.scheduler, "order", 1)
516
+ callback(step_idx, t, latents)
517
+
518
+ if XLA_AVAILABLE:
519
+ xm.mark_step()
520
+
521
+ if not output_type == "latent":
522
+ image = self.vae.decode(
523
+ latents / self.vae.config.scaling_factor,
524
+ return_dict=False,
525
+ generator=generator,
526
+ )[0]
527
+ image, has_nsfw_concept = self.run_safety_checker(
528
+ image, device, prompt_embeds.dtype
529
+ )
530
+ else:
531
+ image = latents
532
+ has_nsfw_concept = None
533
+
534
+ if has_nsfw_concept is None:
535
+ do_denormalize = [True] * image.shape[0]
536
+ else:
537
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
538
+ image = self.image_processor.postprocess(
539
+ image, output_type=output_type, do_denormalize=do_denormalize
540
+ )
541
+
542
+ # Offload all models
543
+ self.maybe_free_model_hooks()
544
+
545
+ if not return_dict:
546
+ return (image, has_nsfw_concept)
547
+
548
+ return StableDiffusionPipelineOutput(
549
+ images=image, nsfw_content_detected=has_nsfw_concept
550
+ )
551
+
552
+ ### NEW: adapters ###
553
+ def _init_custom_adapter(
554
+ self,
555
+ # Multi-view adapter
556
+ num_views: int = 1,
557
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
558
+ # Condition encoder
559
+ cond_in_channels: int = 6,
560
+ # For training
561
+ copy_attn_weights: bool = True,
562
+ zero_init_module_keys: List[str] = [],
563
+ ):
564
+ # Condition encoder
565
+ self.cond_encoder = T2IAdapter(
566
+ in_channels=cond_in_channels,
567
+ channels=self.unet.config.block_out_channels,
568
+ num_res_blocks=self.unet.config.layers_per_block,
569
+ downscale_factor=8,
570
+ )
571
+
572
+ # set custom attn processor for multi-view attention
573
+ self.unet: UNet2DConditionModel
574
+ set_unet_2d_condition_attn_processor(
575
+ self.unet,
576
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
577
+ query_dim=hs,
578
+ inner_dim=hs,
579
+ num_views=num_views,
580
+ name=name,
581
+ use_mv=True,
582
+ use_ref=False,
583
+ ),
584
+ set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
585
+ query_dim=hs,
586
+ inner_dim=hs,
587
+ num_views=num_views,
588
+ name=name,
589
+ use_mv=False,
590
+ use_ref=False,
591
+ ),
592
+ )
593
+
594
+ # copy decoupled attention weights from original unet
595
+ if copy_attn_weights:
596
+ state_dict = self.unet.state_dict()
597
+ for key in state_dict.keys():
598
+ if "_mv" in key:
599
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
600
+ else:
601
+ compatible_key = key
602
+
603
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
604
+ if is_zero_init_key:
605
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
606
+ else:
607
+ state_dict[key] = state_dict[compatible_key].clone()
608
+ self.unet.load_state_dict(state_dict)
609
+
610
+ def _load_custom_adapter(self, state_dict):
611
+ self.unet.load_state_dict(state_dict, strict=False)
612
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
613
+
614
+ def _save_custom_adapter(
615
+ self,
616
+ include_keys: Optional[List[str]] = None,
617
+ exclude_keys: Optional[List[str]] = None,
618
+ ):
619
+ def include_fn(k):
620
+ is_included = False
621
+
622
+ if include_keys is not None:
623
+ is_included = is_included or any([key in k for key in include_keys])
624
+ if exclude_keys is not None:
625
+ is_included = is_included and not any(
626
+ [key in k for key in exclude_keys]
627
+ )
628
+
629
+ return is_included
630
+
631
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
632
+ state_dict.update(self.cond_encoder.state_dict())
633
+
634
+ return state_dict
mvadapter/pipelines/pipeline_mvadapter_t2mv_sdxl.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
20
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
21
+ StableDiffusionXLPipelineOutput,
22
+ )
23
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
24
+ StableDiffusionXLPipeline,
25
+ rescale_noise_cfg,
26
+ retrieve_timesteps,
27
+ )
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import deprecate, logging
30
+ from transformers import (
31
+ CLIPImageProcessor,
32
+ CLIPTextModel,
33
+ CLIPTextModelWithProjection,
34
+ CLIPTokenizer,
35
+ CLIPVisionModelWithProjection,
36
+ )
37
+
38
+ from ..loaders import CustomAdapterMixin
39
+ from ..models.attention_processor import (
40
+ DecoupledMVRowSelfAttnProcessor2_0,
41
+ set_unet_2d_condition_attn_processor,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ class MVAdapterT2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
48
+ def __init__(
49
+ self,
50
+ vae: AutoencoderKL,
51
+ text_encoder: CLIPTextModel,
52
+ text_encoder_2: CLIPTextModelWithProjection,
53
+ tokenizer: CLIPTokenizer,
54
+ tokenizer_2: CLIPTokenizer,
55
+ unet: UNet2DConditionModel,
56
+ scheduler: KarrasDiffusionSchedulers,
57
+ image_encoder: CLIPVisionModelWithProjection = None,
58
+ feature_extractor: CLIPImageProcessor = None,
59
+ force_zeros_for_empty_prompt: bool = True,
60
+ add_watermarker: Optional[bool] = None,
61
+ ):
62
+ super().__init__(
63
+ vae=vae,
64
+ text_encoder=text_encoder,
65
+ text_encoder_2=text_encoder_2,
66
+ tokenizer=tokenizer,
67
+ tokenizer_2=tokenizer_2,
68
+ unet=unet,
69
+ scheduler=scheduler,
70
+ image_encoder=image_encoder,
71
+ feature_extractor=feature_extractor,
72
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
73
+ add_watermarker=add_watermarker,
74
+ )
75
+
76
+ self.control_image_processor = VaeImageProcessor(
77
+ vae_scale_factor=self.vae_scale_factor,
78
+ do_convert_rgb=True,
79
+ do_normalize=False,
80
+ )
81
+
82
+ def prepare_control_image(
83
+ self,
84
+ image,
85
+ width,
86
+ height,
87
+ batch_size,
88
+ num_images_per_prompt,
89
+ device,
90
+ dtype,
91
+ do_classifier_free_guidance=False,
92
+ ):
93
+ assert hasattr(
94
+ self, "control_image_processor"
95
+ ), "control_image_processor is not initialized"
96
+
97
+ image = self.control_image_processor.preprocess(
98
+ image, height=height, width=width
99
+ ).to(dtype=torch.float32)
100
+ image_batch_size = image.shape[0]
101
+
102
+ if image_batch_size == 1:
103
+ repeat_by = batch_size
104
+ else:
105
+ # image batch size is the same as prompt batch size
106
+ repeat_by = num_images_per_prompt # always 1 for control image
107
+
108
+ image = image.repeat_interleave(repeat_by, dim=0)
109
+
110
+ image = image.to(device=device, dtype=dtype)
111
+
112
+ if do_classifier_free_guidance:
113
+ image = torch.cat([image] * 2)
114
+
115
+ return image
116
+
117
+ @torch.no_grad()
118
+ def __call__(
119
+ self,
120
+ prompt: Union[str, List[str]] = None,
121
+ prompt_2: Optional[Union[str, List[str]]] = None,
122
+ height: Optional[int] = None,
123
+ width: Optional[int] = None,
124
+ num_inference_steps: int = 50,
125
+ timesteps: List[int] = None,
126
+ denoising_end: Optional[float] = None,
127
+ guidance_scale: float = 5.0,
128
+ negative_prompt: Optional[Union[str, List[str]]] = None,
129
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
130
+ num_images_per_prompt: Optional[int] = 1,
131
+ eta: float = 0.0,
132
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
133
+ latents: Optional[torch.FloatTensor] = None,
134
+ prompt_embeds: Optional[torch.FloatTensor] = None,
135
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
136
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ip_adapter_image: Optional[PipelineImageInput] = None,
139
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
140
+ output_type: Optional[str] = "pil",
141
+ return_dict: bool = True,
142
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
143
+ guidance_rescale: float = 0.0,
144
+ original_size: Optional[Tuple[int, int]] = None,
145
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
146
+ target_size: Optional[Tuple[int, int]] = None,
147
+ negative_original_size: Optional[Tuple[int, int]] = None,
148
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
149
+ negative_target_size: Optional[Tuple[int, int]] = None,
150
+ clip_skip: Optional[int] = None,
151
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
152
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
153
+ # NEW
154
+ mv_scale: float = 1.0,
155
+ # Camera or geometry condition
156
+ control_image: Optional[PipelineImageInput] = None,
157
+ control_conditioning_scale: Optional[float] = 1.0,
158
+ control_conditioning_factor: float = 1.0,
159
+ # Optional. controlnet
160
+ controlnet_image: Optional[PipelineImageInput] = None,
161
+ controlnet_conditioning_scale: Optional[float] = 1.0,
162
+ **kwargs,
163
+ ):
164
+ r"""
165
+ Function invoked when calling the pipeline for generation.
166
+
167
+ Args:
168
+ prompt (`str` or `List[str]`, *optional*):
169
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
170
+ instead.
171
+ prompt_2 (`str` or `List[str]`, *optional*):
172
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
173
+ used in both text-encoders
174
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
175
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
176
+ Anything below 512 pixels won't work well for
177
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
178
+ and checkpoints that are not specifically fine-tuned on low resolutions.
179
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
180
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
181
+ Anything below 512 pixels won't work well for
182
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
183
+ and checkpoints that are not specifically fine-tuned on low resolutions.
184
+ num_inference_steps (`int`, *optional*, defaults to 50):
185
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
186
+ expense of slower inference.
187
+ timesteps (`List[int]`, *optional*):
188
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
189
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
190
+ passed will be used. Must be in descending order.
191
+ denoising_end (`float`, *optional*):
192
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
193
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
194
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
195
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
196
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
197
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
198
+ guidance_scale (`float`, *optional*, defaults to 5.0):
199
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
200
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
201
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
202
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
203
+ usually at the expense of lower image quality.
204
+ negative_prompt (`str` or `List[str]`, *optional*):
205
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
206
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
207
+ less than `1`).
208
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
209
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
210
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
211
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
212
+ The number of images to generate per prompt.
213
+ eta (`float`, *optional*, defaults to 0.0):
214
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
215
+ [`schedulers.DDIMScheduler`], will be ignored for others.
216
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
217
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
218
+ to make generation deterministic.
219
+ latents (`torch.FloatTensor`, *optional*):
220
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
221
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
222
+ tensor will ge generated by sampling using the supplied random `generator`.
223
+ prompt_embeds (`torch.FloatTensor`, *optional*):
224
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
225
+ provided, text embeddings will be generated from `prompt` input argument.
226
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
227
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
228
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
229
+ argument.
230
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
231
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
232
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
233
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
234
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
235
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
236
+ input argument.
237
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
238
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
239
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
240
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
241
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
242
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
243
+ output_type (`str`, *optional*, defaults to `"pil"`):
244
+ The output format of the generate image. Choose between
245
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
246
+ return_dict (`bool`, *optional*, defaults to `True`):
247
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
248
+ of a plain tuple.
249
+ cross_attention_kwargs (`dict`, *optional*):
250
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
251
+ `self.processor` in
252
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
253
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
254
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
255
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
256
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
257
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
258
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
259
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
260
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
261
+ explained in section 2.2 of
262
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
263
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
264
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
265
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
266
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
267
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
268
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
269
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
270
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
271
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
272
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
273
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
274
+ micro-conditioning as explained in section 2.2 of
275
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
276
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
277
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
278
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
279
+ micro-conditioning as explained in section 2.2 of
280
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
281
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
282
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
283
+ To negatively condition the generation process based on a target image resolution. It should be as same
284
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
285
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
286
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
287
+ callback_on_step_end (`Callable`, *optional*):
288
+ A function that calls at the end of each denoising steps during the inference. The function is called
289
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
290
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
291
+ `callback_on_step_end_tensor_inputs`.
292
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
293
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
294
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
295
+ `._callback_tensor_inputs` attribute of your pipeline class.
296
+
297
+ Examples:
298
+
299
+ Returns:
300
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
301
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
302
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
303
+ """
304
+
305
+ callback = kwargs.pop("callback", None)
306
+ callback_steps = kwargs.pop("callback_steps", None)
307
+
308
+ if callback is not None:
309
+ deprecate(
310
+ "callback",
311
+ "1.0.0",
312
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
313
+ )
314
+ if callback_steps is not None:
315
+ deprecate(
316
+ "callback_steps",
317
+ "1.0.0",
318
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
319
+ )
320
+
321
+ # 0. Default height and width to unet
322
+ height = height or self.default_sample_size * self.vae_scale_factor
323
+ width = width or self.default_sample_size * self.vae_scale_factor
324
+
325
+ original_size = original_size or (height, width)
326
+ target_size = target_size or (height, width)
327
+
328
+ # 1. Check inputs. Raise error if not correct
329
+ self.check_inputs(
330
+ prompt,
331
+ prompt_2,
332
+ height,
333
+ width,
334
+ callback_steps,
335
+ negative_prompt,
336
+ negative_prompt_2,
337
+ prompt_embeds,
338
+ negative_prompt_embeds,
339
+ pooled_prompt_embeds,
340
+ negative_pooled_prompt_embeds,
341
+ ip_adapter_image,
342
+ ip_adapter_image_embeds,
343
+ callback_on_step_end_tensor_inputs,
344
+ )
345
+
346
+ self._guidance_scale = guidance_scale
347
+ self._guidance_rescale = guidance_rescale
348
+ self._clip_skip = clip_skip
349
+ self._cross_attention_kwargs = cross_attention_kwargs
350
+ self._denoising_end = denoising_end
351
+ self._interrupt = False
352
+
353
+ # 2. Define call parameters
354
+ if prompt is not None and isinstance(prompt, str):
355
+ batch_size = 1
356
+ elif prompt is not None and isinstance(prompt, list):
357
+ batch_size = len(prompt)
358
+ else:
359
+ batch_size = prompt_embeds.shape[0]
360
+
361
+ device = self._execution_device
362
+
363
+ # 3. Encode input prompt
364
+ lora_scale = (
365
+ self.cross_attention_kwargs.get("scale", None)
366
+ if self.cross_attention_kwargs is not None
367
+ else None
368
+ )
369
+
370
+ (
371
+ prompt_embeds,
372
+ negative_prompt_embeds,
373
+ pooled_prompt_embeds,
374
+ negative_pooled_prompt_embeds,
375
+ ) = self.encode_prompt(
376
+ prompt=prompt,
377
+ prompt_2=prompt_2,
378
+ device=device,
379
+ num_images_per_prompt=num_images_per_prompt,
380
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
381
+ negative_prompt=negative_prompt,
382
+ negative_prompt_2=negative_prompt_2,
383
+ prompt_embeds=prompt_embeds,
384
+ negative_prompt_embeds=negative_prompt_embeds,
385
+ pooled_prompt_embeds=pooled_prompt_embeds,
386
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
387
+ lora_scale=lora_scale,
388
+ clip_skip=self.clip_skip,
389
+ )
390
+
391
+ # 4. Prepare timesteps
392
+ timesteps, num_inference_steps = retrieve_timesteps(
393
+ self.scheduler, num_inference_steps, device, timesteps
394
+ )
395
+
396
+ # 5. Prepare latent variables
397
+ num_channels_latents = self.unet.config.in_channels
398
+ latents = self.prepare_latents(
399
+ batch_size * num_images_per_prompt,
400
+ num_channels_latents,
401
+ height,
402
+ width,
403
+ prompt_embeds.dtype,
404
+ device,
405
+ generator,
406
+ latents,
407
+ )
408
+
409
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
410
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
411
+
412
+ # 7. Prepare added time ids & embeddings
413
+ add_text_embeds = pooled_prompt_embeds
414
+ if self.text_encoder_2 is None:
415
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
416
+ else:
417
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
418
+
419
+ add_time_ids = self._get_add_time_ids(
420
+ original_size,
421
+ crops_coords_top_left,
422
+ target_size,
423
+ dtype=prompt_embeds.dtype,
424
+ text_encoder_projection_dim=text_encoder_projection_dim,
425
+ )
426
+ if negative_original_size is not None and negative_target_size is not None:
427
+ negative_add_time_ids = self._get_add_time_ids(
428
+ negative_original_size,
429
+ negative_crops_coords_top_left,
430
+ negative_target_size,
431
+ dtype=prompt_embeds.dtype,
432
+ text_encoder_projection_dim=text_encoder_projection_dim,
433
+ )
434
+ else:
435
+ negative_add_time_ids = add_time_ids
436
+
437
+ if self.do_classifier_free_guidance:
438
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
439
+ add_text_embeds = torch.cat(
440
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
441
+ )
442
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
443
+
444
+ prompt_embeds = prompt_embeds.to(device)
445
+ add_text_embeds = add_text_embeds.to(device)
446
+ add_time_ids = add_time_ids.to(device).repeat(
447
+ batch_size * num_images_per_prompt, 1
448
+ )
449
+
450
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
451
+ image_embeds = self.prepare_ip_adapter_image_embeds(
452
+ ip_adapter_image,
453
+ ip_adapter_image_embeds,
454
+ device,
455
+ batch_size * num_images_per_prompt,
456
+ self.do_classifier_free_guidance,
457
+ )
458
+
459
+ # Preprocess control image
460
+ control_image_feature = self.prepare_control_image(
461
+ image=control_image,
462
+ width=width,
463
+ height=height,
464
+ batch_size=batch_size * num_images_per_prompt,
465
+ num_images_per_prompt=1, # NOTE: always 1 for control images
466
+ device=device,
467
+ dtype=latents.dtype,
468
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
469
+ )
470
+ control_image_feature = control_image_feature.to(
471
+ device=device, dtype=latents.dtype
472
+ )
473
+
474
+ adapter_state = self.cond_encoder(control_image_feature)
475
+ for i, state in enumerate(adapter_state):
476
+ adapter_state[i] = state * control_conditioning_scale
477
+
478
+ # Preprocess controlnet image if provided
479
+ do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
480
+ if do_controlnet:
481
+ controlnet_image = self.prepare_control_image(
482
+ image=controlnet_image,
483
+ width=width,
484
+ height=height,
485
+ batch_size=batch_size * num_images_per_prompt,
486
+ num_images_per_prompt=1, # NOTE: always 1 for control images
487
+ device=device,
488
+ dtype=latents.dtype,
489
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
490
+ )
491
+ controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
492
+
493
+ # 8. Denoising loop
494
+ num_warmup_steps = max(
495
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
496
+ )
497
+
498
+ # 8.1 Apply denoising_end
499
+ if (
500
+ self.denoising_end is not None
501
+ and isinstance(self.denoising_end, float)
502
+ and self.denoising_end > 0
503
+ and self.denoising_end < 1
504
+ ):
505
+ discrete_timestep_cutoff = int(
506
+ round(
507
+ self.scheduler.config.num_train_timesteps
508
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
509
+ )
510
+ )
511
+ num_inference_steps = len(
512
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
513
+ )
514
+ timesteps = timesteps[:num_inference_steps]
515
+
516
+ # 9. Optionally get Guidance Scale Embedding
517
+ timestep_cond = None
518
+ if self.unet.config.time_cond_proj_dim is not None:
519
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
520
+ batch_size * num_images_per_prompt
521
+ )
522
+ timestep_cond = self.get_guidance_scale_embedding(
523
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
524
+ ).to(device=device, dtype=latents.dtype)
525
+
526
+ self._num_timesteps = len(timesteps)
527
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
528
+ for i, t in enumerate(timesteps):
529
+ if self.interrupt:
530
+ continue
531
+
532
+ # expand the latents if we are doing classifier free guidance
533
+ latent_model_input = (
534
+ torch.cat([latents] * 2)
535
+ if self.do_classifier_free_guidance
536
+ else latents
537
+ )
538
+
539
+ latent_model_input = self.scheduler.scale_model_input(
540
+ latent_model_input, t
541
+ )
542
+
543
+ added_cond_kwargs = {
544
+ "text_embeds": add_text_embeds,
545
+ "time_ids": add_time_ids,
546
+ }
547
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
548
+ added_cond_kwargs["image_embeds"] = image_embeds
549
+
550
+ if i < int(num_inference_steps * control_conditioning_factor):
551
+ down_intrablock_additional_residuals = [
552
+ state.clone() for state in adapter_state
553
+ ]
554
+ else:
555
+ down_intrablock_additional_residuals = None
556
+
557
+ unet_add_kwargs = {}
558
+
559
+ # Do controlnet if provided
560
+ if do_controlnet:
561
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
562
+ latent_model_input,
563
+ t,
564
+ encoder_hidden_states=prompt_embeds,
565
+ controlnet_cond=controlnet_image,
566
+ conditioning_scale=controlnet_conditioning_scale,
567
+ guess_mode=False,
568
+ added_cond_kwargs=added_cond_kwargs,
569
+ return_dict=False,
570
+ )
571
+ unet_add_kwargs.update(
572
+ {
573
+ "down_block_additional_residuals": down_block_res_samples,
574
+ "mid_block_additional_residual": mid_block_res_sample,
575
+ }
576
+ )
577
+
578
+ # predict the noise residual
579
+ noise_pred = self.unet(
580
+ latent_model_input,
581
+ t,
582
+ encoder_hidden_states=prompt_embeds,
583
+ timestep_cond=timestep_cond,
584
+ cross_attention_kwargs={
585
+ "mv_scale": mv_scale,
586
+ "num_views": num_images_per_prompt,
587
+ **(self.cross_attention_kwargs or {}),
588
+ },
589
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
590
+ added_cond_kwargs=added_cond_kwargs,
591
+ return_dict=False,
592
+ **unet_add_kwargs,
593
+ )[0]
594
+
595
+ # perform guidance
596
+ if self.do_classifier_free_guidance:
597
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
598
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
599
+ noise_pred_text - noise_pred_uncond
600
+ )
601
+
602
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
603
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
604
+ noise_pred = rescale_noise_cfg(
605
+ noise_pred,
606
+ noise_pred_text,
607
+ guidance_rescale=self.guidance_rescale,
608
+ )
609
+
610
+ # compute the previous noisy sample x_t -> x_t-1
611
+ latents_dtype = latents.dtype
612
+ latents = self.scheduler.step(
613
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
614
+ )[0]
615
+ if latents.dtype != latents_dtype:
616
+ if torch.backends.mps.is_available():
617
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
618
+ latents = latents.to(latents_dtype)
619
+
620
+ if callback_on_step_end is not None:
621
+ callback_kwargs = {}
622
+ for k in callback_on_step_end_tensor_inputs:
623
+ callback_kwargs[k] = locals()[k]
624
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
625
+
626
+ latents = callback_outputs.pop("latents", latents)
627
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
628
+ negative_prompt_embeds = callback_outputs.pop(
629
+ "negative_prompt_embeds", negative_prompt_embeds
630
+ )
631
+ add_text_embeds = callback_outputs.pop(
632
+ "add_text_embeds", add_text_embeds
633
+ )
634
+ negative_pooled_prompt_embeds = callback_outputs.pop(
635
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
636
+ )
637
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
638
+ negative_add_time_ids = callback_outputs.pop(
639
+ "negative_add_time_ids", negative_add_time_ids
640
+ )
641
+
642
+ # call the callback, if provided
643
+ if i == len(timesteps) - 1 or (
644
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
645
+ ):
646
+ progress_bar.update()
647
+ if callback is not None and i % callback_steps == 0:
648
+ step_idx = i // getattr(self.scheduler, "order", 1)
649
+ callback(step_idx, t, latents)
650
+
651
+ if not output_type == "latent":
652
+ # make sure the VAE is in float32 mode, as it overflows in float16
653
+ needs_upcasting = (
654
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
655
+ )
656
+
657
+ if needs_upcasting:
658
+ self.upcast_vae()
659
+ latents = latents.to(
660
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
661
+ )
662
+ elif latents.dtype != self.vae.dtype:
663
+ if torch.backends.mps.is_available():
664
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
665
+ self.vae = self.vae.to(latents.dtype)
666
+
667
+ # unscale/denormalize the latents
668
+ # denormalize with the mean and std if available and not None
669
+ has_latents_mean = (
670
+ hasattr(self.vae.config, "latents_mean")
671
+ and self.vae.config.latents_mean is not None
672
+ )
673
+ has_latents_std = (
674
+ hasattr(self.vae.config, "latents_std")
675
+ and self.vae.config.latents_std is not None
676
+ )
677
+ if has_latents_mean and has_latents_std:
678
+ latents_mean = (
679
+ torch.tensor(self.vae.config.latents_mean)
680
+ .view(1, 4, 1, 1)
681
+ .to(latents.device, latents.dtype)
682
+ )
683
+ latents_std = (
684
+ torch.tensor(self.vae.config.latents_std)
685
+ .view(1, 4, 1, 1)
686
+ .to(latents.device, latents.dtype)
687
+ )
688
+ latents = (
689
+ latents * latents_std / self.vae.config.scaling_factor
690
+ + latents_mean
691
+ )
692
+ else:
693
+ latents = latents / self.vae.config.scaling_factor
694
+
695
+ image = self.vae.decode(latents, return_dict=False)[0]
696
+
697
+ # cast back to fp16 if needed
698
+ if needs_upcasting:
699
+ self.vae.to(dtype=torch.float16)
700
+ else:
701
+ image = latents
702
+
703
+ if not output_type == "latent":
704
+ # apply watermark if available
705
+ if self.watermark is not None:
706
+ image = self.watermark.apply_watermark(image)
707
+
708
+ image = self.image_processor.postprocess(image, output_type=output_type)
709
+
710
+ # Offload all models
711
+ self.maybe_free_model_hooks()
712
+
713
+ if not return_dict:
714
+ return (image,)
715
+
716
+ return StableDiffusionXLPipelineOutput(images=image)
717
+
718
+ ### NEW: adapters ###
719
+ def _init_custom_adapter(
720
+ self,
721
+ # Multi-view adapter
722
+ num_views: int = 1,
723
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
724
+ # Condition encoder
725
+ cond_in_channels: int = 6,
726
+ # For training
727
+ copy_attn_weights: bool = True,
728
+ zero_init_module_keys: List[str] = [],
729
+ ):
730
+ # Condition encoder
731
+ self.cond_encoder = T2IAdapter(
732
+ in_channels=cond_in_channels,
733
+ channels=(320, 640, 1280, 1280),
734
+ num_res_blocks=2,
735
+ downscale_factor=16,
736
+ adapter_type="full_adapter_xl",
737
+ )
738
+
739
+ # set custom attn processor for multi-view attention
740
+ self.unet: UNet2DConditionModel
741
+ set_unet_2d_condition_attn_processor(
742
+ self.unet,
743
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
744
+ query_dim=hs,
745
+ inner_dim=hs,
746
+ num_views=num_views,
747
+ name=name,
748
+ use_mv=True,
749
+ use_ref=False,
750
+ ),
751
+ set_cross_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
752
+ query_dim=hs,
753
+ inner_dim=hs,
754
+ num_views=num_views,
755
+ name=name,
756
+ use_mv=False,
757
+ use_ref=False,
758
+ ),
759
+ )
760
+
761
+ # copy decoupled attention weights from original unet
762
+ if copy_attn_weights:
763
+ state_dict = self.unet.state_dict()
764
+ for key in state_dict.keys():
765
+ if "_mv" in key:
766
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
767
+ else:
768
+ compatible_key = key
769
+
770
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
771
+ if is_zero_init_key:
772
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
773
+ else:
774
+ state_dict[key] = state_dict[compatible_key].clone()
775
+ self.unet.load_state_dict(state_dict)
776
+
777
+ def _load_custom_adapter(self, state_dict):
778
+ self.unet.load_state_dict(state_dict, strict=False)
779
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
780
+
781
+ def _save_custom_adapter(
782
+ self,
783
+ include_keys: Optional[List[str]] = None,
784
+ exclude_keys: Optional[List[str]] = None,
785
+ ):
786
+ def include_fn(k):
787
+ is_included = False
788
+
789
+ if include_keys is not None:
790
+ is_included = is_included or any([key in k for key in include_keys])
791
+ if exclude_keys is not None:
792
+ is_included = is_included and not any(
793
+ [key in k for key in exclude_keys]
794
+ )
795
+
796
+ return is_included
797
+
798
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
799
+ state_dict.update(self.cond_encoder.state_dict())
800
+
801
+ return state_dict
mvadapter/schedulers/scheduler_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None):
5
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
6
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
7
+ timesteps = timesteps.to(device)
8
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
9
+ sigma = sigmas[step_indices].flatten()
10
+ while len(sigma.shape) < n_dim:
11
+ sigma = sigma.unsqueeze(-1)
12
+ return sigma
13
+
14
+
15
+ def SNR_to_betas(snr):
16
+ """
17
+ Converts SNR to betas
18
+ """
19
+ # alphas_cumprod = pass
20
+ # snr = (alpha / ) ** 2
21
+ # alpha_t^2 / (1 - alpha_t^2) = snr
22
+ alpha_t = (snr / (1 + snr)) ** 0.5
23
+ alphas_cumprod = alpha_t**2
24
+ alphas = alphas_cumprod / torch.cat(
25
+ [torch.ones(1, device=snr.device), alphas_cumprod[:-1]]
26
+ )
27
+ betas = 1 - alphas
28
+ return betas
29
+
30
+
31
+ def compute_snr(timesteps, noise_scheduler):
32
+ """
33
+ Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
34
+ """
35
+ alphas_cumprod = noise_scheduler.alphas_cumprod
36
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
37
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
38
+
39
+ # Expand the tensors.
40
+ # Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
41
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
42
+ timesteps
43
+ ].float()
44
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
45
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
46
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
47
+
48
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
49
+ device=timesteps.device
50
+ )[timesteps].float()
51
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
52
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
53
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
54
+
55
+ # Compute SNR.
56
+ snr = (alpha / sigma) ** 2
57
+ return snr
58
+
59
+
60
+ def compute_alpha(timesteps, noise_scheduler):
61
+ alphas_cumprod = noise_scheduler.alphas_cumprod
62
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
63
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
64
+ timesteps
65
+ ].float()
66
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
67
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
68
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
69
+
70
+ return alpha
mvadapter/schedulers/scheduling_shift_snr.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from .scheduler_utils import SNR_to_betas, compute_snr
6
+
7
+
8
+ class ShiftSNRScheduler:
9
+ def __init__(
10
+ self,
11
+ noise_scheduler: Any,
12
+ timesteps: Any,
13
+ shift_scale: float,
14
+ scheduler_class: Any,
15
+ ):
16
+ self.noise_scheduler = noise_scheduler
17
+ self.timesteps = timesteps
18
+ self.shift_scale = shift_scale
19
+ self.scheduler_class = scheduler_class
20
+
21
+ def _get_shift_scheduler(self):
22
+ """
23
+ Prepare scheduler for shifted betas.
24
+
25
+ :return: A scheduler object configured with shifted betas
26
+ """
27
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
28
+ shifted_betas = SNR_to_betas(snr / self.shift_scale)
29
+
30
+ return self.scheduler_class.from_config(
31
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
32
+ )
33
+
34
+ def _get_interpolated_shift_scheduler(self):
35
+ """
36
+ Prepare scheduler for shifted betas and interpolate with the original betas in log space.
37
+
38
+ :return: A scheduler object configured with interpolated shifted betas
39
+ """
40
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
41
+ shifted_snr = snr / self.shift_scale
42
+
43
+ weighting = self.timesteps.float() / (
44
+ self.noise_scheduler.config.num_train_timesteps - 1
45
+ )
46
+ interpolated_snr = torch.exp(
47
+ torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
48
+ )
49
+
50
+ shifted_betas = SNR_to_betas(interpolated_snr)
51
+
52
+ return self.scheduler_class.from_config(
53
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
54
+ )
55
+
56
+ @classmethod
57
+ def from_scheduler(
58
+ cls,
59
+ noise_scheduler: Any,
60
+ shift_mode: str = "default",
61
+ timesteps: Any = None,
62
+ shift_scale: float = 1.0,
63
+ scheduler_class: Any = None,
64
+ ):
65
+ # Check input
66
+ if timesteps is None:
67
+ timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
68
+ if scheduler_class is None:
69
+ scheduler_class = noise_scheduler.__class__
70
+
71
+ # Create scheduler
72
+ shift_scheduler = cls(
73
+ noise_scheduler=noise_scheduler,
74
+ timesteps=timesteps,
75
+ shift_scale=shift_scale,
76
+ scheduler_class=scheduler_class,
77
+ )
78
+
79
+ if shift_mode == "default":
80
+ return shift_scheduler._get_shift_scheduler()
81
+ elif shift_mode == "interpolated":
82
+ return shift_scheduler._get_interpolated_shift_scheduler()
83
+ else:
84
+ raise ValueError(f"Unknown shift_mode: {shift_mode}")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ """
89
+ Compare the alpha values for different noise schedulers.
90
+ """
91
+ import matplotlib.pyplot as plt
92
+ from diffusers import DDPMScheduler
93
+
94
+ from .scheduler_utils import compute_alpha
95
+
96
+ # Base
97
+ timesteps = torch.arange(0, 1000)
98
+ noise_scheduler_base = DDPMScheduler.from_pretrained(
99
+ "runwayml/stable-diffusion-v1-5", subfolder="scheduler"
100
+ )
101
+ alpha = compute_alpha(timesteps, noise_scheduler_base)
102
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
103
+
104
+ # Kolors
105
+ num_train_timesteps_ = 1100
106
+ timesteps_ = torch.arange(0, num_train_timesteps_)
107
+ noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
108
+ noise_scheduler_kolors = DDPMScheduler.from_config(
109
+ noise_scheduler_base.config, **noise_kwargs
110
+ )
111
+ alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
112
+ plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
113
+
114
+ # Shift betas
115
+ shift_scale = 8.0
116
+ noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
117
+ noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
118
+ )
119
+ alpha = compute_alpha(timesteps, noise_scheduler_shift)
120
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
121
+
122
+ # Shift betas (interpolated)
123
+ noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
124
+ noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
125
+ )
126
+ alpha = compute_alpha(timesteps, noise_scheduler_inter)
127
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
128
+
129
+ # ZeroSNR
130
+ noise_scheduler = DDPMScheduler.from_config(
131
+ noise_scheduler_base.config, rescale_betas_zero_snr=True
132
+ )
133
+ alpha = compute_alpha(timesteps, noise_scheduler)
134
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
135
+
136
+ plt.legend()
137
+ plt.grid()
138
+ plt.savefig("check_alpha.png")
mvadapter/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .camera import get_camera, get_orthogonal_camera
2
+ from .geometry import get_plucker_embeds_from_cameras_ortho
3
+ from .saving import make_image_grid, tensor_to_image
mvadapter/utils/camera.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import trimesh
9
+ from PIL import Image
10
+ from torch import BoolTensor, FloatTensor
11
+
12
+ LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
13
+
14
+
15
+ def list_to_pt(
16
+ x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None
17
+ ) -> torch.Tensor:
18
+ if isinstance(x, list) or isinstance(x, np.ndarray):
19
+ return torch.tensor(x, dtype=dtype, device=device)
20
+ return x.to(dtype=dtype)
21
+
22
+
23
+ def get_c2w(
24
+ elevation_deg: LIST_TYPE,
25
+ distance: LIST_TYPE,
26
+ azimuth_deg: Optional[LIST_TYPE],
27
+ num_views: Optional[int] = 1,
28
+ device: Optional[str] = None,
29
+ ) -> torch.FloatTensor:
30
+ if azimuth_deg is None:
31
+ assert (
32
+ num_views is not None
33
+ ), "num_views must be provided if azimuth_deg is None."
34
+ azimuth_deg = torch.linspace(
35
+ 0, 360, num_views + 1, dtype=torch.float32, device=device
36
+ )[:-1]
37
+ else:
38
+ num_views = len(azimuth_deg)
39
+ azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
40
+ elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
41
+ camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
42
+ elevation = elevation_deg * math.pi / 180
43
+ azimuth = azimuth_deg * math.pi / 180
44
+ camera_positions = torch.stack(
45
+ [
46
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
47
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
48
+ camera_distances * torch.sin(elevation),
49
+ ],
50
+ dim=-1,
51
+ )
52
+ center = torch.zeros_like(camera_positions)
53
+ up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat(
54
+ num_views, 1
55
+ )
56
+ lookat = F.normalize(center - camera_positions, dim=-1)
57
+ right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
58
+ up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
59
+ c2w3x4 = torch.cat(
60
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
61
+ dim=-1,
62
+ )
63
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
64
+ c2w[:, 3, 3] = 1.0
65
+ return c2w
66
+
67
+
68
+ def get_projection_matrix(
69
+ fovy_deg: LIST_TYPE,
70
+ aspect_wh: float = 1.0,
71
+ near: float = 0.1,
72
+ far: float = 100.0,
73
+ device: Optional[str] = None,
74
+ ) -> torch.FloatTensor:
75
+ fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device)
76
+ batch_size = fovy_deg.shape[0]
77
+ fovy = fovy_deg * math.pi / 180
78
+ tan_half_fovy = torch.tan(fovy / 2)
79
+ projection_matrix = torch.zeros(
80
+ batch_size, 4, 4, dtype=torch.float32, device=device
81
+ )
82
+ projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy)
83
+ projection_matrix[:, 1, 1] = -1 / tan_half_fovy
84
+ projection_matrix[:, 2, 2] = -(far + near) / (far - near)
85
+ projection_matrix[:, 2, 3] = -2 * far * near / (far - near)
86
+ projection_matrix[:, 3, 2] = -1
87
+ return projection_matrix
88
+
89
+
90
+ def get_orthogonal_projection_matrix(
91
+ batch_size: int,
92
+ left: float,
93
+ right: float,
94
+ bottom: float,
95
+ top: float,
96
+ near: float = 0.1,
97
+ far: float = 100.0,
98
+ device: Optional[str] = None,
99
+ ) -> torch.FloatTensor:
100
+ projection_matrix = torch.zeros(
101
+ batch_size, 4, 4, dtype=torch.float32, device=device
102
+ )
103
+ projection_matrix[:, 0, 0] = 2 / (right - left)
104
+ projection_matrix[:, 1, 1] = -2 / (top - bottom)
105
+ projection_matrix[:, 2, 2] = -2 / (far - near)
106
+ projection_matrix[:, 0, 3] = -(right + left) / (right - left)
107
+ projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom)
108
+ projection_matrix[:, 2, 3] = -(far + near) / (far - near)
109
+ projection_matrix[:, 3, 3] = 1
110
+ return projection_matrix
111
+
112
+
113
+ @dataclass
114
+ class Camera:
115
+ c2w: Optional[torch.FloatTensor]
116
+ w2c: torch.FloatTensor
117
+ proj_mtx: torch.FloatTensor
118
+ mvp_mtx: torch.FloatTensor
119
+ cam_pos: Optional[torch.FloatTensor]
120
+
121
+ def __getitem__(self, index):
122
+ if isinstance(index, int):
123
+ sl = slice(index, index + 1)
124
+ elif isinstance(index, slice):
125
+ sl = index
126
+ else:
127
+ raise NotImplementedError
128
+
129
+ return Camera(
130
+ c2w=self.c2w[sl] if self.c2w is not None else None,
131
+ w2c=self.w2c[sl],
132
+ proj_mtx=self.proj_mtx[sl],
133
+ mvp_mtx=self.mvp_mtx[sl],
134
+ cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None,
135
+ )
136
+
137
+ def to(self, device: Optional[str] = None):
138
+ if self.c2w is not None:
139
+ self.c2w = self.c2w.to(device)
140
+ self.w2c = self.w2c.to(device)
141
+ self.proj_mtx = self.proj_mtx.to(device)
142
+ self.mvp_mtx = self.mvp_mtx.to(device)
143
+ if self.cam_pos is not None:
144
+ self.cam_pos = self.cam_pos.to(device)
145
+
146
+ def __len__(self):
147
+ return self.c2w.shape[0]
148
+
149
+
150
+ def get_camera(
151
+ elevation_deg: Optional[LIST_TYPE] = None,
152
+ distance: Optional[LIST_TYPE] = None,
153
+ fovy_deg: Optional[LIST_TYPE] = None,
154
+ azimuth_deg: Optional[LIST_TYPE] = None,
155
+ num_views: Optional[int] = 1,
156
+ c2w: Optional[torch.FloatTensor] = None,
157
+ w2c: Optional[torch.FloatTensor] = None,
158
+ proj_mtx: Optional[torch.FloatTensor] = None,
159
+ aspect_wh: float = 1.0,
160
+ near: float = 0.1,
161
+ far: float = 100.0,
162
+ device: Optional[str] = None,
163
+ ):
164
+ if w2c is None:
165
+ if c2w is None:
166
+ c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
167
+ camera_positions = c2w[:, :3, 3]
168
+ w2c = torch.linalg.inv(c2w)
169
+ else:
170
+ camera_positions = None
171
+ c2w = None
172
+ if proj_mtx is None:
173
+ proj_mtx = get_projection_matrix(
174
+ fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device
175
+ )
176
+ mvp_mtx = proj_mtx @ w2c
177
+ return Camera(
178
+ c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
179
+ )
180
+
181
+
182
+ def get_orthogonal_camera(
183
+ elevation_deg: LIST_TYPE,
184
+ distance: LIST_TYPE,
185
+ left: float,
186
+ right: float,
187
+ bottom: float,
188
+ top: float,
189
+ azimuth_deg: Optional[LIST_TYPE] = None,
190
+ num_views: Optional[int] = 1,
191
+ near: float = 0.1,
192
+ far: float = 100.0,
193
+ device: Optional[str] = None,
194
+ ):
195
+ c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
196
+ camera_positions = c2w[:, :3, 3]
197
+ w2c = torch.linalg.inv(c2w)
198
+ proj_mtx = get_orthogonal_projection_matrix(
199
+ batch_size=c2w.shape[0],
200
+ left=left,
201
+ right=right,
202
+ bottom=bottom,
203
+ top=top,
204
+ near=near,
205
+ far=far,
206
+ device=device,
207
+ )
208
+ mvp_mtx = proj_mtx @ w2c
209
+ return Camera(
210
+ c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
211
+ )
mvadapter/utils/geometry.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None):
9
+ """Compute the position map from the depth map and the camera parameters for a batch of views.
10
+
11
+ Args:
12
+ depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
13
+ mask (torch.Tensor): The masks with the shape (B, H, W, 1).
14
+ intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3).
15
+ extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
16
+ image_wh (Tuple[int, int]): The image width and height.
17
+
18
+ Returns:
19
+ torch.Tensor: The position maps with the shape (B, H, W, 3).
20
+ """
21
+ if image_wh is None:
22
+ image_wh = depth.shape[2], depth.shape[1]
23
+
24
+ B, H, W, _ = depth.shape
25
+ depth = depth.squeeze(-1)
26
+
27
+ u_coord, v_coord = torch.meshgrid(
28
+ torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy"
29
+ )
30
+ u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
31
+ v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
32
+
33
+ # Compute the position map by back-projecting depth pixels to 3D space
34
+ x = (
35
+ (u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1))
36
+ * depth
37
+ / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
38
+ )
39
+ y = (
40
+ (v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1))
41
+ * depth
42
+ / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
43
+ )
44
+ z = depth
45
+
46
+ # Concatenate to form the 3D coordinates in the camera frame
47
+ camera_coords = torch.stack([x, y, z], dim=-1)
48
+
49
+ # Apply the extrinsic matrix to get coordinates in the world frame
50
+ coords_homogeneous = torch.nn.functional.pad(
51
+ camera_coords, (0, 1), "constant", 1.0
52
+ ) # Add a homogeneous coordinate
53
+ world_coords = torch.matmul(
54
+ coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
55
+ ).view(B, H, W, 4)
56
+
57
+ # Apply the mask to the position map
58
+ position_map = world_coords[..., :3] * mask
59
+
60
+ return position_map
61
+
62
+
63
+ def get_position_map_from_depth_ortho(
64
+ depth, mask, extrinsics, ortho_scale, image_wh=None
65
+ ):
66
+ """Compute the position map from the depth map and the camera parameters for a batch of views
67
+ using orthographic projection with a given ortho_scale.
68
+
69
+ Args:
70
+ depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
71
+ mask (torch.Tensor): The masks with the shape (B, H, W, 1).
72
+ extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
73
+ ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1).
74
+ image_wh (Tuple[int, int]): Optional. The image width and height.
75
+
76
+ Returns:
77
+ torch.Tensor: The position maps with the shape (B, H, W, 3).
78
+ """
79
+ if image_wh is None:
80
+ image_wh = depth.shape[2], depth.shape[1]
81
+
82
+ B, H, W, _ = depth.shape
83
+ depth = depth.squeeze(-1)
84
+
85
+ # Generating grid of coordinates in the image space
86
+ u_coord, v_coord = torch.meshgrid(
87
+ torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy"
88
+ )
89
+ u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
90
+ v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
91
+
92
+ # Compute the position map using orthographic projection with ortho_scale
93
+ x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0]
94
+ y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1]
95
+ z = depth
96
+
97
+ # Concatenate to form the 3D coordinates in the camera frame
98
+ camera_coords = torch.stack([x, y, z], dim=-1)
99
+
100
+ # Apply the extrinsic matrix to get coordinates in the world frame
101
+ coords_homogeneous = torch.nn.functional.pad(
102
+ camera_coords, (0, 1), "constant", 1.0
103
+ ) # Add a homogeneous coordinate
104
+ world_coords = torch.matmul(
105
+ coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
106
+ ).view(B, H, W, 4)
107
+
108
+ # Apply the mask to the position map
109
+ position_map = world_coords[..., :3] * mask
110
+
111
+ return position_map
112
+
113
+
114
+ def get_opencv_from_blender(matrix_world, fov=None, image_size=None):
115
+ # convert matrix_world to opencv format extrinsics
116
+ opencv_world_to_cam = matrix_world.inverse()
117
+ opencv_world_to_cam[1, :] *= -1
118
+ opencv_world_to_cam[2, :] *= -1
119
+ R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
120
+
121
+ if fov is None: # orthographic camera
122
+ return R, T
123
+
124
+ R, T = R.unsqueeze(0), T.unsqueeze(0)
125
+ # convert fov to opencv format intrinsics
126
+ focal = 1 / np.tan(fov / 2)
127
+ intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
128
+ opencv_cam_matrix = (
129
+ torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device)
130
+ )
131
+ opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to(
132
+ matrix_world.device
133
+ )
134
+ opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
135
+
136
+ return R, T, opencv_cam_matrix
137
+
138
+
139
+ def get_ray_directions(
140
+ H: int,
141
+ W: int,
142
+ focal: float,
143
+ principal: Optional[Tuple[float, float]] = None,
144
+ use_pixel_centers: bool = True,
145
+ ) -> torch.Tensor:
146
+ """
147
+ Get ray directions for all pixels in camera coordinate.
148
+ Args:
149
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
150
+ Outputs:
151
+ directions: (H, W, 3), the direction of the rays in camera coordinate
152
+ """
153
+ pixel_center = 0.5 if use_pixel_centers else 0
154
+ cx, cy = W / 2, H / 2 if principal is None else principal
155
+ i, j = torch.meshgrid(
156
+ torch.arange(W, dtype=torch.float32) + pixel_center,
157
+ torch.arange(H, dtype=torch.float32) + pixel_center,
158
+ indexing="xy",
159
+ )
160
+ directions = torch.stack(
161
+ [(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1
162
+ )
163
+ return F.normalize(directions, dim=-1)
164
+
165
+
166
+ def get_rays(
167
+ directions: torch.Tensor, c2w: torch.Tensor
168
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ """
170
+ Get ray origins and directions from camera coordinates to world coordinates
171
+ Args:
172
+ directions: (H, W, 3) ray directions in camera coordinates
173
+ c2w: (4, 4) camera-to-world transformation matrix
174
+ Outputs:
175
+ rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates
176
+ """
177
+ # Rotate ray directions from camera coordinate to the world coordinate
178
+ rays_d = directions @ c2w[:3, :3].T
179
+ rays_o = c2w[:3, 3].expand(rays_d.shape)
180
+ return rays_o, rays_d
181
+
182
+
183
+ def compute_plucker_embed(
184
+ c2w: torch.Tensor, image_width: int, image_height: int, focal: float
185
+ ) -> torch.Tensor:
186
+ """
187
+ Computes Plucker coordinates for a camera.
188
+ Args:
189
+ c2w: (4, 4) camera-to-world transformation matrix
190
+ image_width: Image width
191
+ image_height: Image height
192
+ focal: Focal length of the camera
193
+ Returns:
194
+ plucker: (6, H, W) Plucker embedding
195
+ """
196
+ directions = get_ray_directions(image_height, image_width, focal)
197
+ rays_o, rays_d = get_rays(directions, c2w)
198
+ # Cross product to get Plucker coordinates
199
+ cross = torch.cross(rays_o, rays_d, dim=-1)
200
+ plucker = torch.cat((rays_d, cross), dim=-1)
201
+ return plucker.permute(2, 0, 1)
202
+
203
+
204
+ def get_plucker_embeds_from_cameras(
205
+ c2w: List[torch.Tensor], fov: List[float], image_size: int
206
+ ) -> torch.Tensor:
207
+ """
208
+ Given lists of camera transformations and fov, returns the batched plucker embeddings.
209
+ Args:
210
+ c2w: list of camera-to-world transformation matrices
211
+ fov: list of field of view values
212
+ image_size: size of the image
213
+ Returns:
214
+ plucker_embeds: (B, 6, H, W) batched plucker embeddings
215
+ """
216
+ plucker_embeds = []
217
+ for cam_matrix, cam_fov in zip(c2w, fov):
218
+ focal = 0.5 * image_size / np.tan(0.5 * cam_fov)
219
+ plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal)
220
+ plucker_embeds.append(plucker)
221
+ return torch.stack(plucker_embeds)
222
+
223
+
224
+ def get_plucker_embeds_from_cameras_ortho(
225
+ c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int
226
+ ):
227
+ """
228
+ Given lists of camera transformations and fov, returns the batched plucker embeddings.
229
+
230
+ Parameters:
231
+ c2w: list of camera-to-world transformation matrices
232
+ fov: list of field of view values
233
+ image_size: size of the image
234
+
235
+ Returns:
236
+ plucker_embeds: plucker embeddings (B, 6, H, W)
237
+ """
238
+ plucker_embeds = []
239
+ # compute pairwise mask and plucker embeddings
240
+ for cam_matrix, scale in zip(c2w, ortho_scale):
241
+ # blender to opencv to pytorch3d
242
+ R, T = get_opencv_from_blender(cam_matrix)
243
+ cam_pos = -R.T @ T
244
+ view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device)
245
+ # normalize camera position
246
+ cam_pos = F.normalize(cam_pos, dim=0)
247
+ plucker = torch.concat([view_dir, cam_pos])
248
+ plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size)
249
+ plucker_embeds.append(plucker)
250
+
251
+ plucker_embeds = torch.stack(plucker_embeds)
252
+
253
+ return plucker_embeds
mvadapter/utils/logging.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Optuna, Hugging Face
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Logging utilities."""
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import threading
21
+ from logging import CRITICAL # NOQA
22
+ from logging import DEBUG # NOQA
23
+ from logging import ERROR # NOQA
24
+ from logging import FATAL # NOQA
25
+ from logging import INFO # NOQA
26
+ from logging import NOTSET # NOQA
27
+ from logging import WARN # NOQA
28
+ from logging import WARNING # NOQA
29
+ from typing import Dict, Optional
30
+
31
+ from tqdm import auto as tqdm_lib
32
+
33
+ _lock = threading.Lock()
34
+ _default_handler: Optional[logging.Handler] = None
35
+
36
+ log_levels = {
37
+ "debug": logging.DEBUG,
38
+ "info": logging.INFO,
39
+ "warning": logging.WARNING,
40
+ "error": logging.ERROR,
41
+ "critical": logging.CRITICAL,
42
+ }
43
+
44
+ _default_log_level = logging.INFO
45
+
46
+ _tqdm_active = True
47
+
48
+
49
+ def _get_default_logging_level() -> int:
50
+ """
51
+ If LATEXTURE_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
52
+ not - fall back to `_default_log_level`
53
+ """
54
+ env_level_str = os.getenv("LATEXTURE_VERBOSITY", None)
55
+ if env_level_str:
56
+ if env_level_str in log_levels:
57
+ return log_levels[env_level_str]
58
+ else:
59
+ logging.getLogger().warning(
60
+ f"Unknown option LATEXTURE_VERBOSITY={env_level_str}, "
61
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
62
+ )
63
+ return _default_log_level
64
+
65
+
66
+ def _get_library_name() -> str:
67
+ return __name__.split(".")[0]
68
+
69
+
70
+ def _get_library_root_logger() -> logging.Logger:
71
+ return logging.getLogger(_get_library_name())
72
+
73
+
74
+ def _configure_library_root_logger() -> None:
75
+ global _default_handler
76
+
77
+ with _lock:
78
+ if _default_handler:
79
+ # This library has already configured the library root logger.
80
+ return
81
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
82
+ _default_handler.flush = sys.stderr.flush
83
+
84
+ # Apply our default configuration to the library root logger.
85
+ library_root_logger = _get_library_root_logger()
86
+ library_root_logger.addHandler(_default_handler)
87
+ library_root_logger.setLevel(_get_default_logging_level())
88
+ library_root_logger.propagate = False
89
+
90
+ enable_explicit_format()
91
+
92
+
93
+ def _reset_library_root_logger() -> None:
94
+ global _default_handler
95
+
96
+ with _lock:
97
+ if not _default_handler:
98
+ return
99
+
100
+ library_root_logger = _get_library_root_logger()
101
+ library_root_logger.removeHandler(_default_handler)
102
+ library_root_logger.setLevel(logging.NOTSET)
103
+ _default_handler = None
104
+
105
+
106
+ def get_log_levels_dict() -> Dict[str, int]:
107
+ return log_levels
108
+
109
+
110
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
111
+ """
112
+ Return a logger with the specified name.
113
+
114
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
115
+ """
116
+
117
+ if name is None:
118
+ name = _get_library_name()
119
+
120
+ _configure_library_root_logger()
121
+ return logging.getLogger(name)
122
+
123
+
124
+ def get_verbosity() -> int:
125
+ """
126
+ Return the current level for the 🤗 Diffusers' root logger as an `int`.
127
+
128
+ Returns:
129
+ `int`:
130
+ Logging level integers which can be one of:
131
+
132
+ - `50`: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
133
+ - `40`: `diffusers.logging.ERROR`
134
+ - `30`: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
135
+ - `20`: `diffusers.logging.INFO`
136
+ - `10`: `diffusers.logging.DEBUG`
137
+
138
+ """
139
+
140
+ _configure_library_root_logger()
141
+ return _get_library_root_logger().getEffectiveLevel()
142
+
143
+
144
+ def set_verbosity(verbosity: int) -> None:
145
+ """
146
+ Set the verbosity level for the 🤗 Diffusers' root logger.
147
+
148
+ Args:
149
+ verbosity (`int`):
150
+ Logging level which can be one of:
151
+
152
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
153
+ - `diffusers.logging.ERROR`
154
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
155
+ - `diffusers.logging.INFO`
156
+ - `diffusers.logging.DEBUG`
157
+ """
158
+
159
+ _configure_library_root_logger()
160
+ _get_library_root_logger().setLevel(verbosity)
161
+
162
+
163
+ def set_verbosity_info() -> None:
164
+ """Set the verbosity to the `INFO` level."""
165
+ return set_verbosity(INFO)
166
+
167
+
168
+ def set_verbosity_warning() -> None:
169
+ """Set the verbosity to the `WARNING` level."""
170
+ return set_verbosity(WARNING)
171
+
172
+
173
+ def set_verbosity_debug() -> None:
174
+ """Set the verbosity to the `DEBUG` level."""
175
+ return set_verbosity(DEBUG)
176
+
177
+
178
+ def set_verbosity_error() -> None:
179
+ """Set the verbosity to the `ERROR` level."""
180
+ return set_verbosity(ERROR)
181
+
182
+
183
+ def disable_default_handler() -> None:
184
+ """Disable the default handler of the 🤗 Diffusers' root logger."""
185
+
186
+ _configure_library_root_logger()
187
+
188
+ assert _default_handler is not None
189
+ _get_library_root_logger().removeHandler(_default_handler)
190
+
191
+
192
+ def enable_default_handler() -> None:
193
+ """Enable the default handler of the 🤗 Diffusers' root logger."""
194
+
195
+ _configure_library_root_logger()
196
+
197
+ assert _default_handler is not None
198
+ _get_library_root_logger().addHandler(_default_handler)
199
+
200
+
201
+ def add_handler(handler: logging.Handler) -> None:
202
+ """adds a handler to the HuggingFace Diffusers' root logger."""
203
+
204
+ _configure_library_root_logger()
205
+
206
+ assert handler is not None
207
+ _get_library_root_logger().addHandler(handler)
208
+
209
+
210
+ def remove_handler(handler: logging.Handler) -> None:
211
+ """removes given handler from the HuggingFace Diffusers' root logger."""
212
+
213
+ _configure_library_root_logger()
214
+
215
+ assert handler is not None and handler in _get_library_root_logger().handlers
216
+ _get_library_root_logger().removeHandler(handler)
217
+
218
+
219
+ def disable_propagation() -> None:
220
+ """
221
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
222
+ """
223
+
224
+ _configure_library_root_logger()
225
+ _get_library_root_logger().propagate = False
226
+
227
+
228
+ def enable_propagation() -> None:
229
+ """
230
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
231
+ double logging if the root logger has been configured.
232
+ """
233
+
234
+ _configure_library_root_logger()
235
+ _get_library_root_logger().propagate = True
236
+
237
+
238
+ def enable_explicit_format() -> None:
239
+ """
240
+ Enable explicit formatting for every 🤗 Diffusers' logger. The explicit formatter is as follows:
241
+ ```
242
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
243
+ ```
244
+ All handlers currently bound to the root logger are affected by this method.
245
+ """
246
+ handlers = _get_library_root_logger().handlers
247
+
248
+ for handler in handlers:
249
+ formatter = logging.Formatter(
250
+ "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
251
+ )
252
+ handler.setFormatter(formatter)
253
+
254
+
255
+ def reset_format() -> None:
256
+ """
257
+ Resets the formatting for 🤗 Diffusers' loggers.
258
+
259
+ All handlers currently bound to the root logger are affected by this method.
260
+ """
261
+ handlers = _get_library_root_logger().handlers
262
+
263
+ for handler in handlers:
264
+ handler.setFormatter(None)
265
+
266
+
267
+ def warning_advice(self, *args, **kwargs) -> None:
268
+ """
269
+ This method is identical to `logger.warning()`, but if env var LATEXTURE_NO_ADVISORY_WARNINGS=1 is set, this
270
+ warning will not be printed
271
+ """
272
+ no_advisory_warnings = os.getenv("LATEXTURE_NO_ADVISORY_WARNINGS", False)
273
+ if no_advisory_warnings:
274
+ return
275
+ self.warning(*args, **kwargs)
276
+
277
+
278
+ logging.Logger.warning_advice = warning_advice
279
+
280
+
281
+ class EmptyTqdm:
282
+ """Dummy tqdm which doesn't do anything."""
283
+
284
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
285
+ self._iterator = args[0] if args else None
286
+
287
+ def __iter__(self):
288
+ return iter(self._iterator)
289
+
290
+ def __getattr__(self, _):
291
+ """Return empty function."""
292
+
293
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
294
+ return
295
+
296
+ return empty_fn
297
+
298
+ def __enter__(self):
299
+ return self
300
+
301
+ def __exit__(self, type_, value, traceback):
302
+ return
303
+
304
+
305
+ class _tqdm_cls:
306
+ def __call__(self, *args, **kwargs):
307
+ if _tqdm_active:
308
+ return tqdm_lib.tqdm(*args, **kwargs)
309
+ else:
310
+ return EmptyTqdm(*args, **kwargs)
311
+
312
+ def set_lock(self, *args, **kwargs):
313
+ self._lock = None
314
+ if _tqdm_active:
315
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
316
+
317
+ def get_lock(self):
318
+ if _tqdm_active:
319
+ return tqdm_lib.tqdm.get_lock()
320
+
321
+
322
+ tqdm = _tqdm_cls()
323
+
324
+
325
+ def is_progress_bar_enabled() -> bool:
326
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
327
+ global _tqdm_active
328
+ return bool(_tqdm_active)
329
+
330
+
331
+ def enable_progress_bar() -> None:
332
+ """Enable tqdm progress bar."""
333
+ global _tqdm_active
334
+ _tqdm_active = True
335
+
336
+
337
+ def disable_progress_bar() -> None:
338
+ """Disable tqdm progress bar."""
339
+ global _tqdm_active
340
+ _tqdm_active = False
mvadapter/utils/render.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
5
+ from typing import List, Optional, Union
6
+
7
+ import numpy as np
8
+ import nvdiffrast.torch as dr
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import trimesh
12
+ from PIL import Image
13
+ from torch import BoolTensor, FloatTensor
14
+
15
+ from . import logging
16
+ from .camera import Camera
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+
21
+ def dot(x: torch.FloatTensor, y: torch.FloatTensor) -> torch.FloatTensor:
22
+ return torch.sum(x * y, -1, keepdim=True)
23
+
24
+
25
+ @dataclass
26
+ class TexturedMesh:
27
+ v_pos: torch.FloatTensor
28
+ t_pos_idx: torch.LongTensor
29
+
30
+ # texture coordinates
31
+ v_tex: Optional[torch.FloatTensor] = None
32
+ t_tex_idx: Optional[torch.LongTensor] = None
33
+
34
+ # texture map
35
+ texture: Optional[torch.FloatTensor] = None
36
+
37
+ # vertices, faces after vertex merging
38
+ _stitched_v_pos: Optional[torch.FloatTensor] = None
39
+ _stitched_t_pos_idx: Optional[torch.LongTensor] = None
40
+
41
+ _v_nrm: Optional[torch.FloatTensor] = None
42
+
43
+ @property
44
+ def v_nrm(self) -> torch.FloatTensor:
45
+ if self._v_nrm is None:
46
+ self._v_nrm = self._compute_vertex_normal()
47
+ return self._v_nrm
48
+
49
+ def set_stitched_mesh(
50
+ self, v_pos: torch.FloatTensor, t_pos_idx: torch.LongTensor
51
+ ) -> None:
52
+ self._stitched_v_pos = v_pos
53
+ self._stitched_t_pos_idx = t_pos_idx
54
+
55
+ @property
56
+ def stitched_v_pos(self) -> torch.FloatTensor:
57
+ if self._stitched_v_pos is None:
58
+ logger.warning("Stitched vertices not available, using original vertices!")
59
+ return self.v_pos
60
+ return self._stitched_v_pos
61
+
62
+ @property
63
+ def stitched_t_pos_idx(self) -> torch.LongTensor:
64
+ if self._stitched_t_pos_idx is None:
65
+ logger.warning("Stitched faces not available, using original faces!")
66
+ return self.t_pos_idx
67
+ return self._stitched_t_pos_idx
68
+
69
+ def _compute_vertex_normal(self) -> torch.FloatTensor:
70
+ if self._stitched_v_pos is None or self._stitched_t_pos_idx is None:
71
+ logger.warning(
72
+ "Stitched vertices and faces not available, computing vertex normals on original mesh, which can be erroneous!"
73
+ )
74
+ v_pos, t_pos_idx = self.v_pos, self.t_pos_idx
75
+ else:
76
+ v_pos, t_pos_idx = self._stitched_v_pos, self._stitched_t_pos_idx
77
+
78
+ i0 = t_pos_idx[:, 0]
79
+ i1 = t_pos_idx[:, 1]
80
+ i2 = t_pos_idx[:, 2]
81
+
82
+ v0 = v_pos[i0, :]
83
+ v1 = v_pos[i1, :]
84
+ v2 = v_pos[i2, :]
85
+
86
+ face_normals = torch.cross(v1 - v0, v2 - v0)
87
+
88
+ # Splat face normals to vertices
89
+ v_nrm = torch.zeros_like(v_pos)
90
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
91
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
92
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
93
+
94
+ # Normalize, replace zero (degenerated) normals with some default value
95
+ v_nrm = torch.where(
96
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
97
+ )
98
+ v_nrm = F.normalize(v_nrm, dim=1)
99
+
100
+ if torch.is_anomaly_enabled():
101
+ assert torch.all(torch.isfinite(v_nrm))
102
+
103
+ return v_nrm
104
+
105
+ def to(self, device: Optional[str] = None):
106
+ self.v_pos = self.v_pos.to(device)
107
+ self.t_pos_idx = self.t_pos_idx.to(device)
108
+ if self.v_tex is not None:
109
+ self.v_tex = self.v_tex.to(device)
110
+ if self.t_tex_idx is not None:
111
+ self.t_tex_idx = self.t_tex_idx.to(device)
112
+ if self.texture is not None:
113
+ self.texture = self.texture.to(device)
114
+ if self._stitched_v_pos is not None:
115
+ self._stitched_v_pos = self._stitched_v_pos.to(device)
116
+ if self._stitched_t_pos_idx is not None:
117
+ self._stitched_t_pos_idx = self._stitched_t_pos_idx.to(device)
118
+ if self._v_nrm is not None:
119
+ self._v_nrm = self._v_nrm.to(device)
120
+
121
+
122
+ def load_mesh(
123
+ mesh_path: str,
124
+ rescale: bool = False,
125
+ move_to_center: bool = False,
126
+ scale: float = 0.5,
127
+ flip_uv: bool = True,
128
+ merge_vertices: bool = True,
129
+ default_uv_size: int = 2048,
130
+ shape_init_mesh_up: str = "+y",
131
+ shape_init_mesh_front: str = "+x",
132
+ front_x_to_y: bool = False,
133
+ device: Optional[str] = None,
134
+ return_transform: bool = False,
135
+ ) -> TexturedMesh:
136
+ scene = trimesh.load(mesh_path, force="mesh", process=False)
137
+ if isinstance(scene, trimesh.Trimesh):
138
+ mesh = scene
139
+ elif isinstance(scene, trimesh.scene.Scene):
140
+ mesh = trimesh.Trimesh()
141
+ for obj in scene.geometry.values():
142
+ mesh = trimesh.util.concatenate([mesh, obj])
143
+ else:
144
+ raise ValueError(f"Unknown mesh type at {mesh_path}.")
145
+
146
+ # move to center
147
+ if move_to_center:
148
+ centroid = mesh.vertices.mean(0)
149
+ mesh.vertices = mesh.vertices - centroid
150
+
151
+ # rescale
152
+ if rescale:
153
+ max_scale = np.abs(mesh.vertices).max()
154
+ mesh.vertices = mesh.vertices / max_scale * scale
155
+
156
+ dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
157
+ dir2vec = {
158
+ "+x": np.array([1, 0, 0]),
159
+ "+y": np.array([0, 1, 0]),
160
+ "+z": np.array([0, 0, 1]),
161
+ "-x": np.array([-1, 0, 0]),
162
+ "-y": np.array([0, -1, 0]),
163
+ "-z": np.array([0, 0, -1]),
164
+ }
165
+ if shape_init_mesh_up not in dirs or shape_init_mesh_front not in dirs:
166
+ raise ValueError(
167
+ f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
168
+ )
169
+ if shape_init_mesh_up[1] == shape_init_mesh_front[1]:
170
+ raise ValueError(
171
+ "shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
172
+ )
173
+ z_, x_ = (
174
+ dir2vec[shape_init_mesh_up],
175
+ dir2vec[shape_init_mesh_front],
176
+ )
177
+ y_ = np.cross(z_, x_)
178
+ std2mesh = np.stack([x_, y_, z_], axis=0).T
179
+ mesh2std = np.linalg.inv(std2mesh)
180
+ mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
181
+ if front_x_to_y:
182
+ x = mesh.vertices[:, 1].copy()
183
+ y = -mesh.vertices[:, 0].copy()
184
+ mesh.vertices[:, 0] = x
185
+ mesh.vertices[:, 1] = y
186
+
187
+ v_pos = torch.tensor(mesh.vertices, dtype=torch.float32)
188
+ t_pos_idx = torch.tensor(mesh.faces, dtype=torch.int64)
189
+
190
+ if hasattr(mesh, "visual") and hasattr(mesh.visual, "uv"):
191
+ v_tex = torch.tensor(mesh.visual.uv, dtype=torch.float32)
192
+ if flip_uv:
193
+ v_tex[:, 1] = 1.0 - v_tex[:, 1]
194
+ t_tex_idx = t_pos_idx.clone()
195
+ if (
196
+ hasattr(mesh.visual.material, "baseColorTexture")
197
+ and mesh.visual.material.baseColorTexture
198
+ ):
199
+ texture = torch.tensor(
200
+ np.array(mesh.visual.material.baseColorTexture) / 255.0,
201
+ dtype=torch.float32,
202
+ )[..., :3]
203
+ else:
204
+ texture = torch.zeros(
205
+ (default_uv_size, default_uv_size, 3), dtype=torch.float32
206
+ )
207
+ else:
208
+ v_tex = None
209
+ t_tex_idx = None
210
+ texture = None
211
+
212
+ textured_mesh = TexturedMesh(
213
+ v_pos=v_pos,
214
+ t_pos_idx=t_pos_idx,
215
+ v_tex=v_tex,
216
+ t_tex_idx=t_tex_idx,
217
+ texture=texture,
218
+ )
219
+
220
+ if merge_vertices:
221
+ mesh.merge_vertices(merge_tex=True)
222
+ textured_mesh.set_stitched_mesh(
223
+ torch.tensor(mesh.vertices, dtype=torch.float32),
224
+ torch.tensor(mesh.faces, dtype=torch.int64),
225
+ )
226
+
227
+ textured_mesh.to(device)
228
+
229
+ if return_transform:
230
+ return textured_mesh, np.array(centroid), max_scale / scale
231
+
232
+ return textured_mesh
233
+
234
+
235
+ @dataclass
236
+ class RenderOutput:
237
+ attr: Optional[torch.FloatTensor] = None
238
+ mask: Optional[torch.BoolTensor] = None
239
+ depth: Optional[torch.FloatTensor] = None
240
+ normal: Optional[torch.FloatTensor] = None
241
+ pos: Optional[torch.FloatTensor] = None
242
+
243
+
244
+ class NVDiffRastContextWrapper:
245
+ def __init__(self, device: str, context_type: str = "gl"):
246
+ if context_type == "gl":
247
+ self.ctx = dr.RasterizeGLContext(device=device)
248
+ elif context_type == "cuda":
249
+ self.ctx = dr.RasterizeCudaContext(device=device)
250
+ else:
251
+ raise NotImplementedError
252
+
253
+ def rasterize(self, pos, tri, resolution, ranges=None, grad_db=True):
254
+ """
255
+ Rasterize triangles.
256
+
257
+ All input tensors must be contiguous and reside in GPU memory except for the ranges tensor that, if specified, has to reside in CPU memory. The output tensors will be contiguous and reside in GPU memory.
258
+
259
+ Arguments:
260
+ glctx Rasterizer context of type RasterizeGLContext or RasterizeCudaContext.
261
+ pos Vertex position tensor with dtype torch.float32. To enable range mode, this tensor should have a 2D shape [num_vertices, 4]. To enable instanced mode, use a 3D shape [minibatch_size, num_vertices, 4].
262
+ tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32.
263
+ resolution Output resolution as integer tuple (height, width).
264
+ ranges In range mode, tensor with shape [minibatch_size, 2] and dtype torch.int32, specifying start indices and counts into tri. Ignored in instanced mode.
265
+ grad_db Propagate gradients of image-space derivatives of barycentrics into pos in backward pass. Ignored if using an OpenGL context that was not configured to output image-space derivatives.
266
+ Returns:
267
+ A tuple of two tensors. The first output tensor has shape [minibatch_size, height, width, 4] and contains the main rasterizer output in order (u, v, z/w, triangle_id). If the OpenGL context was configured to output image-space derivatives of barycentrics, the second output tensor will also have shape [minibatch_size, height, width, 4] and contain said derivatives in order (du/dX, du/dY, dv/dX, dv/dY). Otherwise it will be an empty tensor with shape [minibatch_size, height, width, 0].
268
+ """
269
+ return dr.rasterize(
270
+ self.ctx, pos.float(), tri.int(), resolution, ranges, grad_db
271
+ )
272
+
273
+ def interpolate(self, attr, rast, tri, rast_db=None, diff_attrs=None):
274
+ """
275
+ Interpolate vertex attributes.
276
+
277
+ All input tensors must be contiguous and reside in GPU memory. The output tensors will be contiguous and reside in GPU memory.
278
+
279
+ Arguments:
280
+ attr Attribute tensor with dtype torch.float32. Shape is [num_vertices, num_attributes] in range mode, or [minibatch_size, num_vertices, num_attributes] in instanced mode. Broadcasting is supported along the minibatch axis.
281
+ rast Main output tensor from rasterize().
282
+ tri Triangle tensor with shape [num_triangles, 3] and dtype torch.int32.
283
+ rast_db (Optional) Tensor containing image-space derivatives of barycentrics, i.e., the second output tensor from rasterize(). Enables computing image-space derivatives of attributes.
284
+ diff_attrs (Optional) List of attribute indices for which image-space derivatives are to be computed. Special value 'all' is equivalent to list [0, 1, ..., num_attributes - 1].
285
+ Returns:
286
+ A tuple of two tensors. The first output tensor contains interpolated attributes and has shape [minibatch_size, height, width, num_attributes]. If rast_db and diff_attrs were specified, the second output tensor contains the image-space derivatives of the selected attributes and has shape [minibatch_size, height, width, 2 * len(diff_attrs)]. The derivatives of the first selected attribute A will be on channels 0 and 1 as (dA/dX, dA/dY), etc. Otherwise, the second output tensor will be an empty tensor with shape [minibatch_size, height, width, 0].
287
+ """
288
+ return dr.interpolate(attr.float(), rast, tri.int(), rast_db, diff_attrs)
289
+
290
+ def texture(
291
+ self,
292
+ tex,
293
+ uv,
294
+ uv_da=None,
295
+ mip_level_bias=None,
296
+ mip=None,
297
+ filter_mode="auto",
298
+ boundary_mode="wrap",
299
+ max_mip_level=None,
300
+ ):
301
+ """
302
+ Perform texture sampling.
303
+
304
+ All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory.
305
+
306
+ Arguments:
307
+ tex Texture tensor with dtype torch.float32. For 2D textures, must have shape [minibatch_size, tex_height, tex_width, tex_channels]. For cube map textures, must have shape [minibatch_size, 6, tex_height, tex_width, tex_channels] where tex_width and tex_height are equal. Note that boundary_mode must also be set to 'cube' to enable cube map mode. Broadcasting is supported along the minibatch axis.
308
+ uv Tensor containing per-pixel texture coordinates. When sampling a 2D texture, must have shape [minibatch_size, height, width, 2]. When sampling a cube map texture, must have shape [minibatch_size, height, width, 3].
309
+ uv_da (Optional) Tensor containing image-space derivatives of texture coordinates. Must have same shape as uv except for the last dimension that is to be twice as long.
310
+ mip_level_bias (Optional) Per-pixel bias for mip level selection. If uv_da is omitted, determines mip level directly. Must have shape [minibatch_size, height, width].
311
+ mip (Optional) Preconstructed mipmap stack from a texture_construct_mip() call, or a list of tensors specifying a custom mipmap stack. When specifying a custom mipmap stack, the tensors in the list must follow the same format as tex except for width and height that must follow the usual rules for mipmap sizes. The base level texture is still supplied in tex and must not be included in the list. Gradients of a custom mipmap stack are not automatically propagated to base texture but the mipmap tensors will receive gradients of their own. If a mipmap stack is not specified but the chosen filter mode requires it, the mipmap stack is constructed internally and discarded afterwards.
312
+ filter_mode Texture filtering mode to be used. Valid values are 'auto', 'nearest', 'linear', 'linear-mipmap-nearest', and 'linear-mipmap-linear'. Mode 'auto' selects 'linear' if neither uv_da or mip_level_bias is specified, and 'linear-mipmap-linear' when at least one of them is specified, these being the highest-quality modes possible depending on the availability of the image-space derivatives of the texture coordinates or direct mip level information.
313
+ boundary_mode Valid values are 'wrap', 'clamp', 'zero', and 'cube'. If tex defines a cube map, this must be set to 'cube'. The default mode 'wrap' takes fractional part of texture coordinates. Mode 'clamp' clamps texture coordinates to the centers of the boundary texels. Mode 'zero' virtually extends the texture with all-zero values in all directions.
314
+ max_mip_level If specified, limits the number of mipmaps constructed and used in mipmap-based filter modes.
315
+ Returns:
316
+ A tensor containing the results of the texture sampling with shape [minibatch_size, height, width, tex_channels]. Cube map fetches with invalid uv coordinates (e.g., zero vectors) output all zeros and do not propagate gradients.
317
+ """
318
+ return dr.texture(
319
+ tex.float(),
320
+ uv.float(),
321
+ uv_da,
322
+ mip_level_bias,
323
+ mip,
324
+ filter_mode,
325
+ boundary_mode,
326
+ max_mip_level,
327
+ )
328
+
329
+ def antialias(
330
+ self, color, rast, pos, tri, topology_hash=None, pos_gradient_boost=1.0
331
+ ):
332
+ """
333
+ Perform antialiasing.
334
+
335
+ All input tensors must be contiguous and reside in GPU memory. The output tensor will be contiguous and reside in GPU memory.
336
+
337
+ Note that silhouette edge determination is based on vertex indices in the triangle tensor. For it to work properly, a vertex belonging to multiple triangles must be referred to using the same vertex index in each triangle. Otherwise, nvdiffrast will always classify the adjacent edges as silhouette edges, which leads to bad performance and potentially incorrect gradients. If you are unsure whether your data is good, check which pixels are modified by the antialias operation and compare to the example in the documentation.
338
+
339
+ Arguments:
340
+ color Input image to antialias with shape [minibatch_size, height, width, num_channels].
341
+ rast Main output tensor from rasterize().
342
+ pos Vertex position tensor used in the rasterization operation.
343
+ tri Triangle tensor used in the rasterization operation.
344
+ topology_hash (Optional) Preconstructed topology hash for the triangle tensor. If not specified, the topology hash is constructed internally and discarded afterwards.
345
+ pos_gradient_boost (Optional) Multiplier for gradients propagated to pos.
346
+ Returns:
347
+ A tensor containing the antialiased image with the same shape as color input tensor.
348
+ """
349
+ return dr.antialias(
350
+ color.float(),
351
+ rast,
352
+ pos.float(),
353
+ tri.int(),
354
+ topology_hash,
355
+ pos_gradient_boost,
356
+ )
357
+
358
+
359
+ def get_clip_space_position(pos: torch.FloatTensor, mvp_mtx: torch.FloatTensor):
360
+ pos_homo = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos)], dim=-1)
361
+ return torch.matmul(pos_homo, mvp_mtx.permute(0, 2, 1))
362
+
363
+
364
+ def transform_points_homo(pos: torch.FloatTensor, mtx: torch.FloatTensor):
365
+ batch_size = pos.shape[0]
366
+ pos_shape = pos.shape[1:-1]
367
+ pos = pos.reshape(batch_size, -1, 3)
368
+ pos_homo = torch.cat([pos, torch.ones_like(pos[..., 0:1])], dim=-1)
369
+ pos = (pos_homo.unsqueeze(2) * mtx.unsqueeze(1)).sum(-1)[..., :3]
370
+ pos = pos.reshape(batch_size, *pos_shape, 3)
371
+ return pos
372
+
373
+
374
+ class DepthNormalizationStrategy(ABC):
375
+ @abstractmethod
376
+ def __init__(self, *args, **kwargs):
377
+ pass
378
+
379
+ @abstractmethod
380
+ def __call__(
381
+ self, depth: torch.FloatTensor, mask: torch.BoolTensor
382
+ ) -> torch.FloatTensor:
383
+ pass
384
+
385
+
386
+ class DepthControlNetNormalization(DepthNormalizationStrategy):
387
+ def __init__(
388
+ self, far_clip: float = 0.25, near_clip: float = 1.0, bg_value: float = 0.0
389
+ ):
390
+ self.far_clip = far_clip
391
+ self.near_clip = near_clip
392
+ self.bg_value = bg_value
393
+
394
+ def __call__(
395
+ self, depth: torch.FloatTensor, mask: torch.BoolTensor
396
+ ) -> torch.FloatTensor:
397
+ batch_size = depth.shape[0]
398
+ min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None]
399
+ max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None]
400
+ depth = 1.0 - ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp(
401
+ 0.0, 1.0
402
+ )
403
+ depth = depth * (self.near_clip - self.far_clip) + self.far_clip
404
+ depth[~mask] = self.bg_value
405
+ return depth
406
+
407
+
408
+ class Zero123PlusPlusNormalization(DepthNormalizationStrategy):
409
+ def __init__(self, bg_value: float = 0.8):
410
+ self.bg_value = bg_value
411
+
412
+ def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor:
413
+ batch_size = depth.shape[0]
414
+ min_depth = depth.view(batch_size, -1).min(dim=-1)[0][:, None, None]
415
+ max_depth = depth.view(batch_size, -1).max(dim=-1)[0][:, None, None]
416
+ depth = ((depth - min_depth) / (max_depth - min_depth + 1e-5)).clamp(0.0, 1.0)
417
+ depth[~mask] = self.bg_value
418
+ return depth
419
+
420
+
421
+ class SimpleNormalization(DepthNormalizationStrategy):
422
+ def __init__(
423
+ self,
424
+ scale: float = 1.0,
425
+ offset: float = -1.0,
426
+ clamp: bool = True,
427
+ bg_value: float = 1.0,
428
+ ):
429
+ self.scale = scale
430
+ self.offset = offset
431
+ self.clamp = clamp
432
+ self.bg_value = bg_value
433
+
434
+ def __call__(self, depth: FloatTensor, mask: BoolTensor) -> FloatTensor:
435
+ depth = depth * self.scale + self.offset
436
+ if self.clamp:
437
+ depth = depth.clamp(0.0, 1.0)
438
+ depth[~mask] = self.bg_value
439
+ return depth
440
+
441
+
442
+ def render(
443
+ ctx: NVDiffRastContextWrapper,
444
+ mesh: TexturedMesh,
445
+ cam: Camera,
446
+ height: int,
447
+ width: int,
448
+ render_attr: bool = True,
449
+ render_depth: bool = True,
450
+ render_normal: bool = True,
451
+ depth_normalization_strategy: DepthNormalizationStrategy = DepthControlNetNormalization(),
452
+ attr_background: Union[float, torch.FloatTensor] = 0.5,
453
+ antialias_attr=False,
454
+ normal_background: Union[float, torch.FloatTensor] = 0.5,
455
+ texture_override=None,
456
+ texture_filter_mode: str = "linear",
457
+ ) -> RenderOutput:
458
+ output_dict = {}
459
+
460
+ v_pos_clip = get_clip_space_position(mesh.v_pos, cam.mvp_mtx)
461
+ rast, _ = ctx.rasterize(v_pos_clip, mesh.t_pos_idx, (height, width), grad_db=True)
462
+ mask = rast[..., 3] > 0
463
+
464
+ gb_pos, _ = ctx.interpolate(mesh.v_pos[None], rast, mesh.t_pos_idx)
465
+ output_dict.update({"mask": mask, "pos": gb_pos})
466
+
467
+ if render_depth:
468
+ gb_pos_vs = transform_points_homo(gb_pos, cam.w2c)
469
+ gb_depth = -gb_pos_vs[..., 2]
470
+ # set background pixels to min depth value for correct min/max calculation
471
+ gb_depth = torch.where(
472
+ mask,
473
+ gb_depth,
474
+ gb_depth.view(gb_depth.shape[0], -1).min(dim=-1)[0][:, None, None],
475
+ )
476
+ gb_depth = depth_normalization_strategy(gb_depth, mask)
477
+ output_dict["depth"] = gb_depth
478
+
479
+ if render_attr:
480
+ tex_c, _ = ctx.interpolate(mesh.v_tex[None], rast, mesh.t_tex_idx)
481
+ texture = (
482
+ texture_override[None]
483
+ if texture_override is not None
484
+ else mesh.texture[None]
485
+ )
486
+ gb_rgb_fg = ctx.texture(texture, tex_c, filter_mode=texture_filter_mode)
487
+ gb_rgb_bg = torch.ones_like(gb_rgb_fg) * attr_background
488
+ gb_rgb = torch.where(mask[..., None], gb_rgb_fg, gb_rgb_bg)
489
+ if antialias_attr:
490
+ gb_rgb = ctx.antialias(gb_rgb, rast, v_pos_clip, mesh.t_pos_idx)
491
+ output_dict["attr"] = gb_rgb
492
+
493
+ if render_normal:
494
+ gb_nrm, _ = ctx.interpolate(mesh.v_nrm[None], rast, mesh.stitched_t_pos_idx)
495
+ gb_nrm = F.normalize(gb_nrm, dim=-1, p=2)
496
+ gb_nrm[~mask] = normal_background
497
+ output_dict["normal"] = gb_nrm
498
+
499
+ return RenderOutput(**output_dict)
mvadapter/utils/saving.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+
8
+
9
+ def tensor_to_image(
10
+ data: Union[Image.Image, torch.Tensor, np.ndarray],
11
+ batched: bool = False,
12
+ format: str = "HWC",
13
+ ) -> Union[Image.Image, List[Image.Image]]:
14
+ if isinstance(data, Image.Image):
15
+ return data
16
+ if isinstance(data, torch.Tensor):
17
+ data = data.detach().cpu().numpy()
18
+ if data.dtype == np.float32 or data.dtype == np.float16:
19
+ data = (data * 255).astype(np.uint8)
20
+ elif data.dtype == np.bool_:
21
+ data = data.astype(np.uint8) * 255
22
+ assert data.dtype == np.uint8
23
+ if format == "CHW":
24
+ if batched and data.ndim == 4:
25
+ data = data.transpose((0, 2, 3, 1))
26
+ elif not batched and data.ndim == 3:
27
+ data = data.transpose((1, 2, 0))
28
+
29
+ if batched:
30
+ return [Image.fromarray(d) for d in data]
31
+ return Image.fromarray(data)
32
+
33
+
34
+ def largest_factor_near_sqrt(n: int) -> int:
35
+ """
36
+ Finds the largest factor of n that is closest to the square root of n.
37
+
38
+ Args:
39
+ n (int): The integer for which to find the largest factor near its square root.
40
+
41
+ Returns:
42
+ int: The largest factor of n that is closest to the square root of n.
43
+ """
44
+ sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root
45
+
46
+ # First, check if the square root itself is a factor
47
+ if sqrt_n * sqrt_n == n:
48
+ return sqrt_n
49
+
50
+ # Otherwise, find the largest factor by iterating from sqrt_n downwards
51
+ for i in range(sqrt_n, 0, -1):
52
+ if n % i == 0:
53
+ return i
54
+
55
+ # If n is 1, return 1
56
+ return 1
57
+
58
+
59
+ def make_image_grid(
60
+ images: List[Image.Image],
61
+ rows: Optional[int] = None,
62
+ cols: Optional[int] = None,
63
+ resize: Optional[int] = None,
64
+ ) -> Image.Image:
65
+ """
66
+ Prepares a single grid of images. Useful for visualization purposes.
67
+ """
68
+ if rows is None and cols is not None:
69
+ assert len(images) % cols == 0
70
+ rows = len(images) // cols
71
+ elif cols is None and rows is not None:
72
+ assert len(images) % rows == 0
73
+ cols = len(images) // rows
74
+ elif rows is None and cols is None:
75
+ rows = largest_factor_near_sqrt(len(images))
76
+ cols = len(images) // rows
77
+
78
+ assert len(images) == rows * cols
79
+
80
+ if resize is not None:
81
+ images = [img.resize((resize, resize)) for img in images]
82
+
83
+ w, h = images[0].size
84
+ grid = Image.new("RGB", size=(cols * w, rows * h))
85
+
86
+ for i, img in enumerate(images):
87
+ grid.paste(img, box=(i % cols * w, i // cols * h))
88
+ return grid
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchvision
2
+ diffusers
3
+ transformers==4.49.0
4
+ einops
5
+ huggingface_hub
6
+ opencv-python
7
+ trimesh==4.5.3
8
+ omegaconf
9
+ scikit-image
10
+ numpy
11
+ peft
12
+ scipy==1.11.4
13
+ jaxtyping
14
+ typeguard
15
+ pymeshlab==2022.2.post4
16
+ open3d
17
+ timm
18
+ kornia
19
+ ninja
20
+ https://huggingface.co/spaces/JeffreyXiang/TRELLIS/resolve/main/wheels/nvdiffrast-0.3.3-cp310-cp310-linux_x86_64.whl?download=true
21
+ cvcuda_cu12
22
+ gltflib
23
+ torch-cluster