zysong212 commited on
Commit
5a3d50d
·
1 Parent(s): b425273

first commit

Browse files
app.py CHANGED
@@ -1,60 +1,93 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
  pipe = pipe.to(device)
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ]
59
 
60
  css = """
@@ -62,93 +95,150 @@ css = """
62
  margin: 0 auto;
63
  max-width: 640px;
64
  }
 
 
 
 
 
 
65
  """
66
 
67
  with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ import logging
5
+ import os
6
+ from glob import glob
7
 
8
+ import numpy as np
 
9
  import torch
10
+ from PIL import Image
11
+ from tqdm.auto import tqdm
12
+
13
+ from depthmaster import DepthMasterPipeline
14
+ from depthmaster.modules.unet_2d_condition import UNet2DConditionModel
15
+
16
+ def load_example(example_image):
17
+ # 返回选中的图片
18
+ return example_image
19
+
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model_repo_id = "zysong212/DepthMaster" # Replace to the model you would like to use
23
+
24
+ # if torch.cuda.is_available():
25
+ # torch_dtype = torch.float16
26
+ # else:
27
+ torch_dtype = torch.float32
28
+
29
+ # pipe = DepthMasterPipeline.from_pretrained('eval', torch_dtype=torch_dtype)
30
+ # unet = UNet2DConditionModel.from_pretrained(os.path.join('eval', f'unet'))
31
+ pipe = DepthMasterPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
32
+ unet = UNet2DConditionModel.from_pretrained(model_repo_id, subfolder="unet", torch_dtype=torch_dtype)
33
+ pipe.unet = unet
34
 
35
+ try:
36
+ pipe.enable_xformers_memory_efficient_attention()
37
+ except ImportError:
38
+ pass # run without xformers
39
 
 
40
  pipe = pipe.to(device)
41
 
42
+ # MAX_SEED = np.iinfo(np.int32).max
43
+ # MAX_IMAGE_SIZE = 1024
44
 
45
 
46
  # @spaces.GPU #[uncomment to use ZeroGPU]
47
  def infer(
48
+ input_image,
 
 
 
 
 
 
 
49
  progress=gr.Progress(track_tqdm=True),
50
  ):
51
+ # if randomize_seed:
52
+ # seed = random.randint(0, MAX_SEED)
53
+
54
+ # generator = torch.Generator().manual_seed(seed)
55
+
56
+ # image = pipe(
57
+ # prompt=prompt,
58
+ # negative_prompt=negative_prompt,
59
+ # guidance_scale=guidance_scale,
60
+ # num_inference_steps=num_inference_steps,
61
+ # width=width,
62
+ # height=height,
63
+ # generator=generator,
64
+ # ).images[0]
65
+ pipe_out = pipe(
66
+ input_image,
67
+ processing_res=768,
68
+ match_input_res=True,
69
+ batch_size=1,
70
+ color_map="Spectral",
71
+ show_progress_bar=True,
72
+ resample_method="bilinear",
73
+ )
74
+
75
+ # depth_pred: np.ndarray = pipe_out.depth_np
76
+ depth_colored: Image.Image = pipe_out.depth_colored
77
+
78
+
79
+ return depth_colored
80
+
81
+
82
+ # 默认图像路径
83
+ example_images = [
84
+ "wild_example/000000000776.jpg",
85
+ "wild_example/800x.jpg",
86
+ "wild_example/000000055950.jpg",
87
+ "wild_example/53441037037_c2cbd91ad2_k.jpg",
88
+ "wild_example/53501906161_6109e3da29_b.jpg",
89
+ "wild_example/m_1e31af1c.jpg",
90
+ "wild_example/sg-11134201-7rd5x-lvlh48byidbqca.jpg"
91
  ]
92
 
93
  css = """
 
95
  margin: 0 auto;
96
  max-width: 640px;
97
  }
98
+ #example-gallery {
99
+ height: 80px; /* 设置缩略图高度 */
100
+ width: auto; /* 保持宽高比 */
101
+ margin: 0 auto; /* 图片间距 */
102
+ cursor: pointer; /* 鼠标指针变为手型 */
103
+ }
104
  """
105
 
106
  with gr.Blocks(css=css) as demo:
107
+ gr.Markdown("# DepthMaster")
108
+ gr.Markdown("Official demo for DepthMaster. Please refer to our [paper](https://arxiv.org/abs/2501.02576), [project page](https://indu1ge.github.io/DepthMaster_page/), and [github](https://github.com/indu1ge/DepthMaster) for more details.")
109
+ gr.Markdown(" ### Depth Estimation with DepthMaster.")
110
+ # with gr.Column(elem_id="col-container"):
111
+ # gr.Markdown(" # Depth Estimation")
112
+ with gr.Row():
113
+ with gr.Column():
114
+ input_image = gr.Image(label="Input Image", type="pil", elem_id="input-image", interactive=True)
115
+ with gr.Column():
116
+ depth_map = gr.Image(label="Depth Map with Slider View", type="pil", interactive=False, elem_id="depth-map")
117
+
118
+ # 计算按钮
119
+ compute_button = gr.Button("Compute Depth")
120
+
121
+ # # 添加示例图片选择器
122
+ # with gr.Row():
123
+ # gr.Markdown("### example images")
124
+ # with gr.Row(elem_id="example-gallery"):
125
+ # example_gallery = gr.Gallery(
126
+ # label="",
127
+ # value=example_images,
128
+ # elem_id="example-gallery",
129
+ # show_label=False,
130
+ # interactive=True,
131
+ # columns=10
132
+ # )
133
+
134
+ # 设置默认图片点击后的操作
135
+ # example_gallery.select(
136
+ # fn=lambda img_path: img_path, # 回调函数:返回选择的路径
137
+ # inputs=[],
138
+ # outputs=input_image # 输出设置为 Input Image
139
+ # )
140
+ # example_gallery.click(
141
+ # fn=load_example, # 选择图片的回调
142
+ # inputs=[example_gallery], # 输入:用户点击的图片
143
+ # outputs=[input_image] # 输出:更新 Input Image
144
+ # )
145
+
146
+
147
+ # 设置计算按钮的回调
148
+ compute_button.click(
149
+ fn=infer, # 回调函数
150
+ inputs=input_image, # 输入
151
+ outputs=depth_map # 输出
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
 
154
+ # 启动 Gradio 应用
155
+ demo.launch()
156
+ # with gr.Column(scale=45):
157
+ # img_in = gr.Image(type="pil")
158
+ # with gr.Column(scale=45):
159
+ # img_out =
160
+
161
+ # with gr.Row():
162
+ # prompt = gr.Text(
163
+ # label="Prompt",
164
+ # show_label=False,
165
+ # max_lines=1,
166
+ # placeholder="Enter your prompt",
167
+ # container=False,
168
+ # )
169
+
170
+ # run_button = gr.Button("Run", scale=0, variant="primary")
171
+
172
+ # result = gr.Image(label="Result", show_label=False)
173
+
174
+ # with gr.Accordion("Advanced Settings", open=False):
175
+ # negative_prompt = gr.Text(
176
+ # label="Negative prompt",
177
+ # max_lines=1,
178
+ # placeholder="Enter a negative prompt",
179
+ # visible=False,
180
+ # )
181
+
182
+ # seed = gr.Slider(
183
+ # label="Seed",
184
+ # minimum=0,
185
+ # maximum=MAX_SEED,
186
+ # step=1,
187
+ # value=0,
188
+ # )
189
+
190
+ # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
191
+
192
+ # with gr.Row():
193
+ # width = gr.Slider(
194
+ # label="Width",
195
+ # minimum=256,
196
+ # maximum=MAX_IMAGE_SIZE,
197
+ # step=32,
198
+ # value=1024, # Replace with defaults that work for your model
199
+ # )
200
+
201
+ # height = gr.Slider(
202
+ # label="Height",
203
+ # minimum=256,
204
+ # maximum=MAX_IMAGE_SIZE,
205
+ # step=32,
206
+ # value=1024, # Replace with defaults that work for your model
207
+ # )
208
+
209
+ # with gr.Row():
210
+ # guidance_scale = gr.Slider(
211
+ # label="Guidance scale",
212
+ # minimum=0.0,
213
+ # maximum=10.0,
214
+ # step=0.1,
215
+ # value=0.0, # Replace with defaults that work for your model
216
+ # )
217
+
218
+ # num_inference_steps = gr.Slider(
219
+ # label="Number of inference steps",
220
+ # minimum=1,
221
+ # maximum=50,
222
+ # step=1,
223
+ # value=2, # Replace with defaults that work for your model
224
+ # )
225
+
226
+ # gr.Examples(examples=examples, inputs=[prompt])
227
+ # gr.on(
228
+ # triggers=[run_button.click, prompt.submit],
229
+ # fn=infer,
230
+ # inputs=[
231
+ # prompt,
232
+ # negative_prompt,
233
+ # seed,
234
+ # randomize_seed,
235
+ # # width,
236
+ # # height,
237
+ # # guidance_scale,
238
+ # # num_inference_steps,
239
+ # ],
240
+ # outputs=[result, seed],
241
+ # )
242
+
243
+ # if __name__ == "__main__":
244
+ # demo.launch()
depthmaster/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ from .depthmaster_pipeline import DepthMasterPipeline, DepthMasterDepthOutput # noqa: F401
depthmaster/depthmaster_pipeline.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ import logging
27
+ from typing import Dict, Optional, Union
28
+
29
+ import numpy as np
30
+ import torch
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DiffusionPipeline,
34
+ # UNet2DConditionModel,
35
+ )
36
+ from depthmaster.modules.unet_2d_condition import UNet2DConditionModel
37
+ from diffusers.utils import BaseOutput
38
+ from PIL import Image
39
+ from torch.utils.data import DataLoader, TensorDataset
40
+ from torchvision.transforms import InterpolationMode
41
+ from torchvision.transforms.functional import pil_to_tensor, resize
42
+ from tqdm.auto import tqdm
43
+ from transformers import CLIPTextModel, CLIPTokenizer
44
+
45
+ from .util.image_util import (
46
+ chw2hwc,
47
+ colorize_depth_maps,
48
+ get_tv_resample_method,
49
+ resize_max_res,
50
+ )
51
+
52
+ class DepthMasterDepthOutput(BaseOutput):
53
+ """
54
+ Output class for monocular depth prediction pipeline.
55
+
56
+ Args:
57
+ depth_np (`np.ndarray`):
58
+ Predicted depth map, with depth values in the range of [0, 1].
59
+ depth_colored (`PIL.Image.Image`):
60
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
61
+ uncertainty (`None` or `np.ndarray`):
62
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
63
+ """
64
+
65
+ depth_np: np.ndarray
66
+ depth_colored: Union[None, Image.Image]
67
+ uncertainty: Union[None, np.ndarray]
68
+
69
+
70
+ class DepthMasterPipeline(DiffusionPipeline):
71
+ """
72
+ Pipeline for monocular depth estimation using DepthMaster.
73
+
74
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
75
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
76
+
77
+ Args:
78
+ unet (`UNet2DConditionModel`):
79
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
80
+ vae (`AutoencoderKL`):
81
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
82
+ to and from latent representations.
83
+ scheduler (`DDIMScheduler`):
84
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
85
+ text_encoder (`CLIPTextModel`):
86
+ Text-encoder, for empty text embedding.
87
+ tokenizer (`CLIPTokenizer`):
88
+ CLIP tokenizer.
89
+ scale_invariant (`bool`, *optional*):
90
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
91
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
92
+ "affine-invariant". NB: overriding this value is not supported.
93
+ shift_invariant (`bool`, *optional*):
94
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
95
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
96
+ "affine-invariant". NB: overriding this value is not supported.
97
+ default_denoising_steps (`int`, *optional*):
98
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
99
+ quality with the given model. This value must be set in the model config. When the pipeline is called
100
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
101
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
102
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
103
+ default_processing_resolution (`int`, *optional*):
104
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
105
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
106
+ default value is used. This is required to ensure reasonable results with various model flavors trained
107
+ with varying optimal processing resolution values.
108
+ """
109
+
110
+ rgb_latent_scale_factor = 0.18215
111
+ depth_latent_scale_factor = 0.18215
112
+
113
+ def __init__(
114
+ self,
115
+ unet: UNet2DConditionModel,
116
+ vae: AutoencoderKL,
117
+ text_encoder: CLIPTextModel,
118
+ tokenizer: CLIPTokenizer,
119
+ scale_invariant: Optional[bool] = True,
120
+ shift_invariant: Optional[bool] = True,
121
+ default_processing_resolution: Optional[int] = None,
122
+ ):
123
+ super().__init__()
124
+
125
+ # unet = UNet2DConditionModel.from_pretrained('/zssd/szy/Marigold_rgb2d/ckpt/eval/unet')
126
+
127
+ self.register_modules(
128
+ unet=unet,
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ tokenizer=tokenizer,
132
+ )
133
+ self.register_to_config(
134
+ scale_invariant=scale_invariant,
135
+ shift_invariant=shift_invariant,
136
+ default_processing_resolution=default_processing_resolution,
137
+ )
138
+
139
+ self.scale_invariant = scale_invariant
140
+ self.shift_invariant = shift_invariant
141
+ self.default_processing_resolution = default_processing_resolution
142
+
143
+ self.empty_text_embed = None
144
+
145
+ @torch.no_grad()
146
+ def __call__(
147
+ self,
148
+ input_image: Union[Image.Image, torch.Tensor],
149
+ processing_res: Optional[int] = None,
150
+ match_input_res: bool = True,
151
+ resample_method: str = "bilinear",
152
+ batch_size: int = 0,
153
+ color_map: str = "Spectral",
154
+ show_progress_bar: bool = True,
155
+ ) -> DepthMasterDepthOutput:
156
+ """
157
+ Function invoked when calling the pipeline.
158
+
159
+ Args:
160
+ input_image (`Image`):
161
+ Input RGB (or gray-scale) image.
162
+ processing_res (`int`, *optional*, defaults to `None`):
163
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
164
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
165
+ value `None` resolves to the optimal value from the model config.
166
+ match_input_res (`bool`, *optional*, defaults to `True`):
167
+ Resize depth prediction to match input resolution.
168
+ Only valid if `processing_res` > 0.
169
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
170
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
171
+ batch_size (`int`, *optional*, defaults to `0`):
172
+ Inference batch size, no bigger than `num_ensemble`.
173
+ If set to 0, the script will automatically decide the proper batch size.
174
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
175
+ Display a progress bar of diffusion denoising.
176
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
177
+ Colormap used to colorize the depth map.
178
+ Returns:
179
+ `DepthMasterDepthOutput`: Output class for DepthMaster monocular depth prediction pipeline, including:
180
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
181
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
182
+ """
183
+ # Model-specific optimal default values leading to fast and reasonable results.
184
+ if processing_res is None:
185
+ processing_res = self.default_processing_resolution
186
+
187
+ assert processing_res >= 0
188
+
189
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
190
+
191
+ # ----------------- Image Preprocess -----------------
192
+ # Convert to torch tensor
193
+ if isinstance(input_image, Image.Image):
194
+ input_image = input_image.convert("RGB")
195
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
196
+ rgb = pil_to_tensor(input_image)
197
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
198
+ elif isinstance(input_image, torch.Tensor):
199
+ rgb = input_image
200
+ else:
201
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
202
+ input_size = rgb.shape
203
+ assert (
204
+ 4 == rgb.dim() and 3 == input_size[-3]
205
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
206
+ # --------------- Image Processing ------------------------
207
+ # Resize image
208
+ if processing_res > 0:
209
+ rgb = resize_max_res(
210
+ rgb,
211
+ max_edge_resolution=processing_res,
212
+ resample_method=resample_method,
213
+ )
214
+
215
+ # Normalize rgb values
216
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
217
+ rgb_norm = rgb_norm.to(self.dtype)
218
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
219
+
220
+ # ----------------- Predicting depth -----------------
221
+ # Batch repeated input image
222
+ duplicated_rgb = rgb_norm.expand(1, -1, -1, -1)
223
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
224
+ # find the batch size
225
+ if batch_size > 0:
226
+ _bs = batch_size
227
+ else:
228
+ _bs = 1
229
+
230
+ single_rgb_loader = DataLoader(
231
+ single_rgb_dataset, batch_size=_bs, shuffle=False
232
+ )
233
+
234
+ # Predict depth maps (batched)
235
+ depth_pred_ls = []
236
+ if show_progress_bar:
237
+ iterable = tqdm(
238
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
239
+ )
240
+ else:
241
+ iterable = single_rgb_loader
242
+ for batch in iterable:
243
+ (batched_img,) = batch # here the image is still around 0-1
244
+ depth_pred_raw = self.single_infer(
245
+ rgb_in=batched_img,
246
+ )
247
+ depth_pred_ls.append(depth_pred_raw.detach())
248
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
249
+ torch.cuda.empty_cache() # clear vram cache for ensembling
250
+
251
+ depth_pred = depth_preds
252
+ pred_uncert = None
253
+
254
+ # Resize back to original resolution
255
+ if match_input_res:
256
+ depth_pred = resize(
257
+ depth_pred,
258
+ input_size[-2:],
259
+ interpolation=resample_method,
260
+ antialias=True,
261
+ )
262
+
263
+ # Convert to numpy
264
+ depth_pred = depth_pred.squeeze()
265
+ depth_pred = depth_pred.cpu().numpy()
266
+ if pred_uncert is not None:
267
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
268
+
269
+ # Clip output range
270
+ depth_pred = depth_pred.clip(0, 1)
271
+
272
+ # Colorize
273
+ if color_map is not None:
274
+ depth_colored = colorize_depth_maps(
275
+ depth_pred, 0, 1, cmap=color_map
276
+ ).squeeze() # [3, H, W], value in (0, 1)
277
+ depth_colored = (depth_colored * 255).astype(np.uint8)
278
+ depth_colored_hwc = chw2hwc(depth_colored)
279
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
280
+ else:
281
+ depth_colored_img = None
282
+
283
+ return DepthMasterDepthOutput(
284
+ depth_np=depth_pred,
285
+ depth_colored=depth_colored_img,
286
+ uncertainty=pred_uncert,
287
+ )
288
+
289
+
290
+ def encode_empty_text(self):
291
+ """
292
+ Encode text embedding for empty prompt
293
+ """
294
+ prompt = ""
295
+ text_inputs = self.tokenizer(
296
+ prompt,
297
+ padding="do_not_pad",
298
+ max_length=self.tokenizer.model_max_length,
299
+ truncation=True,
300
+ return_tensors="pt",
301
+ )
302
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2]
303
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024]
304
+
305
+ @torch.no_grad()
306
+ def single_infer(
307
+ self,
308
+ rgb_in: torch.Tensor,
309
+ ) -> torch.Tensor:
310
+ """
311
+ Perform an individual depth prediction without ensembling.
312
+
313
+ Args:
314
+ rgb_in (`torch.Tensor`):
315
+ Input RGB image.
316
+ Returns:
317
+ `torch.Tensor`: Predicted depth map.
318
+ """
319
+ device = self.device
320
+ rgb_in = rgb_in.to(device)
321
+
322
+ # Encode image
323
+ rgb_latent = self.encode_rgb(rgb_in) # 1/8 Resolution with a channel nums of 4.
324
+
325
+
326
+ # Batched empty text embedding
327
+ if self.empty_text_embed is None:
328
+ self.encode_empty_text()
329
+ batch_empty_text_embed = self.empty_text_embed.repeat(
330
+ (rgb_latent.shape[0], 1, 1)
331
+ ).to(device) # [B, 2, 1024]
332
+
333
+
334
+ unet_output = self.unet(
335
+ rgb_latent,
336
+ 1,
337
+ encoder_hidden_states=batch_empty_text_embed,
338
+ ).sample # [B, 4, h, w]
339
+
340
+ torch.cuda.empty_cache()
341
+ depth = self.decode_depth(unet_output) # [B, 1, h, w]
342
+
343
+ # clip prediction
344
+ depth = torch.clip(depth, -1.0, 1.0)
345
+ # shift to [0, 1]
346
+ depth = (depth + 1.0) / 2.0
347
+
348
+ return depth
349
+
350
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
351
+ """
352
+ Encode RGB image into latent.
353
+
354
+ Args:
355
+ rgb_in (`torch.Tensor`):
356
+ Input RGB image to be encoded.
357
+
358
+ Returns:
359
+ `torch.Tensor`: Image latent.
360
+ """
361
+ # encode
362
+ h = self.vae.encoder(rgb_in)
363
+ moments = self.vae.quant_conv(h)
364
+ mean, logvar = torch.chunk(moments, 2, dim=1)
365
+ # scale latent
366
+ rgb_latent = mean * self.rgb_latent_scale_factor
367
+ return rgb_latent
368
+
369
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
370
+ """
371
+ Decode depth latent into depth map.
372
+
373
+ Args:
374
+ depth_latent (`torch.Tensor`):
375
+ Depth latent to be decoded.
376
+
377
+ Returns:
378
+ `torch.Tensor`: Decoded depth map.
379
+ """
380
+ # scale latent
381
+ depth_latent = depth_latent / self.depth_latent_scale_factor
382
+ # decode
383
+ z = self.vae.post_quant_conv(depth_latent)
384
+ stacked = self.vae.decoder(z)
385
+ # mean of output channels
386
+ depth_mean = stacked.mean(dim=1, keepdim=True)
387
+ return depth_mean
depthmaster/modules/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
depthmaster/modules/unet_2d_condition.py ADDED
@@ -0,0 +1,1322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnAddedKVProcessor,
32
+ AttnProcessor,
33
+ FusedAttnProcessor2_0,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ GLIGENTextBoundingboxProjection,
38
+ ImageHintTimeEmbedding,
39
+ ImageProjection,
40
+ ImageTimeEmbedding,
41
+ TextImageProjection,
42
+ TextImageTimeEmbedding,
43
+ TextTimeEmbedding,
44
+ TimestepEmbedding,
45
+ Timesteps,
46
+ )
47
+ from diffusers.models.modeling_utils import ModelMixin
48
+ from depthmaster.modules.unet_2d_blocks import (
49
+ get_down_block,
50
+ get_mid_block,
51
+ get_up_block,
52
+ BlockFE,
53
+ )
54
+
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ @dataclass
60
+ class UNet2DConditionOutput(BaseOutput):
61
+ """
62
+ The output of [`UNet2DConditionModel`].
63
+
64
+ Args:
65
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
66
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
67
+ """
68
+
69
+ sample: torch.Tensor = None
70
+ feat_64: torch.Tensor = None
71
+
72
+
73
+ class UNet2DConditionModel(
74
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
75
+ ):
76
+ r"""
77
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
78
+ shaped output.
79
+
80
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
81
+ for all models (such as downloading or saving).
82
+
83
+ Parameters:
84
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
85
+ Height and width of input/output sample.
86
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
87
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
88
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
89
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
90
+ Whether to flip the sin to cos in the time embedding.
91
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
92
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
93
+ The tuple of downsample blocks to use.
94
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
95
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
96
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
97
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
98
+ The tuple of upsample blocks to use.
99
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
100
+ Whether to include self-attention in the basic transformer blocks, see
101
+ [`~models.attention.BasicTransformerBlock`].
102
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
103
+ The tuple of output channels for each block.
104
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
105
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
106
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
107
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
108
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
109
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
110
+ If `None`, normalization and activation layers is skipped in post-processing.
111
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
112
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
113
+ The dimension of the cross attention features.
114
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
115
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
116
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
117
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
118
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
119
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
120
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
121
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
122
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
123
+ encoder_hid_dim (`int`, *optional*, defaults to None):
124
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
125
+ dimension to `cross_attention_dim`.
126
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
127
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
128
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
129
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
130
+ num_attention_heads (`int`, *optional*):
131
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
132
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
133
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
134
+ class_embed_type (`str`, *optional*, defaults to `None`):
135
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
136
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
137
+ addition_embed_type (`str`, *optional*, defaults to `None`):
138
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
139
+ "text". "text" will use the `TextTimeEmbedding` layer.
140
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
141
+ Dimension for the timestep embeddings.
142
+ num_class_embeds (`int`, *optional*, defaults to `None`):
143
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
144
+ class conditioning with `class_embed_type` equal to `None`.
145
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
146
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
147
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
148
+ An optional override for the dimension of the projected time embedding.
149
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
150
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
151
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
152
+ timestep_post_act (`str`, *optional*, defaults to `None`):
153
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
154
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
155
+ The dimension of `cond_proj` layer in the timestep embedding.
156
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
157
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
158
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
159
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
160
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
161
+ embeddings with the class embeddings.
162
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
163
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
164
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
165
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
166
+ otherwise.
167
+ """
168
+
169
+ _supports_gradient_checkpointing = True
170
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
171
+
172
+ @register_to_config
173
+ def __init__(
174
+ self,
175
+ sample_size: Optional[int] = None,
176
+ in_channels: int = 4,
177
+ out_channels: int = 4,
178
+ center_input_sample: bool = False,
179
+ flip_sin_to_cos: bool = True,
180
+ freq_shift: int = 0,
181
+ down_block_types: Tuple[str] = (
182
+ "CrossAttnDownBlock2D",
183
+ "CrossAttnDownBlock2D",
184
+ "CrossAttnDownBlock2D",
185
+ "DownBlock2D",
186
+ ),
187
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
188
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
189
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
190
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
191
+ layers_per_block: Union[int, Tuple[int]] = 2,
192
+ downsample_padding: int = 1,
193
+ mid_block_scale_factor: float = 1,
194
+ dropout: float = 0.0,
195
+ act_fn: str = "silu",
196
+ norm_num_groups: Optional[int] = 32,
197
+ norm_eps: float = 1e-5,
198
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
199
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
200
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
201
+ encoder_hid_dim: Optional[int] = None,
202
+ encoder_hid_dim_type: Optional[str] = None,
203
+ attention_head_dim: Union[int, Tuple[int]] = 8,
204
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
205
+ dual_cross_attention: bool = False,
206
+ use_linear_projection: bool = False,
207
+ class_embed_type: Optional[str] = None,
208
+ addition_embed_type: Optional[str] = None,
209
+ addition_time_embed_dim: Optional[int] = None,
210
+ num_class_embeds: Optional[int] = None,
211
+ upcast_attention: bool = False,
212
+ resnet_time_scale_shift: str = "default",
213
+ resnet_skip_time_act: bool = False,
214
+ resnet_out_scale_factor: float = 1.0,
215
+ time_embedding_type: str = "positional",
216
+ time_embedding_dim: Optional[int] = None,
217
+ time_embedding_act_fn: Optional[str] = None,
218
+ timestep_post_act: Optional[str] = None,
219
+ time_cond_proj_dim: Optional[int] = None,
220
+ conv_in_kernel: int = 3,
221
+ conv_out_kernel: int = 3,
222
+ projection_class_embeddings_input_dim: Optional[int] = None,
223
+ attention_type: str = "default",
224
+ class_embeddings_concat: bool = False,
225
+ mid_block_only_cross_attention: Optional[bool] = None,
226
+ cross_attention_norm: Optional[str] = None,
227
+ addition_embed_type_num_heads: int = 64,
228
+ ):
229
+ super().__init__()
230
+ # print('loaded correct file')
231
+
232
+ self.sample_size = sample_size
233
+
234
+ if num_attention_heads is not None:
235
+ raise ValueError(
236
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
237
+ )
238
+
239
+ # If `num_attention_heads` is not defined (which is the case for most models)
240
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
241
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
242
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
243
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
244
+ # which is why we correct for the naming here.
245
+ num_attention_heads = num_attention_heads or attention_head_dim
246
+
247
+ # Check inputs
248
+ self._check_config(
249
+ down_block_types=down_block_types,
250
+ up_block_types=up_block_types,
251
+ only_cross_attention=only_cross_attention,
252
+ block_out_channels=block_out_channels,
253
+ layers_per_block=layers_per_block,
254
+ cross_attention_dim=cross_attention_dim,
255
+ transformer_layers_per_block=transformer_layers_per_block,
256
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
257
+ attention_head_dim=attention_head_dim,
258
+ num_attention_heads=num_attention_heads,
259
+ )
260
+
261
+ # input
262
+ conv_in_padding = (conv_in_kernel - 1) // 2
263
+ self.conv_in = nn.Conv2d(
264
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
265
+ )
266
+
267
+ # time
268
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
269
+ time_embedding_type,
270
+ block_out_channels=block_out_channels,
271
+ flip_sin_to_cos=flip_sin_to_cos,
272
+ freq_shift=freq_shift,
273
+ time_embedding_dim=time_embedding_dim,
274
+ )
275
+
276
+ self.time_embedding = TimestepEmbedding(
277
+ timestep_input_dim,
278
+ time_embed_dim,
279
+ act_fn=act_fn,
280
+ post_act_fn=timestep_post_act,
281
+ cond_proj_dim=time_cond_proj_dim,
282
+ )
283
+
284
+ self._set_encoder_hid_proj(
285
+ encoder_hid_dim_type,
286
+ cross_attention_dim=cross_attention_dim,
287
+ encoder_hid_dim=encoder_hid_dim,
288
+ )
289
+
290
+ # class embedding
291
+ self._set_class_embedding(
292
+ class_embed_type,
293
+ act_fn=act_fn,
294
+ num_class_embeds=num_class_embeds,
295
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
296
+ time_embed_dim=time_embed_dim,
297
+ timestep_input_dim=timestep_input_dim,
298
+ )
299
+
300
+ self._set_add_embedding(
301
+ addition_embed_type,
302
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
303
+ addition_time_embed_dim=addition_time_embed_dim,
304
+ cross_attention_dim=cross_attention_dim,
305
+ encoder_hid_dim=encoder_hid_dim,
306
+ flip_sin_to_cos=flip_sin_to_cos,
307
+ freq_shift=freq_shift,
308
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
309
+ time_embed_dim=time_embed_dim,
310
+ )
311
+
312
+ if time_embedding_act_fn is None:
313
+ self.time_embed_act = None
314
+ else:
315
+ self.time_embed_act = get_activation(time_embedding_act_fn)
316
+
317
+ self.down_blocks = nn.ModuleList([])
318
+ self.up_blocks = nn.ModuleList([])
319
+
320
+ if isinstance(only_cross_attention, bool):
321
+ if mid_block_only_cross_attention is None:
322
+ mid_block_only_cross_attention = only_cross_attention
323
+
324
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
325
+
326
+ if mid_block_only_cross_attention is None:
327
+ mid_block_only_cross_attention = False
328
+
329
+ if isinstance(num_attention_heads, int):
330
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
331
+
332
+ if isinstance(attention_head_dim, int):
333
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
334
+
335
+ if isinstance(cross_attention_dim, int):
336
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
337
+
338
+ if isinstance(layers_per_block, int):
339
+ layers_per_block = [layers_per_block] * len(down_block_types)
340
+
341
+ if isinstance(transformer_layers_per_block, int):
342
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
343
+
344
+ if class_embeddings_concat:
345
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
346
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
347
+ # regular time embeddings
348
+ blocks_time_embed_dim = time_embed_dim * 2
349
+ else:
350
+ blocks_time_embed_dim = time_embed_dim
351
+
352
+ # down
353
+ output_channel = block_out_channels[0]
354
+ for i, down_block_type in enumerate(down_block_types):
355
+ input_channel = output_channel
356
+ output_channel = block_out_channels[i]
357
+ is_final_block = i == len(block_out_channels) - 1
358
+
359
+ down_block = get_down_block(
360
+ down_block_type,
361
+ num_layers=layers_per_block[i],
362
+ transformer_layers_per_block=transformer_layers_per_block[i],
363
+ in_channels=input_channel,
364
+ out_channels=output_channel,
365
+ temb_channels=blocks_time_embed_dim,
366
+ add_downsample=not is_final_block,
367
+ resnet_eps=norm_eps,
368
+ resnet_act_fn=act_fn,
369
+ resnet_groups=norm_num_groups,
370
+ cross_attention_dim=cross_attention_dim[i],
371
+ num_attention_heads=num_attention_heads[i],
372
+ downsample_padding=downsample_padding,
373
+ dual_cross_attention=dual_cross_attention,
374
+ use_linear_projection=use_linear_projection,
375
+ only_cross_attention=only_cross_attention[i],
376
+ upcast_attention=upcast_attention,
377
+ resnet_time_scale_shift=resnet_time_scale_shift,
378
+ attention_type=attention_type,
379
+ resnet_skip_time_act=resnet_skip_time_act,
380
+ resnet_out_scale_factor=resnet_out_scale_factor,
381
+ cross_attention_norm=cross_attention_norm,
382
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
383
+ dropout=dropout,
384
+ )
385
+ self.down_blocks.append(down_block)
386
+
387
+ # mid
388
+ self.mid_block = get_mid_block(
389
+ mid_block_type,
390
+ temb_channels=blocks_time_embed_dim,
391
+ in_channels=block_out_channels[-1],
392
+ resnet_eps=norm_eps,
393
+ resnet_act_fn=act_fn,
394
+ resnet_groups=norm_num_groups,
395
+ output_scale_factor=mid_block_scale_factor,
396
+ transformer_layers_per_block=transformer_layers_per_block[-1],
397
+ num_attention_heads=num_attention_heads[-1],
398
+ cross_attention_dim=cross_attention_dim[-1],
399
+ dual_cross_attention=dual_cross_attention,
400
+ use_linear_projection=use_linear_projection,
401
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
402
+ upcast_attention=upcast_attention,
403
+ resnet_time_scale_shift=resnet_time_scale_shift,
404
+ attention_type=attention_type,
405
+ resnet_skip_time_act=resnet_skip_time_act,
406
+ cross_attention_norm=cross_attention_norm,
407
+ attention_head_dim=attention_head_dim[-1],
408
+ dropout=dropout,
409
+ )
410
+
411
+ self.fftblock = BlockFE()
412
+
413
+ # count how many layers upsample the images
414
+ self.num_upsamplers = 0
415
+
416
+ # up
417
+ reversed_block_out_channels = list(reversed(block_out_channels))
418
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
419
+ reversed_layers_per_block = list(reversed(layers_per_block))
420
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
421
+ reversed_transformer_layers_per_block = (
422
+ list(reversed(transformer_layers_per_block))
423
+ if reverse_transformer_layers_per_block is None
424
+ else reverse_transformer_layers_per_block
425
+ )
426
+ only_cross_attention = list(reversed(only_cross_attention))
427
+
428
+ output_channel = reversed_block_out_channels[0]
429
+ for i, up_block_type in enumerate(up_block_types):
430
+ is_final_block = i == len(block_out_channels) - 1
431
+
432
+ prev_output_channel = output_channel
433
+ output_channel = reversed_block_out_channels[i]
434
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
435
+
436
+ # add upsample block for all BUT final layer
437
+ if not is_final_block:
438
+ add_upsample = True
439
+ self.num_upsamplers += 1
440
+ else:
441
+ add_upsample = False
442
+
443
+ up_block = get_up_block(
444
+ up_block_type,
445
+ num_layers=reversed_layers_per_block[i] + 1,
446
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
447
+ in_channels=input_channel,
448
+ out_channels=output_channel,
449
+ prev_output_channel=prev_output_channel,
450
+ temb_channels=blocks_time_embed_dim,
451
+ add_upsample=add_upsample,
452
+ resnet_eps=norm_eps,
453
+ resnet_act_fn=act_fn,
454
+ resolution_idx=i,
455
+ resnet_groups=norm_num_groups,
456
+ cross_attention_dim=reversed_cross_attention_dim[i],
457
+ num_attention_heads=reversed_num_attention_heads[i],
458
+ dual_cross_attention=dual_cross_attention,
459
+ use_linear_projection=use_linear_projection,
460
+ only_cross_attention=only_cross_attention[i],
461
+ upcast_attention=upcast_attention,
462
+ resnet_time_scale_shift=resnet_time_scale_shift,
463
+ attention_type=attention_type,
464
+ resnet_skip_time_act=resnet_skip_time_act,
465
+ resnet_out_scale_factor=resnet_out_scale_factor,
466
+ cross_attention_norm=cross_attention_norm,
467
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
468
+ dropout=dropout,
469
+ )
470
+ self.up_blocks.append(up_block)
471
+ prev_output_channel = output_channel
472
+
473
+ # out
474
+ if norm_num_groups is not None:
475
+ self.conv_norm_out = nn.GroupNorm(
476
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
477
+ )
478
+
479
+ self.conv_act = get_activation(act_fn)
480
+
481
+ else:
482
+ self.conv_norm_out = None
483
+ self.conv_act = None
484
+
485
+ conv_out_padding = (conv_out_kernel - 1) // 2
486
+ self.conv_out = nn.Conv2d(
487
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
488
+ )
489
+
490
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
491
+
492
+ def _check_config(
493
+ self,
494
+ down_block_types: Tuple[str],
495
+ up_block_types: Tuple[str],
496
+ only_cross_attention: Union[bool, Tuple[bool]],
497
+ block_out_channels: Tuple[int],
498
+ layers_per_block: Union[int, Tuple[int]],
499
+ cross_attention_dim: Union[int, Tuple[int]],
500
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
501
+ reverse_transformer_layers_per_block: bool,
502
+ attention_head_dim: int,
503
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
504
+ ):
505
+ if len(down_block_types) != len(up_block_types):
506
+ raise ValueError(
507
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
508
+ )
509
+
510
+ if len(block_out_channels) != len(down_block_types):
511
+ raise ValueError(
512
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
513
+ )
514
+
515
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
516
+ raise ValueError(
517
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
518
+ )
519
+
520
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
521
+ raise ValueError(
522
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
523
+ )
524
+
525
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
526
+ raise ValueError(
527
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
528
+ )
529
+
530
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
531
+ raise ValueError(
532
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
533
+ )
534
+
535
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
536
+ raise ValueError(
537
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
538
+ )
539
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
540
+ for layer_number_per_block in transformer_layers_per_block:
541
+ if isinstance(layer_number_per_block, list):
542
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
543
+
544
+ def _set_time_proj(
545
+ self,
546
+ time_embedding_type: str,
547
+ block_out_channels: int,
548
+ flip_sin_to_cos: bool,
549
+ freq_shift: float,
550
+ time_embedding_dim: int,
551
+ ) -> Tuple[int, int]:
552
+ if time_embedding_type == "fourier":
553
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
554
+ if time_embed_dim % 2 != 0:
555
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
556
+ self.time_proj = GaussianFourierProjection(
557
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
558
+ )
559
+ timestep_input_dim = time_embed_dim
560
+ elif time_embedding_type == "positional":
561
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
562
+
563
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
564
+ timestep_input_dim = block_out_channels[0]
565
+ else:
566
+ raise ValueError(
567
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
568
+ )
569
+
570
+ return time_embed_dim, timestep_input_dim
571
+
572
+ def _set_encoder_hid_proj(
573
+ self,
574
+ encoder_hid_dim_type: Optional[str],
575
+ cross_attention_dim: Union[int, Tuple[int]],
576
+ encoder_hid_dim: Optional[int],
577
+ ):
578
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
579
+ encoder_hid_dim_type = "text_proj"
580
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
581
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
582
+
583
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
584
+ raise ValueError(
585
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
586
+ )
587
+
588
+ if encoder_hid_dim_type == "text_proj":
589
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
590
+ elif encoder_hid_dim_type == "text_image_proj":
591
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
592
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
593
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
594
+ self.encoder_hid_proj = TextImageProjection(
595
+ text_embed_dim=encoder_hid_dim,
596
+ image_embed_dim=cross_attention_dim,
597
+ cross_attention_dim=cross_attention_dim,
598
+ )
599
+ elif encoder_hid_dim_type == "image_proj":
600
+ # Kandinsky 2.2
601
+ self.encoder_hid_proj = ImageProjection(
602
+ image_embed_dim=encoder_hid_dim,
603
+ cross_attention_dim=cross_attention_dim,
604
+ )
605
+ elif encoder_hid_dim_type is not None:
606
+ raise ValueError(
607
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
608
+ )
609
+ else:
610
+ self.encoder_hid_proj = None
611
+
612
+ def _set_class_embedding(
613
+ self,
614
+ class_embed_type: Optional[str],
615
+ act_fn: str,
616
+ num_class_embeds: Optional[int],
617
+ projection_class_embeddings_input_dim: Optional[int],
618
+ time_embed_dim: int,
619
+ timestep_input_dim: int,
620
+ ):
621
+ if class_embed_type is None and num_class_embeds is not None:
622
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
623
+ elif class_embed_type == "timestep":
624
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
625
+ elif class_embed_type == "identity":
626
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
627
+ elif class_embed_type == "projection":
628
+ if projection_class_embeddings_input_dim is None:
629
+ raise ValueError(
630
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
631
+ )
632
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
633
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
634
+ # 2. it projects from an arbitrary input dimension.
635
+ #
636
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
637
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
638
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
639
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
640
+ elif class_embed_type == "simple_projection":
641
+ if projection_class_embeddings_input_dim is None:
642
+ raise ValueError(
643
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
644
+ )
645
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
646
+ else:
647
+ self.class_embedding = None
648
+
649
+ def _set_add_embedding(
650
+ self,
651
+ addition_embed_type: str,
652
+ addition_embed_type_num_heads: int,
653
+ addition_time_embed_dim: Optional[int],
654
+ flip_sin_to_cos: bool,
655
+ freq_shift: float,
656
+ cross_attention_dim: Optional[int],
657
+ encoder_hid_dim: Optional[int],
658
+ projection_class_embeddings_input_dim: Optional[int],
659
+ time_embed_dim: int,
660
+ ):
661
+ if addition_embed_type == "text":
662
+ if encoder_hid_dim is not None:
663
+ text_time_embedding_from_dim = encoder_hid_dim
664
+ else:
665
+ text_time_embedding_from_dim = cross_attention_dim
666
+
667
+ self.add_embedding = TextTimeEmbedding(
668
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
669
+ )
670
+ elif addition_embed_type == "text_image":
671
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
672
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
673
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
674
+ self.add_embedding = TextImageTimeEmbedding(
675
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
676
+ )
677
+ elif addition_embed_type == "text_time":
678
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
679
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
680
+ elif addition_embed_type == "image":
681
+ # Kandinsky 2.2
682
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
683
+ elif addition_embed_type == "image_hint":
684
+ # Kandinsky 2.2 ControlNet
685
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
686
+ elif addition_embed_type is not None:
687
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
688
+
689
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
690
+ if attention_type in ["gated", "gated-text-image"]:
691
+ positive_len = 768
692
+ if isinstance(cross_attention_dim, int):
693
+ positive_len = cross_attention_dim
694
+ elif isinstance(cross_attention_dim, (list, tuple)):
695
+ positive_len = cross_attention_dim[0]
696
+
697
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
698
+ self.position_net = GLIGENTextBoundingboxProjection(
699
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
700
+ )
701
+
702
+ @property
703
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
704
+ r"""
705
+ Returns:
706
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
707
+ indexed by its weight name.
708
+ """
709
+ # set recursively
710
+ processors = {}
711
+
712
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
713
+ if hasattr(module, "get_processor"):
714
+ processors[f"{name}.processor"] = module.get_processor()
715
+
716
+ for sub_name, child in module.named_children():
717
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
718
+
719
+ return processors
720
+
721
+ for name, module in self.named_children():
722
+ fn_recursive_add_processors(name, module, processors)
723
+
724
+ return processors
725
+
726
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
727
+ r"""
728
+ Sets the attention processor to use to compute attention.
729
+
730
+ Parameters:
731
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
732
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
733
+ for **all** `Attention` layers.
734
+
735
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
736
+ processor. This is strongly recommended when setting trainable attention processors.
737
+
738
+ """
739
+ count = len(self.attn_processors.keys())
740
+
741
+ if isinstance(processor, dict) and len(processor) != count:
742
+ raise ValueError(
743
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
744
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
745
+ )
746
+
747
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
748
+ if hasattr(module, "set_processor"):
749
+ if not isinstance(processor, dict):
750
+ module.set_processor(processor)
751
+ else:
752
+ module.set_processor(processor.pop(f"{name}.processor"))
753
+
754
+ for sub_name, child in module.named_children():
755
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
756
+
757
+ for name, module in self.named_children():
758
+ fn_recursive_attn_processor(name, module, processor)
759
+
760
+ def set_default_attn_processor(self):
761
+ """
762
+ Disables custom attention processors and sets the default attention implementation.
763
+ """
764
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
765
+ processor = AttnAddedKVProcessor()
766
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
767
+ processor = AttnProcessor()
768
+ else:
769
+ raise ValueError(
770
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
771
+ )
772
+
773
+ self.set_attn_processor(processor)
774
+
775
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
776
+ r"""
777
+ Enable sliced attention computation.
778
+
779
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
780
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
781
+
782
+ Args:
783
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
784
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
785
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
786
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
787
+ must be a multiple of `slice_size`.
788
+ """
789
+ sliceable_head_dims = []
790
+
791
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
792
+ if hasattr(module, "set_attention_slice"):
793
+ sliceable_head_dims.append(module.sliceable_head_dim)
794
+
795
+ for child in module.children():
796
+ fn_recursive_retrieve_sliceable_dims(child)
797
+
798
+ # retrieve number of attention layers
799
+ for module in self.children():
800
+ fn_recursive_retrieve_sliceable_dims(module)
801
+
802
+ num_sliceable_layers = len(sliceable_head_dims)
803
+
804
+ if slice_size == "auto":
805
+ # half the attention head size is usually a good trade-off between
806
+ # speed and memory
807
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
808
+ elif slice_size == "max":
809
+ # make smallest slice possible
810
+ slice_size = num_sliceable_layers * [1]
811
+
812
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
813
+
814
+ if len(slice_size) != len(sliceable_head_dims):
815
+ raise ValueError(
816
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
817
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
818
+ )
819
+
820
+ for i in range(len(slice_size)):
821
+ size = slice_size[i]
822
+ dim = sliceable_head_dims[i]
823
+ if size is not None and size > dim:
824
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
825
+
826
+ # Recursively walk through all the children.
827
+ # Any children which exposes the set_attention_slice method
828
+ # gets the message
829
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
830
+ if hasattr(module, "set_attention_slice"):
831
+ module.set_attention_slice(slice_size.pop())
832
+
833
+ for child in module.children():
834
+ fn_recursive_set_attention_slice(child, slice_size)
835
+
836
+ reversed_slice_size = list(reversed(slice_size))
837
+ for module in self.children():
838
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
839
+
840
+ def _set_gradient_checkpointing(self, module, value=False):
841
+ if hasattr(module, "gradient_checkpointing"):
842
+ module.gradient_checkpointing = value
843
+
844
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
845
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
846
+
847
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
848
+
849
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
850
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
851
+
852
+ Args:
853
+ s1 (`float`):
854
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
855
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
856
+ s2 (`float`):
857
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
858
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
859
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
860
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
861
+ """
862
+ for i, upsample_block in enumerate(self.up_blocks):
863
+ setattr(upsample_block, "s1", s1)
864
+ setattr(upsample_block, "s2", s2)
865
+ setattr(upsample_block, "b1", b1)
866
+ setattr(upsample_block, "b2", b2)
867
+
868
+ def disable_freeu(self):
869
+ """Disables the FreeU mechanism."""
870
+ freeu_keys = {"s1", "s2", "b1", "b2"}
871
+ for i, upsample_block in enumerate(self.up_blocks):
872
+ for k in freeu_keys:
873
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
874
+ setattr(upsample_block, k, None)
875
+
876
+ def fuse_qkv_projections(self):
877
+ """
878
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
879
+ are fused. For cross-attention modules, key and value projection matrices are fused.
880
+
881
+ <Tip warning={true}>
882
+
883
+ This API is 🧪 experimental.
884
+
885
+ </Tip>
886
+ """
887
+ self.original_attn_processors = None
888
+
889
+ for _, attn_processor in self.attn_processors.items():
890
+ if "Added" in str(attn_processor.__class__.__name__):
891
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
892
+
893
+ self.original_attn_processors = self.attn_processors
894
+
895
+ for module in self.modules():
896
+ if isinstance(module, Attention):
897
+ module.fuse_projections(fuse=True)
898
+
899
+ self.set_attn_processor(FusedAttnProcessor2_0())
900
+
901
+ def unfuse_qkv_projections(self):
902
+ """Disables the fused QKV projection if enabled.
903
+
904
+ <Tip warning={true}>
905
+
906
+ This API is 🧪 experimental.
907
+
908
+ </Tip>
909
+
910
+ """
911
+ if self.original_attn_processors is not None:
912
+ self.set_attn_processor(self.original_attn_processors)
913
+
914
+ def get_time_embed(
915
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
916
+ ) -> Optional[torch.Tensor]:
917
+ timesteps = timestep
918
+ if not torch.is_tensor(timesteps):
919
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
920
+ # This would be a good case for the `match` statement (Python 3.10+)
921
+ is_mps = sample.device.type == "mps"
922
+ if isinstance(timestep, float):
923
+ dtype = torch.float32 if is_mps else torch.float64
924
+ else:
925
+ dtype = torch.int32 if is_mps else torch.int64
926
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
927
+ elif len(timesteps.shape) == 0:
928
+ timesteps = timesteps[None].to(sample.device)
929
+
930
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
931
+ timesteps = timesteps.expand(sample.shape[0])
932
+
933
+ t_emb = self.time_proj(timesteps)
934
+ # `Timesteps` does not contain any weights and will always return f32 tensors
935
+ # but time_embedding might actually be running in fp16. so we need to cast here.
936
+ # there might be better ways to encapsulate this.
937
+ t_emb = t_emb.to(dtype=sample.dtype)
938
+ return t_emb
939
+
940
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
941
+ class_emb = None
942
+ if self.class_embedding is not None:
943
+ if class_labels is None:
944
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
945
+
946
+ if self.config.class_embed_type == "timestep":
947
+ class_labels = self.time_proj(class_labels)
948
+
949
+ # `Timesteps` does not contain any weights and will always return f32 tensors
950
+ # there might be better ways to encapsulate this.
951
+ class_labels = class_labels.to(dtype=sample.dtype)
952
+
953
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
954
+ return class_emb
955
+
956
+ def get_aug_embed(
957
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
958
+ ) -> Optional[torch.Tensor]:
959
+ aug_emb = None
960
+ if self.config.addition_embed_type == "text":
961
+ aug_emb = self.add_embedding(encoder_hidden_states)
962
+ elif self.config.addition_embed_type == "text_image":
963
+ # Kandinsky 2.1 - style
964
+ if "image_embeds" not in added_cond_kwargs:
965
+ raise ValueError(
966
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
967
+ )
968
+
969
+ image_embs = added_cond_kwargs.get("image_embeds")
970
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
971
+ aug_emb = self.add_embedding(text_embs, image_embs)
972
+ elif self.config.addition_embed_type == "text_time":
973
+ # SDXL - style
974
+ if "text_embeds" not in added_cond_kwargs:
975
+ raise ValueError(
976
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
977
+ )
978
+ text_embeds = added_cond_kwargs.get("text_embeds")
979
+ if "time_ids" not in added_cond_kwargs:
980
+ raise ValueError(
981
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
982
+ )
983
+ time_ids = added_cond_kwargs.get("time_ids")
984
+ time_embeds = self.add_time_proj(time_ids.flatten())
985
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
986
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
987
+ add_embeds = add_embeds.to(emb.dtype)
988
+ aug_emb = self.add_embedding(add_embeds)
989
+ elif self.config.addition_embed_type == "image":
990
+ # Kandinsky 2.2 - style
991
+ if "image_embeds" not in added_cond_kwargs:
992
+ raise ValueError(
993
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
994
+ )
995
+ image_embs = added_cond_kwargs.get("image_embeds")
996
+ aug_emb = self.add_embedding(image_embs)
997
+ elif self.config.addition_embed_type == "image_hint":
998
+ # Kandinsky 2.2 - style
999
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1000
+ raise ValueError(
1001
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1002
+ )
1003
+ image_embs = added_cond_kwargs.get("image_embeds")
1004
+ hint = added_cond_kwargs.get("hint")
1005
+ aug_emb = self.add_embedding(image_embs, hint)
1006
+ return aug_emb
1007
+
1008
+ def process_encoder_hidden_states(
1009
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1010
+ ) -> torch.Tensor:
1011
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1012
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1013
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1014
+ # Kandinsky 2.1 - style
1015
+ if "image_embeds" not in added_cond_kwargs:
1016
+ raise ValueError(
1017
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1018
+ )
1019
+
1020
+ image_embeds = added_cond_kwargs.get("image_embeds")
1021
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1022
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1023
+ # Kandinsky 2.2 - style
1024
+ if "image_embeds" not in added_cond_kwargs:
1025
+ raise ValueError(
1026
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1027
+ )
1028
+ image_embeds = added_cond_kwargs.get("image_embeds")
1029
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1030
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1031
+ if "image_embeds" not in added_cond_kwargs:
1032
+ raise ValueError(
1033
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1034
+ )
1035
+
1036
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1037
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1038
+
1039
+ image_embeds = added_cond_kwargs.get("image_embeds")
1040
+ image_embeds = self.encoder_hid_proj(image_embeds)
1041
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1042
+ return encoder_hidden_states
1043
+
1044
+ def forward(
1045
+ self,
1046
+ sample: torch.Tensor,
1047
+ timestep: Union[torch.Tensor, float, int],
1048
+ encoder_hidden_states: torch.Tensor,
1049
+ class_labels: Optional[torch.Tensor] = None,
1050
+ timestep_cond: Optional[torch.Tensor] = None,
1051
+ attention_mask: Optional[torch.Tensor] = None,
1052
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1053
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1054
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1055
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1056
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1057
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1058
+ return_dict: bool = True,
1059
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1060
+ r"""
1061
+ The [`UNet2DConditionModel`] forward method.
1062
+
1063
+ Args:
1064
+ sample (`torch.Tensor`):
1065
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1066
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1067
+ encoder_hidden_states (`torch.Tensor`):
1068
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1069
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1070
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1071
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1072
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1073
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1074
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1075
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1076
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1077
+ negative values to the attention scores corresponding to "discard" tokens.
1078
+ cross_attention_kwargs (`dict`, *optional*):
1079
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1080
+ `self.processor` in
1081
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1082
+ added_cond_kwargs: (`dict`, *optional*):
1083
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1084
+ are passed along to the UNet blocks.
1085
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1086
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1087
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1088
+ A tensor that if specified is added to the residual of the middle unet block.
1089
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1090
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1091
+ encoder_attention_mask (`torch.Tensor`):
1092
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1093
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1094
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1095
+ return_dict (`bool`, *optional*, defaults to `True`):
1096
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1097
+ tuple.
1098
+
1099
+ Returns:
1100
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1101
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1102
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1103
+ """
1104
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1105
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1106
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1107
+ # on the fly if necessary.
1108
+ default_overall_up_factor = 2**self.num_upsamplers
1109
+
1110
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1111
+ forward_upsample_size = False
1112
+ upsample_size = None
1113
+
1114
+ for dim in sample.shape[-2:]:
1115
+ if dim % default_overall_up_factor != 0:
1116
+ # Forward upsample size to force interpolation output size.
1117
+ forward_upsample_size = True
1118
+ break
1119
+
1120
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1121
+ # expects mask of shape:
1122
+ # [batch, key_tokens]
1123
+ # adds singleton query_tokens dimension:
1124
+ # [batch, 1, key_tokens]
1125
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1126
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1127
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1128
+ if attention_mask is not None:
1129
+ # assume that mask is expressed as:
1130
+ # (1 = keep, 0 = discard)
1131
+ # convert mask into a bias that can be added to attention scores:
1132
+ # (keep = +0, discard = -10000.0)
1133
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1134
+ attention_mask = attention_mask.unsqueeze(1)
1135
+
1136
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1137
+ if encoder_attention_mask is not None:
1138
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1139
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1140
+
1141
+ # 0. center input if necessary
1142
+ if self.config.center_input_sample:
1143
+ sample = 2 * sample - 1.0
1144
+
1145
+ # 1. time
1146
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1147
+ emb = self.time_embedding(t_emb, timestep_cond)
1148
+ aug_emb = None
1149
+
1150
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1151
+ if class_emb is not None:
1152
+ if self.config.class_embeddings_concat:
1153
+ emb = torch.cat([emb, class_emb], dim=-1)
1154
+ else:
1155
+ emb = emb + class_emb
1156
+
1157
+ aug_emb = self.get_aug_embed(
1158
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1159
+ )
1160
+ if self.config.addition_embed_type == "image_hint":
1161
+ aug_emb, hint = aug_emb
1162
+ sample = torch.cat([sample, hint], dim=1)
1163
+
1164
+ emb = emb + aug_emb if aug_emb is not None else emb
1165
+
1166
+ if self.time_embed_act is not None:
1167
+ emb = self.time_embed_act(emb)
1168
+
1169
+ encoder_hidden_states = self.process_encoder_hidden_states(
1170
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1171
+ )
1172
+
1173
+ # 2. pre-process
1174
+ sample = self.conv_in(sample)
1175
+
1176
+ # 2.5 GLIGEN position net
1177
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1178
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1179
+ gligen_args = cross_attention_kwargs.pop("gligen")
1180
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1181
+
1182
+ # 3. down
1183
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1184
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1185
+ if cross_attention_kwargs is not None:
1186
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1187
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1188
+ else:
1189
+ lora_scale = 1.0
1190
+
1191
+ if USE_PEFT_BACKEND:
1192
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1193
+ scale_lora_layers(self, lora_scale)
1194
+
1195
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1196
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1197
+ is_adapter = down_intrablock_additional_residuals is not None
1198
+ # maintain backward compatibility for legacy usage, where
1199
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1200
+ # but can only use one or the other
1201
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1202
+ deprecate(
1203
+ "T2I should not use down_block_additional_residuals",
1204
+ "1.3.0",
1205
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1206
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1207
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1208
+ standard_warn=False,
1209
+ )
1210
+ down_intrablock_additional_residuals = down_block_additional_residuals
1211
+ is_adapter = True
1212
+
1213
+ down_block_res_samples = (sample,)
1214
+ for downsample_block in self.down_blocks:
1215
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1216
+ # For t2i-adapter CrossAttnDownBlock2D
1217
+ additional_residuals = {}
1218
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1219
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1220
+
1221
+ sample, res_samples = downsample_block(
1222
+ hidden_states=sample,
1223
+ temb=emb,
1224
+ encoder_hidden_states=encoder_hidden_states,
1225
+ attention_mask=attention_mask,
1226
+ cross_attention_kwargs=cross_attention_kwargs,
1227
+ encoder_attention_mask=encoder_attention_mask,
1228
+ **additional_residuals,
1229
+ )
1230
+ else:
1231
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1232
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1233
+ sample += down_intrablock_additional_residuals.pop(0)
1234
+
1235
+ down_block_res_samples += res_samples
1236
+
1237
+ if is_controlnet:
1238
+ new_down_block_res_samples = ()
1239
+
1240
+ for down_block_res_sample, down_block_additional_residual in zip(
1241
+ down_block_res_samples, down_block_additional_residuals
1242
+ ):
1243
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1244
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1245
+
1246
+ down_block_res_samples = new_down_block_res_samples
1247
+
1248
+ # 4. mid
1249
+ if self.mid_block is not None:
1250
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1251
+ sample = self.mid_block(
1252
+ sample,
1253
+ emb,
1254
+ encoder_hidden_states=encoder_hidden_states,
1255
+ attention_mask=attention_mask,
1256
+ cross_attention_kwargs=cross_attention_kwargs,
1257
+ encoder_attention_mask=encoder_attention_mask,
1258
+ )
1259
+ else:
1260
+ sample = self.mid_block(sample, emb)
1261
+
1262
+ # To support T2I-Adapter-XL
1263
+ if (
1264
+ is_adapter
1265
+ and len(down_intrablock_additional_residuals) > 0
1266
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1267
+ ):
1268
+ sample += down_intrablock_additional_residuals.pop(0)
1269
+
1270
+ if is_controlnet:
1271
+ sample = sample + mid_block_additional_residual
1272
+
1273
+ feat_64 = sample
1274
+
1275
+ # fe transform
1276
+ sample = self.fftblock(sample)
1277
+
1278
+ # 5. up
1279
+ for i, upsample_block in enumerate(self.up_blocks):
1280
+ is_final_block = i == len(self.up_blocks) - 1
1281
+
1282
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1283
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1284
+
1285
+ # if we have not reached the final block and need to forward the
1286
+ # upsample size, we do it here
1287
+ if not is_final_block and forward_upsample_size:
1288
+ upsample_size = down_block_res_samples[-1].shape[2:]
1289
+
1290
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1291
+ sample = upsample_block(
1292
+ hidden_states=sample,
1293
+ temb=emb,
1294
+ res_hidden_states_tuple=res_samples,
1295
+ encoder_hidden_states=encoder_hidden_states,
1296
+ cross_attention_kwargs=cross_attention_kwargs,
1297
+ upsample_size=upsample_size,
1298
+ attention_mask=attention_mask,
1299
+ encoder_attention_mask=encoder_attention_mask,
1300
+ )
1301
+ else:
1302
+ sample = upsample_block(
1303
+ hidden_states=sample,
1304
+ temb=emb,
1305
+ res_hidden_states_tuple=res_samples,
1306
+ upsample_size=upsample_size,
1307
+ )
1308
+
1309
+ # 6. post-process
1310
+ if self.conv_norm_out:
1311
+ sample = self.conv_norm_out(sample)
1312
+ sample = self.conv_act(sample)
1313
+ sample = self.conv_out(sample)
1314
+
1315
+ if USE_PEFT_BACKEND:
1316
+ # remove `lora_scale` from each PEFT layer
1317
+ unscale_lora_layers(self, lora_scale)
1318
+
1319
+ if not return_dict:
1320
+ return (sample,)
1321
+
1322
+ return UNet2DConditionOutput(sample=sample, feat_64=feat_64)
depthmaster/util/batchsize.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ import torch
27
+ import math
28
+
29
+
30
+ # Search table for suggested max. inference batch size
31
+ bs_search_table = [
32
+ # tested on A100-PCIE-80GB
33
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
34
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
35
+ # tested on A100-PCIE-40GB
36
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
37
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
38
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
39
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
40
+ # tested on RTX3090, RTX4090
41
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
42
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
43
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
44
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
45
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
46
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
47
+ # tested on GTX1080Ti
48
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
49
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
50
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
51
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
52
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
53
+ ]
54
+
55
+
56
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
57
+ """
58
+ Automatically search for suitable operating batch size.
59
+
60
+ Args:
61
+ ensemble_size (`int`):
62
+ Number of predictions to be ensembled.
63
+ input_res (`int`):
64
+ Operating resolution of the input image.
65
+
66
+ Returns:
67
+ `int`: Operating batch size.
68
+ """
69
+ if not torch.cuda.is_available():
70
+ return 1
71
+
72
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
73
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
74
+ for settings in sorted(
75
+ filtered_bs_search_table,
76
+ key=lambda k: (k["res"], -k["total_vram"]),
77
+ ):
78
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
79
+ bs = settings["bs"]
80
+ if bs > ensemble_size:
81
+ bs = ensemble_size
82
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
83
+ bs = math.ceil(ensemble_size / 2)
84
+ return bs
85
+
86
+ return 1
depthmaster/util/ensemble.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import numpy as np
30
+ import torch
31
+
32
+ from .image_util import get_tv_resample_method, resize_max_res
33
+
34
+
35
+ def inter_distances(tensors: torch.Tensor):
36
+ """
37
+ To calculate the distance between each two depth maps.
38
+ """
39
+ distances = []
40
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
41
+ arr1 = tensors[i : i + 1]
42
+ arr2 = tensors[j : j + 1]
43
+ distances.append(arr1 - arr2)
44
+ dist = torch.concatenate(distances, dim=0)
45
+ return dist
46
+
47
+
48
+ def ensemble_depth(
49
+ depth: torch.Tensor,
50
+ scale_invariant: bool = True,
51
+ shift_invariant: bool = True,
52
+ output_uncertainty: bool = False,
53
+ reduction: str = "median",
54
+ regularizer_strength: float = 0.02,
55
+ max_iter: int = 2,
56
+ tol: float = 1e-3,
57
+ max_res: int = 1024,
58
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
59
+ """
60
+ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
61
+ number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
62
+ depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
63
+ alignment happens when the predictions have one or more degrees of freedom, that is when they are either
64
+ affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
65
+ `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
66
+ alignment is skipped and only ensembling is performed.
67
+
68
+ Args:
69
+ depth (`torch.Tensor`):
70
+ Input ensemble depth maps.
71
+ scale_invariant (`bool`, *optional*, defaults to `True`):
72
+ Whether to treat predictions as scale-invariant.
73
+ shift_invariant (`bool`, *optional*, defaults to `True`):
74
+ Whether to treat predictions as shift-invariant.
75
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
76
+ Whether to output uncertainty map.
77
+ reduction (`str`, *optional*, defaults to `"median"`):
78
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
79
+ `"median"`.
80
+ regularizer_strength (`float`, *optional*, defaults to `0.02`):
81
+ Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
82
+ max_iter (`int`, *optional*, defaults to `2`):
83
+ Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
84
+ argument.
85
+ tol (`float`, *optional*, defaults to `1e-3`):
86
+ Alignment solver tolerance. The solver stops when the tolerance is reached.
87
+ max_res (`int`, *optional*, defaults to `1024`):
88
+ Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
89
+ Returns:
90
+ A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
91
+ `(1, 1, H, W)`.
92
+ """
93
+ if depth.dim() != 4 or depth.shape[1] != 1:
94
+ raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
95
+ if reduction not in ("mean", "median"):
96
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
97
+ if not scale_invariant and shift_invariant:
98
+ raise ValueError("Pure shift-invariant ensembling is not supported.")
99
+
100
+ def init_param(depth: torch.Tensor):
101
+ init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
102
+ init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
103
+
104
+ if scale_invariant and shift_invariant:
105
+ init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
106
+ init_t = -init_s * init_min
107
+ param = torch.cat((init_s, init_t)).cpu().numpy()
108
+ elif scale_invariant:
109
+ init_s = 1.0 / init_max.clamp(min=1e-6)
110
+ param = init_s.cpu().numpy()
111
+ else:
112
+ raise ValueError("Unrecognized alignment.")
113
+
114
+ return param
115
+
116
+ def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
117
+ if scale_invariant and shift_invariant:
118
+ s, t = np.split(param, 2)
119
+ s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
120
+ t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
121
+ out = depth * s + t
122
+ elif scale_invariant:
123
+ s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
124
+ out = depth * s
125
+ else:
126
+ raise ValueError("Unrecognized alignment.")
127
+ return out
128
+
129
+ def ensemble(
130
+ depth_aligned: torch.Tensor, return_uncertainty: bool = False
131
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
132
+ uncertainty = None
133
+ if reduction == "mean":
134
+ prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
135
+ if return_uncertainty:
136
+ uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
137
+ elif reduction == "median":
138
+ prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
139
+ if return_uncertainty:
140
+ uncertainty = torch.median(
141
+ torch.abs(depth_aligned - prediction), dim=0, keepdim=True
142
+ ).values
143
+ else:
144
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
145
+ return prediction, uncertainty
146
+
147
+ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
148
+ cost = 0.0
149
+ depth_aligned = align(depth, param)
150
+
151
+ for i, j in torch.combinations(torch.arange(ensemble_size)):
152
+ diff = depth_aligned[i] - depth_aligned[j]
153
+ cost += (diff**2).mean().sqrt().item()
154
+
155
+ if regularizer_strength > 0:
156
+ prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
157
+ err_near = (0.0 - prediction.min()).abs().item()
158
+ err_far = (1.0 - prediction.max()).abs().item()
159
+ cost += (err_near + err_far) * regularizer_strength
160
+
161
+ return cost
162
+
163
+ def compute_param(depth: torch.Tensor):
164
+ import scipy
165
+
166
+ depth_to_align = depth.to(torch.float32)
167
+ if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
168
+ depth_to_align = resize_max_res(
169
+ depth_to_align, max_res, get_tv_resample_method("nearest-exact")
170
+ )
171
+
172
+ param = init_param(depth_to_align)
173
+
174
+ res = scipy.optimize.minimize(
175
+ partial(cost_fn, depth=depth_to_align),
176
+ param,
177
+ method="BFGS",
178
+ tol=tol,
179
+ options={"maxiter": max_iter, "disp": False},
180
+ )
181
+
182
+ return res.x
183
+
184
+ requires_aligning = scale_invariant or shift_invariant
185
+ ensemble_size = depth.shape[0]
186
+
187
+ if requires_aligning:
188
+ param = compute_param(depth)
189
+ depth = align(depth, param)
190
+
191
+ depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
192
+
193
+ depth_max = depth.max()
194
+ if scale_invariant and shift_invariant:
195
+ depth_min = depth.min()
196
+ elif scale_invariant:
197
+ depth_min = 0
198
+ else:
199
+ raise ValueError("Unrecognized alignment.")
200
+ depth_range = (depth_max - depth_min).clamp(min=1e-6)
201
+ depth = (depth - depth_min) / depth_range
202
+ if output_uncertainty:
203
+ uncertainty /= depth_range
204
+
205
+ return depth, uncertainty # [1,1,H,W], [1,1,H,W]
depthmaster/util/image_util.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ import matplotlib
27
+ import numpy as np
28
+ import torch
29
+ from torchvision.transforms import InterpolationMode
30
+ from torchvision.transforms.functional import resize
31
+
32
+
33
+ def colorize_depth_maps(
34
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
35
+ ):
36
+ """
37
+ Colorize depth maps.
38
+ """
39
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
40
+
41
+ if isinstance(depth_map, torch.Tensor):
42
+ depth = depth_map.detach().squeeze().numpy()
43
+ elif isinstance(depth_map, np.ndarray):
44
+ depth = depth_map.copy().squeeze()
45
+ # reshape to [ (B,) H, W ]
46
+ if depth.ndim < 3:
47
+ depth = depth[np.newaxis, :, :]
48
+
49
+ # colorize
50
+ cm = matplotlib.colormaps[cmap]
51
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
52
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
53
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
54
+
55
+ if valid_mask is not None:
56
+ if isinstance(depth_map, torch.Tensor):
57
+ valid_mask = valid_mask.detach().numpy()
58
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
59
+ if valid_mask.ndim < 3:
60
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
61
+ else:
62
+ valid_mask = valid_mask[:, np.newaxis, :, :]
63
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
64
+ img_colored_np[~valid_mask] = 0
65
+
66
+ if isinstance(depth_map, torch.Tensor):
67
+ img_colored = torch.from_numpy(img_colored_np).float()
68
+ elif isinstance(depth_map, np.ndarray):
69
+ img_colored = img_colored_np
70
+
71
+ return img_colored
72
+
73
+
74
+ def chw2hwc(chw):
75
+ assert 3 == len(chw.shape)
76
+ if isinstance(chw, torch.Tensor):
77
+ hwc = torch.permute(chw, (1, 2, 0))
78
+ elif isinstance(chw, np.ndarray):
79
+ hwc = np.moveaxis(chw, 0, -1)
80
+ return hwc
81
+
82
+
83
+ def resize_max_res(
84
+ img: torch.Tensor,
85
+ max_edge_resolution: int,
86
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
87
+ ) -> torch.Tensor:
88
+ """
89
+ Resize image to limit maximum edge length while keeping aspect ratio.
90
+
91
+ Args:
92
+ img (`torch.Tensor`):
93
+ Image tensor to be resized. Expected shape: [B, C, H, W]
94
+ max_edge_resolution (`int`):
95
+ Maximum edge length (pixel).
96
+ resample_method (`PIL.Image.Resampling`):
97
+ Resampling method used to resize images.
98
+
99
+ Returns:
100
+ `torch.Tensor`: Resized image.
101
+ """
102
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
103
+
104
+ original_height, original_width = img.shape[-2:]
105
+ downscale_factor = min(
106
+ max_edge_resolution / original_width, max_edge_resolution / original_height
107
+ )
108
+
109
+ new_width = int(original_width * downscale_factor)
110
+ new_height = int(original_height * downscale_factor)
111
+
112
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
113
+ return resized_img
114
+
115
+
116
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
117
+ resample_method_dict = {
118
+ "bilinear": InterpolationMode.BILINEAR,
119
+ "bicubic": InterpolationMode.BICUBIC,
120
+ "nearest": InterpolationMode.NEAREST_EXACT,
121
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
122
+ }
123
+ resample_method = resample_method_dict.get(method_str, None)
124
+ if resample_method is None:
125
+ raise ValueError(f"Unknown resampling method: {resample_method}")
126
+ else:
127
+ return resample_method
requirements.txt CHANGED
@@ -1,6 +1,129 @@
1
- accelerate
2
- diffusers
3
  invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  invisible_watermark
2
+ absl-py==2.1.0
3
+ accelerate==0.31.0
4
+ aiohttp==3.9.5
5
+ aiosignal==1.3.1
6
+ antlr4-python3-runtime==4.9.3
7
+ asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ bitsandbytes==0.43.1
11
+ certifi==2024.6.2
12
+ charset-normalizer==3.3.2
13
+ click==8.1.7
14
+ comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
15
+ contourpy==1.2.1
16
+ cycler==0.12.1
17
+ datasets==2.19.2
18
+ debugpy @ file:///croot/debugpy_1690905042057/work
19
+ decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
20
+ diffusers==0.29.0
21
+ dill==0.3.8
22
+ docker-pycreds==0.4.0
23
+ einops==0.8.0
24
+ entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
25
+ exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1704921103267/work
26
+ executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
27
+ filelock==3.13.1
28
+ fonttools==4.53.0
29
+ frozenlist==1.4.1
30
+ fsspec==2024.2.0
31
+ gitdb==4.0.11
32
+ GitPython==3.1.43
33
+ grpcio==1.64.1
34
+ h5py==3.11.0
35
+ huggingface-hub==0.27.1
36
+ idna==3.7
37
+ imageio==2.34.1
38
+ imgaug==0.4.0
39
+ importlib_metadata==7.1.0
40
+ ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1717717528849/work
41
+ ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1717182742060/work
42
+ jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
43
+ Jinja2==3.1.3
44
+ jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
45
+ jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257277185/work
46
+ kiwisolver==1.4.5
47
+ lazy_loader==0.4
48
+ lightning-utilities==0.11.2
49
+ Markdown==3.6
50
+ MarkupSafe==2.1.5
51
+ matplotlib==3.9.0
52
+ matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1713250518406/work
53
+ mpmath==1.3.0
54
+ multidict==6.0.5
55
+ multiprocess==0.70.16
56
+ nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
57
+ networkx==3.2.1
58
+ numpy==1.26.3
59
+ nvidia-cublas-cu11==11.11.3.6
60
+ nvidia-cuda-cupti-cu11==11.8.87
61
+ nvidia-cuda-nvrtc-cu11==11.8.89
62
+ nvidia-cuda-runtime-cu11==11.8.89
63
+ nvidia-cudnn-cu11==8.7.0.84
64
+ nvidia-cufft-cu11==10.9.0.58
65
+ nvidia-curand-cu11==10.3.0.86
66
+ nvidia-cusolver-cu11==11.4.1.48
67
+ nvidia-cusparse-cu11==11.7.5.86
68
+ nvidia-nccl-cu11==2.20.5
69
+ nvidia-nvtx-cu11==11.8.86
70
+ omegaconf==2.3.0
71
+ opencv-python==4.10.0.82
72
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1718189413536/work
73
+ pandas==2.2.2
74
+ parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1712320355065/work
75
+ peft==0.11.1
76
+ pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
77
+ pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
78
+ pillow==10.2.0
79
+ platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1715777629804/work
80
+ prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1718047967974/work
81
+ protobuf==4.25.3
82
+ psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
83
+ ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
84
+ pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
85
+ pyarrow==16.1.0
86
+ pyarrow-hotfix==0.6
87
+ Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1714846767233/work
88
+ pyparsing==3.1.2
89
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1709299778482/work
90
+ pytorch-lightning==2.2.5
91
+ pytz==2024.1
92
+ PyYAML==6.0.1
93
+ pyzmq @ file:///croot/pyzmq_1705605076900/work
94
+ regex==2024.5.15
95
+ requests==2.32.3
96
+ safetensors==0.4.3
97
+ scikit-image==0.23.2
98
+ scipy==1.13.1
99
+ sentry-sdk==2.5.1
100
+ setproctitle==1.3.3
101
+ shapely==2.0.4
102
+ six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
103
+ smmap==5.0.1
104
+ stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
105
+ sympy==1.12
106
+ tabulate==0.9.0
107
+ tensorboard==2.17.0
108
+ tensorboard-data-server==0.7.2
109
+ tifffile==2024.5.22
110
+ tokenizers==0.19.1
111
+ torch==2.3.0+cu118
112
+ torchaudio==2.3.1+cu118
113
+ torchmetrics==1.4.0.post0
114
+ torchvision==0.18.1+cu118
115
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827254365/work
116
+ tqdm==4.66.4
117
+ traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1713535121073/work
118
+ transformers==4.41.2
119
+ triton==2.3.0
120
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1717802530399/work
121
+ tzdata==2024.1
122
+ urllib3==2.2.1
123
+ wandb==0.17.1
124
+ wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
125
+ Werkzeug==3.0.3
126
+ xformers==0.0.26.post1+cu118
127
+ xxhash==3.4.1
128
+ yarl==1.9.4
129
+ zipp==3.19.2
run.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2025-01-14
2
+ #
3
+ # Copyright 2025 Ziyang Song, USTC. All rights reserved.
4
+ #
5
+ # This file has been modified from the original version.
6
+ # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ # --------------------------------------------------------------------------
20
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
21
+ # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation
22
+ # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page
23
+ # --------------------------------------------------------------------------
24
+
25
+
26
+ import argparse
27
+ import logging
28
+ import os
29
+ from glob import glob
30
+
31
+ import numpy as np
32
+ import torch
33
+ from PIL import Image
34
+ from tqdm.auto import tqdm
35
+
36
+ from depthmaster import DepthMasterPipeline
37
+
38
+ EXTENSION_LIST = [".jpg", ".png"]
39
+
40
+
41
+ if "__main__" == __name__:
42
+ logging.basicConfig(level=logging.INFO)
43
+
44
+ # -------------------- Arguments --------------------
45
+ parser = argparse.ArgumentParser(
46
+ description="Run single-image depth estimation using Marigold."
47
+ )
48
+ parser.add_argument(
49
+ "--checkpoint",
50
+ type=str,
51
+ default="ckpt/depthmaster",
52
+ help="Checkpoint path or hub name.",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--input_rgb_dir",
57
+ type=str,
58
+ required=True,
59
+ help="Path to the input image folder.",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--output_dir", type=str, required=True, help="Output directory."
64
+ )
65
+
66
+
67
+ parser.add_argument(
68
+ "--half_precision",
69
+ "--fp16",
70
+ action="store_true",
71
+ help="Run with half-precision (16-bit float), might lead to suboptimal result.",
72
+ )
73
+
74
+ # resolution setting
75
+ parser.add_argument(
76
+ "--processing_res",
77
+ type=int,
78
+ default=None,
79
+ help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.",
80
+ )
81
+ parser.add_argument(
82
+ "--output_processing_res",
83
+ action="store_true",
84
+ help="When input is resized, out put depth at resized operating resolution. Default: False.",
85
+ )
86
+ parser.add_argument(
87
+ "--resample_method",
88
+ choices=["bilinear", "bicubic", "nearest"],
89
+ default="bilinear",
90
+ help="Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`",
91
+ )
92
+
93
+ # depth map colormap
94
+ parser.add_argument(
95
+ "--color_map",
96
+ type=str,
97
+ default="Spectral",
98
+ help="Colormap used to render depth predictions.",
99
+ )
100
+
101
+ # other settings
102
+ parser.add_argument(
103
+ "--batch_size",
104
+ type=int,
105
+ default=0,
106
+ help="Inference batch size. Default: 0 (will be set automatically).",
107
+ )
108
+ parser.add_argument(
109
+ "--apple_silicon",
110
+ action="store_true",
111
+ help="Flag of running on Apple Silicon.",
112
+ )
113
+
114
+ args = parser.parse_args()
115
+
116
+ checkpoint_path = args.checkpoint
117
+ input_rgb_dir = args.input_rgb_dir
118
+ output_dir = args.output_dir
119
+
120
+ half_precision = args.half_precision
121
+
122
+ processing_res = args.processing_res
123
+ match_input_res = not args.output_processing_res
124
+ if 0 == processing_res and match_input_res is False:
125
+ logging.warning(
126
+ "Processing at native resolution without resizing output might NOT lead to exactly the same resolution, due to the padding and pooling properties of conv layers."
127
+ )
128
+ resample_method = args.resample_method
129
+
130
+ color_map = args.color_map
131
+ batch_size = args.batch_size
132
+ apple_silicon = args.apple_silicon
133
+ if apple_silicon and 0 == batch_size:
134
+ batch_size = 1 # set default batchsize
135
+
136
+ # -------------------- Preparation --------------------
137
+ # Output directories
138
+ output_dir_color = os.path.join(output_dir, "depth_colored")
139
+ output_dir_tif = os.path.join(output_dir, "depth_bw")
140
+ # output_dir_npy = os.path.join(output_dir, "depth_npy")
141
+ os.makedirs(output_dir, exist_ok=True)
142
+ os.makedirs(output_dir_color, exist_ok=True)
143
+ os.makedirs(output_dir_tif, exist_ok=True)
144
+ # os.makedirs(output_dir_npy, exist_ok=True)
145
+ logging.info(f"output dir = {output_dir}")
146
+
147
+ # -------------------- Device --------------------
148
+ if apple_silicon:
149
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
150
+ device = torch.device("mps:0")
151
+ else:
152
+ device = torch.device("cpu")
153
+ logging.warning("MPS is not available. Running on CPU will be slow.")
154
+ else:
155
+ if torch.cuda.is_available():
156
+ device = torch.device("cuda")
157
+ else:
158
+ device = torch.device("cpu")
159
+ logging.warning("CUDA is not available. Running on CPU will be slow.")
160
+ logging.info(f"device = {device}")
161
+
162
+ # -------------------- Data --------------------
163
+ rgb_filename_list = glob(os.path.join(input_rgb_dir, "*"))
164
+ rgb_filename_list = [
165
+ f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST
166
+ ]
167
+ rgb_filename_list = sorted(rgb_filename_list)
168
+ n_images = len(rgb_filename_list)
169
+ if n_images > 0:
170
+ logging.info(f"Found {n_images} images")
171
+ else:
172
+ logging.error(f"No image found in '{input_rgb_dir}'")
173
+ exit(1)
174
+
175
+ # -------------------- Model --------------------
176
+ if half_precision:
177
+ dtype = torch.float16
178
+ variant = "fp16"
179
+ logging.info(
180
+ f"Running with half precision ({dtype}), might lead to suboptimal result."
181
+ )
182
+ else:
183
+ dtype = torch.float32
184
+ variant = None
185
+
186
+ pipe: DepthMasterPipeline = DepthMasterPipeline.from_pretrained(
187
+ checkpoint_path, variant=variant, torch_dtype=dtype
188
+ )
189
+
190
+ try:
191
+ pipe.enable_xformers_memory_efficient_attention()
192
+ except ImportError:
193
+ pass # run without xformers
194
+
195
+ pipe = pipe.to(device)
196
+ logging.info(
197
+ f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}"
198
+ )
199
+
200
+ # Print out config
201
+ logging.info(
202
+ f"Inference settings: checkpoint = `{checkpoint_path}`, "
203
+ f"processing resolution = {processing_res or pipe.default_processing_resolution}, "
204
+ f"color_map = {color_map}."
205
+ )
206
+
207
+ # -------------------- Inference and saving --------------------
208
+ with torch.no_grad():
209
+ os.makedirs(output_dir, exist_ok=True)
210
+
211
+ for rgb_path in tqdm(rgb_filename_list, desc="Estimating depth", leave=True):
212
+ # Read input image
213
+ input_image = Image.open(rgb_path)
214
+
215
+ # Predict depth
216
+ with torch.no_grad():
217
+ pipe_out = pipe(
218
+ input_image,
219
+ processing_res=processing_res,
220
+ match_input_res=match_input_res,
221
+ batch_size=batch_size,
222
+ color_map=color_map,
223
+ show_progress_bar=True,
224
+ resample_method=resample_method,
225
+ )
226
+
227
+ depth_pred: np.ndarray = pipe_out.depth_np
228
+ depth_colored: Image.Image = pipe_out.depth_colored
229
+
230
+ # Save as npy
231
+ rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0]
232
+ pred_name_base = rgb_name_base + "_pred"
233
+ # npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy")
234
+ # if os.path.exists(npy_save_path):
235
+ # logging.warning(f"Existing file: '{npy_save_path}' will be overwritten")
236
+ # np.save(npy_save_path, depth_pred)
237
+
238
+ # Save as 16-bit uint png
239
+ depth_to_save = (depth_pred * 65535.0).astype(np.uint16)
240
+ png_save_path = os.path.join(output_dir_tif, f"{pred_name_base}.png")
241
+ if os.path.exists(png_save_path):
242
+ logging.warning(f"Existing file: '{png_save_path}' will be overwritten")
243
+ Image.fromarray(depth_to_save).save(png_save_path, mode="I;16")
244
+
245
+ # Colorize
246
+ colored_save_path = os.path.join(
247
+ output_dir_color, f"{pred_name_base}_colored.png"
248
+ )
249
+ if os.path.exists(colored_save_path):
250
+ logging.warning(
251
+ f"Existing file: '{colored_save_path}' will be overwritten"
252
+ )
253
+ depth_colored.save(colored_save_path)