blanchon commited on
Commit
a9af355
Β·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auto detect text files and perform LF normalization
2
+ *.7z filter=lfs diff=lfs merge=lfs -text
3
+ *.arrow filter=lfs diff=lfs merge=lfs -text
4
+ *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.wasm filter=lfs diff=lfs merge=lfs -text
33
+ *.xz filter=lfs diff=lfs merge=lfs -text
34
+ *.zip filter=lfs diff=lfs merge=lfs -text
35
+ *.zst filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ * text=auto
38
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+
3
+ .venv/
4
+ rgb2x/model_cache
5
+ x2rgb/model_cache
LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADOBE RESEARCH LICENSE
2
+
3
+ This license agreement (the β€œLicense”) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (β€œAdobe”), and you, the individual or entity exercising rights under this License (β€œyou” or β€œyour”), sets forth the terms for your use of certain research materials that are owned by Adobe (the β€œLicensed Materials”). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this license on behalf of an entity, then β€œyou” means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License.
4
+
5
+ 1. **GRANT OF LICENSE.**
6
+
7
+ 1.1. Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License.
8
+
9
+ 1.2. You may add your own copyright statement to your modifications and may provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.
10
+
11
+ 1.3. For purposes of this License, noncommercial research purposes include academic research and teaching but do not include commercial licensing or distribution, development of commercial products, or any other activity which results in commercial gain.
12
+
13
+
14
+ 2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must keep intact any copyright or other notices or disclaimers in the Licensed Materials.
15
+
16
+ 3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED β€œAS IS” WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE RESULTS AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT OF THIRD-PARTY RIGHTS.
17
+
18
+ 4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES OF ANY NATURE WHATSOEVER, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
19
+
20
+ 5. **TERM AND TERMINATION.**
21
+
22
+ 5.1. The License is effective upon acceptance by you and will remain in effect unless terminated earlier as permitted under this License.
23
+
24
+ 5.2. If you breach any material provision of this License, then your rights will terminate immediately.
25
+
26
+ 5.3. All clauses which by their nature should survive the termination of this License will survive such termination. In addition, and without limiting the generality of the preceding sentence, Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), and 4 (Limitation of Liability) will survive termination of this License.
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rgbx
3
+ emoji: πŸš€
4
+ colorFrom: gray
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.5.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+ <h1 align="center"> RGB↔X: Image Decomposition and Synthesis Using Material- and Lighting-aware Diffusion Models </h1>
12
+
13
+ <p align="center"><a href="https://zheng95z.github.io/" target="_blank">Zheng Zeng</a>, <a href="https://valentin.deschaintre.fr/" target="_blank">Valentin Deschaintre</a>, <a href="https://www.iliyan.com/" target="_blank">Iliyan Georgiev</a>, <a href="https://yannickhold.com/" target="_blank">Yannick Hold-Geoffroy</a>, <a href="https://yiweihu.netlify.app/" target="_blank">Yiwei Hu</a>, <a href="https://luanfujun.com/" target="_blank">Fujun Luan</a>, <a href="https://sites.cs.ucsb.edu/~lingqi/" target="_blank">Ling-Qi Yan</a>, <a href="http://www.miloshasan.net/" target="_blank">MiloΕ‘ HaΕ‘an</a></p>
14
+
15
+ <p align="center">ACM SIGGRAPH 2024</p>
16
+
17
+ <p align="center"><img src="assets/rgbx24_teaser.png"></p>
18
+
19
+ The three areas of realistic forward rendering, per-pixel inverse rendering, and generative image synthesis may seem like separate and unrelated sub-fields of graphics and vision. However, recent work has demonstrated improved estimation of per-pixel intrinsic channels (albedo, roughness, metallicity) based on a diffusion architecture; we call this the RGB→X problem. We further show that the reverse problem of synthesizing realistic images given intrinsic channels, X→RGB, can also be addressed in a diffusion framework.
20
+
21
+ Focusing on the image domain of interior scenes, we introduce an improved diffusion model for RGB→X, which also estimates lighting, as well as the first diffusion X→RGB model capable of synthesizing realistic images from (full or partial) intrinsic channels. Our X→RGB model explores a middle ground between traditional rendering and generative models: we can specify only certain appearance properties that should be followed, and give freedom to the model to hallucinate a plausible version of the rest.
22
+
23
+ This flexibility makes it possible to use a mix of heterogeneous training datasets, which differ in the available channels. We use multiple existing datasets and extend them with our own synthetic and real data, resulting in a model capable of extracting scene properties better than previous work and of generating highly realistic images of interior scenes.
24
+
25
+ ## Structure
26
+ ```
27
+ β”œβ”€β”€ assets <- Assets used by the README.md
28
+ β”œβ”€β”€ rgb2x <- Code for the RGBβ†’X model
29
+ β”‚ β”œβ”€β”€ example <- Example photo
30
+ β”‚ └── model_cache <- Model weights (automatically downloaded when running the inference script)
31
+ β”œβ”€β”€ x2rgb <- Code for the Xβ†’RGB model
32
+ β”‚ β”œβ”€β”€ example <- Example photo
33
+ β”‚ └── model_cache <- Model weights (automatically downloaded when running the inference script)
34
+ β”œβ”€β”€ environment.yaml <- Env file for creating conda environment
35
+ β”œβ”€β”€ LICENSE
36
+ └── README.md
37
+ ```
38
+
39
+ ## Model Weights
40
+ You don't need to manually download the model weights. The weights will be downloaded automatically to `/rgb2x/model_cache/` and `/x2rgb/model_cache/` when you run the inference scripts.
41
+
42
+ You can manually acquire the weights by cloning the models from Hugging Face:
43
+ ```bash
44
+ git-lfs install
45
+ git clone https://huggingface.co/zheng95z/x-to-rgb
46
+ git clone https://huggingface.co/zheng95z/rgb-to-x
47
+ ```
48
+
49
+ ## Installation
50
+ Create a conda environment using the provided `environment.yaml` file.
51
+
52
+ ```bash
53
+ conda env create -n rgbx -f environment.yaml
54
+ conda activate rgbx
55
+ ```
56
+
57
+ Note that this environment is only compatible with NVIDIA GPUs. Additionally, we recommend using a GPU with a minimum of 12GB of memory.
58
+
59
+ ## Inference
60
+ When you run the inference scripts, gradio demos will be hosted on your local machine. You can access the demos by opening the URLs (shown in the terminal) in your browser.
61
+
62
+ ### RGB→X
63
+ ```bash
64
+ cd rgb2x
65
+ python gradio_demo_rgb2x.py
66
+ ```
67
+
68
+ **Please note that the metallicity channel prediction might behave differently between the demo and the paper. This is because the demo utilizes a checkpoint that predicts roughness and metallicity separately, whereas in the paper, we used a checkpoint where the roughness and metallicity channels were combined into a single RGB image (with the blue channel set to 0). Unfortunately, the latter checkpoint was lost during the transition between computing platforms, and we apologize for the inconvenience. We plan to resolve this issue and will provide an updated demo in the near future.**
69
+
70
+ ### X→RGB
71
+ ```bash
72
+ cd x2rgb
73
+ python gradio_demo_x2rgb.py
74
+ ```
75
+
76
+ ## Acknowledgements
77
+
78
+ This implementation builds upon Hugging Face’s [Diffusers](https://github.com/huggingface/diffusers) library. We also acknowledge [Gradio](https://www.gradio.app/) for providing an easy-to-use interface that allowed us to create the inference demos for our models.
assets/rgbx24_teaser.png ADDED

Git LFS Details

  • SHA256: 19c429cd26b2eb8dd9565d11d7a7a1107f350c82be9e9ef7c1e813e7b6eb43b4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.05 MB
environment.yml ADDED
Binary file (648 Bytes). View file
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchaudio==2.5.1
3
+ torchvision==0.20.1
4
+ diffusers==0.20.0
5
+ gradio==5.5.0
6
+ imageio==2.34.1
7
+ numpy==1.26.4
8
+ opencv-python==4.9.0.80
9
+ transformers==4.40.2
10
+ spaces==0.30.4
rgb2x/example/Castlereagh_corridor_photo.png ADDED

Git LFS Details

  • SHA256: 8f77a445168dd92b97e214034f11291b8b3c0d98f3f12e34d591f56c39998fb4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
rgb2x/gradio_demo_rgb2x.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ from typing import cast
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import torch
7
+ import torchvision
8
+ from diffusers import DDIMScheduler
9
+ from load_image import load_exr_image, load_ldr_image
10
+ from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
11
+
12
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
+
14
+ current_directory = os.path.dirname(os.path.abspath(__file__))
15
+
16
+ _pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
17
+ "zheng95z/rgb-to-x",
18
+ torch_dtype=torch.float16,
19
+ cache_dir=os.path.join(current_directory, "model_cache"),
20
+ ).to("cuda")
21
+ pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
22
+ pipe.scheduler = DDIMScheduler.from_config(
23
+ pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
24
+ )
25
+ pipe.set_progress_bar_config(disable=True)
26
+ pipe.to("cuda")
27
+ pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)
28
+
29
+
30
+ @spaces.GPU
31
+ def generate(
32
+ photo,
33
+ seed: int,
34
+ inference_step: int,
35
+ num_samples: int,
36
+ ) -> list[Image.Image]:
37
+ generator = torch.Generator(device="cuda").manual_seed(seed)
38
+
39
+ if photo.name.endswith(".exr"):
40
+ photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
41
+ elif (
42
+ photo.name.endswith(".png")
43
+ or photo.name.endswith(".jpg")
44
+ or photo.name.endswith(".jpeg")
45
+ ):
46
+ photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
47
+
48
+ # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
49
+ old_height = photo.shape[1]
50
+ old_width = photo.shape[2]
51
+ new_height = old_height
52
+ new_width = old_width
53
+ radio = old_height / old_width
54
+ max_side = 1000
55
+ if old_height > old_width:
56
+ new_height = max_side
57
+ new_width = int(new_height / radio)
58
+ else:
59
+ new_width = max_side
60
+ new_height = int(new_width * radio)
61
+
62
+ if new_width % 8 != 0 or new_height % 8 != 0:
63
+ new_width = new_width // 8 * 8
64
+ new_height = new_height // 8 * 8
65
+
66
+ photo = torchvision.transforms.Resize((new_height, new_width))(photo)
67
+
68
+ required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
69
+ prompts = {
70
+ "albedo": "Albedo (diffuse basecolor)",
71
+ "normal": "Camera-space Normal",
72
+ "roughness": "Roughness",
73
+ "metallic": "Metallicness",
74
+ "irradiance": "Irradiance (diffuse lighting)",
75
+ }
76
+
77
+ return_list = []
78
+ for i in range(num_samples):
79
+ for aov_name in required_aovs:
80
+ prompt = prompts[aov_name]
81
+ generated_image = pipe(
82
+ prompt=prompt,
83
+ photo=photo,
84
+ num_inference_steps=inference_step,
85
+ height=new_height,
86
+ width=new_width,
87
+ generator=generator,
88
+ required_aovs=[aov_name],
89
+ ).images[0][0] # type: ignore
90
+
91
+ generated_image = torchvision.transforms.Resize((old_height, old_width))(
92
+ generated_image
93
+ )
94
+
95
+ generated_image = (generated_image, f"Generated {aov_name} {i}")
96
+ return_list.append(generated_image)
97
+
98
+ return return_list
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ with gr.Row():
103
+ gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
104
+ with gr.Row():
105
+ # Input side
106
+ with gr.Column():
107
+ gr.Markdown("### Given Image")
108
+ photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
109
+
110
+ gr.Markdown("### Parameters")
111
+ run_button = gr.Button(value="Run")
112
+ with gr.Accordion("Advanced options", open=False):
113
+ seed = gr.Slider(
114
+ label="Seed",
115
+ minimum=-1,
116
+ maximum=2147483647,
117
+ step=1,
118
+ randomize=True,
119
+ )
120
+ inference_step = gr.Slider(
121
+ label="Inference Step",
122
+ minimum=1,
123
+ maximum=100,
124
+ step=1,
125
+ value=50,
126
+ )
127
+ num_samples = gr.Slider(
128
+ label="Samples",
129
+ minimum=1,
130
+ maximum=100,
131
+ step=1,
132
+ value=1,
133
+ )
134
+
135
+ # Output side
136
+ with gr.Column():
137
+ gr.Markdown("### Output Gallery")
138
+ result_gallery = gr.Gallery(
139
+ label="Output",
140
+ show_label=False,
141
+ elem_id="gallery",
142
+ columns=2,
143
+ )
144
+
145
+ run_button.click(
146
+ fn=generate,
147
+ inputs=[photo, seed, inference_step, num_samples],
148
+ outputs=result_gallery,
149
+ queue=True,
150
+ )
151
+
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch(debug=False, share=False, show_api=False)
rgb2x/load_image.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+
6
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
7
+ import numpy as np
8
+
9
+
10
+ def convert_rgb_2_XYZ(rgb):
11
+ # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
12
+ # rgb: (h, w, 3)
13
+ # XYZ: (h, w, 3)
14
+ XYZ = torch.ones_like(rgb)
15
+ XYZ[:, :, 0] = (
16
+ 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
17
+ )
18
+ XYZ[:, :, 1] = (
19
+ 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
20
+ )
21
+ XYZ[:, :, 2] = (
22
+ 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
23
+ )
24
+ return XYZ
25
+
26
+
27
+ def convert_XYZ_2_Yxy(XYZ):
28
+ # XYZ: (h, w, 3)
29
+ # Yxy: (h, w, 3)
30
+ Yxy = torch.ones_like(XYZ)
31
+ Yxy[:, :, 0] = XYZ[:, :, 1]
32
+ sum = torch.sum(XYZ, dim=2)
33
+ inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
34
+ Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
35
+ Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
36
+ return Yxy
37
+
38
+
39
+ def convert_rgb_2_Yxy(rgb):
40
+ # rgb: (h, w, 3)
41
+ # Yxy: (h, w, 3)
42
+ return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
43
+
44
+
45
+ def convert_XYZ_2_rgb(XYZ):
46
+ # XYZ: (h, w, 3)
47
+ # rgb: (h, w, 3)
48
+ rgb = torch.ones_like(XYZ)
49
+ rgb[:, :, 0] = (
50
+ 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
51
+ )
52
+ rgb[:, :, 1] = (
53
+ -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
54
+ )
55
+ rgb[:, :, 2] = (
56
+ 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
57
+ )
58
+ return rgb
59
+
60
+
61
+ def convert_Yxy_2_XYZ(Yxy):
62
+ # Yxy: (h, w, 3)
63
+ # XYZ: (h, w, 3)
64
+ XYZ = torch.ones_like(Yxy)
65
+ XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
66
+ XYZ[:, :, 1] = Yxy[:, :, 0]
67
+ XYZ[:, :, 2] = (
68
+ (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
69
+ / torch.clamp(Yxy[:, :, 2], min=1e-4)
70
+ * Yxy[:, :, 0]
71
+ )
72
+ return XYZ
73
+
74
+
75
+ def convert_Yxy_2_rgb(Yxy):
76
+ # Yxy: (h, w, 3)
77
+ # rgb: (h, w, 3)
78
+ return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
79
+
80
+
81
+ def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
82
+ # Load png or jpg image
83
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
84
+ image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
85
+ image[~torch.isfinite(image)] = 0
86
+ if from_srgb:
87
+ # Convert from sRGB to linear RGB
88
+ image = image**2.2
89
+ if clamp:
90
+ image = torch.clamp(image, min=0.0, max=1.0)
91
+ if normalize:
92
+ # Normalize to [-1, 1]
93
+ image = image * 2.0 - 1.0
94
+ image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
95
+ return image.permute(2, 0, 1) # returns (c, h, w)
96
+
97
+
98
+ def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
99
+ image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
100
+ image = torch.from_numpy(image.astype("float32")) # (h, w, c)
101
+ image[~torch.isfinite(image)] = 0
102
+ if tonemaping:
103
+ # Exposure adjuestment
104
+ image_Yxy = convert_rgb_2_Yxy(image)
105
+ lum = (
106
+ image[:, :, 0:1] * 0.2125
107
+ + image[:, :, 1:2] * 0.7154
108
+ + image[:, :, 2:3] * 0.0721
109
+ )
110
+ lum = torch.log(torch.clamp(lum, min=1e-6))
111
+ lum_mean = torch.exp(torch.mean(lum))
112
+ lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
113
+ image_Yxy[:, :, 0:1] = lp
114
+ image = convert_Yxy_2_rgb(image_Yxy)
115
+ if clamp:
116
+ image = torch.clamp(image, min=0.0, max=1.0)
117
+ if normalize:
118
+ image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
119
+ return image.permute(2, 0, 1) # returns (c, h, w)
rgb2x/pipeline_rgb2x.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ from diffusers.configuration_utils import register_to_config
9
+ from diffusers.image_processor import VaeImageProcessor
10
+ from diffusers.loaders import (
11
+ LoraLoaderMixin,
12
+ TextualInversionLoaderMixin,
13
+ )
14
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
17
+ rescale_noise_cfg,
18
+ )
19
+ from diffusers.schedulers import KarrasDiffusionSchedulers
20
+ from diffusers.utils import (
21
+ CONFIG_NAME,
22
+ BaseOutput,
23
+ deprecate,
24
+ logging,
25
+ )
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from transformers import CLIPTextModel, CLIPTokenizer
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class VaeImageProcrssorAOV(VaeImageProcessor):
33
+ """
34
+ Image processor for VAE AOV.
35
+
36
+ Args:
37
+ do_resize (`bool`, *optional*, defaults to `True`):
38
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
39
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
40
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
41
+ resample (`str`, *optional*, defaults to `lanczos`):
42
+ Resampling filter to use when resizing the image.
43
+ do_normalize (`bool`, *optional*, defaults to `True`):
44
+ Whether to normalize the image to [-1,1].
45
+ """
46
+
47
+ config_name = CONFIG_NAME
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ do_resize: bool = True,
53
+ vae_scale_factor: int = 8,
54
+ resample: str = "lanczos",
55
+ do_normalize: bool = True,
56
+ ):
57
+ super().__init__()
58
+
59
+ def postprocess(
60
+ self,
61
+ image: torch.FloatTensor,
62
+ output_type: str = "pil",
63
+ do_denormalize: Optional[List[bool]] = None,
64
+ do_gamma_correction: bool = True,
65
+ ):
66
+ if not isinstance(image, torch.Tensor):
67
+ raise ValueError(
68
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
69
+ )
70
+ if output_type not in ["latent", "pt", "np", "pil"]:
71
+ deprecation_message = (
72
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
73
+ "`pil`, `np`, `pt`, `latent`"
74
+ )
75
+ deprecate(
76
+ "Unsupported output_type",
77
+ "1.0.0",
78
+ deprecation_message,
79
+ standard_warn=False,
80
+ )
81
+ output_type = "np"
82
+
83
+ if output_type == "latent":
84
+ return image
85
+
86
+ if do_denormalize is None:
87
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
88
+
89
+ image = torch.stack(
90
+ [
91
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
92
+ for i in range(image.shape[0])
93
+ ]
94
+ )
95
+
96
+ # Gamma correction
97
+ if do_gamma_correction:
98
+ image = torch.pow(image, 1.0 / 2.2)
99
+
100
+ if output_type == "pt":
101
+ return image
102
+
103
+ image = self.pt_to_numpy(image)
104
+
105
+ if output_type == "np":
106
+ return image
107
+
108
+ if output_type == "pil":
109
+ return self.numpy_to_pil(image)
110
+
111
+ def preprocess_normal(
112
+ self,
113
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
114
+ height: Optional[int] = None,
115
+ width: Optional[int] = None,
116
+ ) -> torch.Tensor:
117
+ image = torch.stack([image], axis=0)
118
+ return image
119
+
120
+
121
+ @dataclass
122
+ class StableDiffusionAOVPipelineOutput(BaseOutput):
123
+ """
124
+ Output class for Stable Diffusion AOV pipelines.
125
+
126
+ Args:
127
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
128
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
129
+ num_channels)`.
130
+ nsfw_content_detected (`List[bool]`)
131
+ List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
132
+ `None` if safety checking could not be performed.
133
+ """
134
+
135
+ images: Union[List[PIL.Image.Image], np.ndarray]
136
+
137
+
138
+ class StableDiffusionAOVMatEstPipeline(
139
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
140
+ ):
141
+ r"""
142
+ Pipeline for AOVs.
143
+
144
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
145
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
146
+
147
+ The pipeline also inherits the following loading methods:
148
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
149
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
150
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
151
+
152
+ Args:
153
+ vae ([`AutoencoderKL`]):
154
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
155
+ text_encoder ([`~transformers.CLIPTextModel`]):
156
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
157
+ tokenizer ([`~transformers.CLIPTokenizer`]):
158
+ A `CLIPTokenizer` to tokenize text.
159
+ unet ([`UNet2DConditionModel`]):
160
+ A `UNet2DConditionModel` to denoise the encoded image latents.
161
+ scheduler ([`SchedulerMixin`]):
162
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
163
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ vae: AutoencoderKL,
169
+ text_encoder: CLIPTextModel,
170
+ tokenizer: CLIPTokenizer,
171
+ unet: UNet2DConditionModel,
172
+ scheduler: KarrasDiffusionSchedulers,
173
+ ):
174
+ super().__init__()
175
+
176
+ self.register_modules(
177
+ vae=vae,
178
+ text_encoder=text_encoder,
179
+ tokenizer=tokenizer,
180
+ unet=unet,
181
+ scheduler=scheduler,
182
+ )
183
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184
+ self.image_processor = VaeImageProcrssorAOV(
185
+ vae_scale_factor=self.vae_scale_factor
186
+ )
187
+ self.register_to_config()
188
+
189
+ def _encode_prompt(
190
+ self,
191
+ prompt,
192
+ device,
193
+ num_images_per_prompt,
194
+ do_classifier_free_guidance,
195
+ negative_prompt=None,
196
+ prompt_embeds: Optional[torch.FloatTensor] = None,
197
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
198
+ ):
199
+ r"""
200
+ Encodes the prompt into text encoder hidden states.
201
+
202
+ Args:
203
+ prompt (`str` or `List[str]`, *optional*):
204
+ prompt to be encoded
205
+ device: (`torch.device`):
206
+ torch device
207
+ num_images_per_prompt (`int`):
208
+ number of images that should be generated per prompt
209
+ do_classifier_free_guidance (`bool`):
210
+ whether to use classifier free guidance or not
211
+ negative_ prompt (`str` or `List[str]`, *optional*):
212
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
213
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
214
+ less than `1`).
215
+ prompt_embeds (`torch.FloatTensor`, *optional*):
216
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
217
+ provided, text embeddings will be generated from `prompt` input argument.
218
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
219
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
220
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
221
+ argument.
222
+ """
223
+ if prompt is not None and isinstance(prompt, str):
224
+ batch_size = 1
225
+ elif prompt is not None and isinstance(prompt, list):
226
+ batch_size = len(prompt)
227
+ else:
228
+ batch_size = prompt_embeds.shape[0]
229
+
230
+ if prompt_embeds is None:
231
+ # textual inversion: procecss multi-vector tokens if necessary
232
+ if isinstance(self, TextualInversionLoaderMixin):
233
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
234
+
235
+ text_inputs = self.tokenizer(
236
+ prompt,
237
+ padding="max_length",
238
+ max_length=self.tokenizer.model_max_length,
239
+ truncation=True,
240
+ return_tensors="pt",
241
+ )
242
+ text_input_ids = text_inputs.input_ids
243
+ untruncated_ids = self.tokenizer(
244
+ prompt, padding="longest", return_tensors="pt"
245
+ ).input_ids
246
+
247
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
248
+ -1
249
+ ] and not torch.equal(text_input_ids, untruncated_ids):
250
+ removed_text = self.tokenizer.batch_decode(
251
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
252
+ )
253
+ logger.warning(
254
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
255
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256
+ )
257
+
258
+ if (
259
+ hasattr(self.text_encoder.config, "use_attention_mask")
260
+ and self.text_encoder.config.use_attention_mask
261
+ ):
262
+ attention_mask = text_inputs.attention_mask.to(device)
263
+ else:
264
+ attention_mask = None
265
+
266
+ prompt_embeds = self.text_encoder(
267
+ text_input_ids.to(device),
268
+ attention_mask=attention_mask,
269
+ )
270
+ prompt_embeds = prompt_embeds[0]
271
+
272
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
273
+
274
+ bs_embed, seq_len, _ = prompt_embeds.shape
275
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
276
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
277
+ prompt_embeds = prompt_embeds.view(
278
+ bs_embed * num_images_per_prompt, seq_len, -1
279
+ )
280
+
281
+ # get unconditional embeddings for classifier free guidance
282
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
283
+ uncond_tokens: List[str]
284
+ if negative_prompt is None:
285
+ uncond_tokens = [""] * batch_size
286
+ elif type(prompt) is not type(negative_prompt):
287
+ raise TypeError(
288
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
289
+ f" {type(prompt)}."
290
+ )
291
+ elif isinstance(negative_prompt, str):
292
+ uncond_tokens = [negative_prompt]
293
+ elif batch_size != len(negative_prompt):
294
+ raise ValueError(
295
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
296
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
297
+ " the batch size of `prompt`."
298
+ )
299
+ else:
300
+ uncond_tokens = negative_prompt
301
+
302
+ # textual inversion: procecss multi-vector tokens if necessary
303
+ if isinstance(self, TextualInversionLoaderMixin):
304
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
305
+
306
+ max_length = prompt_embeds.shape[1]
307
+ uncond_input = self.tokenizer(
308
+ uncond_tokens,
309
+ padding="max_length",
310
+ max_length=max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ if (
316
+ hasattr(self.text_encoder.config, "use_attention_mask")
317
+ and self.text_encoder.config.use_attention_mask
318
+ ):
319
+ attention_mask = uncond_input.attention_mask.to(device)
320
+ else:
321
+ attention_mask = None
322
+
323
+ negative_prompt_embeds = self.text_encoder(
324
+ uncond_input.input_ids.to(device),
325
+ attention_mask=attention_mask,
326
+ )
327
+ negative_prompt_embeds = negative_prompt_embeds[0]
328
+
329
+ if do_classifier_free_guidance:
330
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
331
+ seq_len = negative_prompt_embeds.shape[1]
332
+
333
+ negative_prompt_embeds = negative_prompt_embeds.to(
334
+ dtype=self.text_encoder.dtype, device=device
335
+ )
336
+
337
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
338
+ 1, num_images_per_prompt, 1
339
+ )
340
+ negative_prompt_embeds = negative_prompt_embeds.view(
341
+ batch_size * num_images_per_prompt, seq_len, -1
342
+ )
343
+
344
+ # For classifier free guidance, we need to do two forward passes.
345
+ # Here we concatenate the unconditional and text embeddings into a single batch
346
+ # to avoid doing two forward passes
347
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
348
+ prompt_embeds = torch.cat(
349
+ [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
350
+ )
351
+
352
+ return prompt_embeds
353
+
354
+ def prepare_extra_step_kwargs(self, generator, eta):
355
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
356
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
357
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
358
+ # and should be between [0, 1]
359
+
360
+ accepts_eta = "eta" in set(
361
+ inspect.signature(self.scheduler.step).parameters.keys()
362
+ )
363
+ extra_step_kwargs = {}
364
+ if accepts_eta:
365
+ extra_step_kwargs["eta"] = eta
366
+
367
+ # check if the scheduler accepts generator
368
+ accepts_generator = "generator" in set(
369
+ inspect.signature(self.scheduler.step).parameters.keys()
370
+ )
371
+ if accepts_generator:
372
+ extra_step_kwargs["generator"] = generator
373
+ return extra_step_kwargs
374
+
375
+ def check_inputs(
376
+ self,
377
+ prompt,
378
+ callback_steps,
379
+ negative_prompt=None,
380
+ prompt_embeds=None,
381
+ negative_prompt_embeds=None,
382
+ ):
383
+ if (callback_steps is None) or (
384
+ callback_steps is not None
385
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
386
+ ):
387
+ raise ValueError(
388
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
389
+ f" {type(callback_steps)}."
390
+ )
391
+
392
+ if prompt is not None and prompt_embeds is not None:
393
+ raise ValueError(
394
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
395
+ " only forward one of the two."
396
+ )
397
+ elif prompt is None and prompt_embeds is None:
398
+ raise ValueError(
399
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
400
+ )
401
+ elif prompt is not None and (
402
+ not isinstance(prompt, str) and not isinstance(prompt, list)
403
+ ):
404
+ raise ValueError(
405
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
406
+ )
407
+
408
+ if negative_prompt is not None and negative_prompt_embeds is not None:
409
+ raise ValueError(
410
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
411
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
412
+ )
413
+
414
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
415
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
416
+ raise ValueError(
417
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
418
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
419
+ f" {negative_prompt_embeds.shape}."
420
+ )
421
+
422
+ def prepare_latents(
423
+ self,
424
+ batch_size,
425
+ num_channels_latents,
426
+ height,
427
+ width,
428
+ dtype,
429
+ device,
430
+ generator,
431
+ latents=None,
432
+ ):
433
+ shape = (
434
+ batch_size,
435
+ num_channels_latents,
436
+ height // self.vae_scale_factor,
437
+ width // self.vae_scale_factor,
438
+ )
439
+ if isinstance(generator, list) and len(generator) != batch_size:
440
+ raise ValueError(
441
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
442
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
443
+ )
444
+
445
+ if latents is None:
446
+ latents = randn_tensor(
447
+ shape, generator=generator, device=device, dtype=dtype
448
+ )
449
+ else:
450
+ latents = latents.to(device)
451
+
452
+ # scale the initial noise by the standard deviation required by the scheduler
453
+ latents = latents * self.scheduler.init_noise_sigma
454
+ return latents
455
+
456
+ def prepare_image_latents(
457
+ self,
458
+ image,
459
+ batch_size,
460
+ num_images_per_prompt,
461
+ dtype,
462
+ device,
463
+ do_classifier_free_guidance,
464
+ generator=None,
465
+ ):
466
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
467
+ raise ValueError(
468
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
469
+ )
470
+
471
+ image = image.to(device=device, dtype=dtype)
472
+
473
+ batch_size = batch_size * num_images_per_prompt
474
+
475
+ if image.shape[1] == 4:
476
+ image_latents = image
477
+ else:
478
+ if isinstance(generator, list) and len(generator) != batch_size:
479
+ raise ValueError(
480
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
481
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
482
+ )
483
+
484
+ if isinstance(generator, list):
485
+ image_latents = [
486
+ self.vae.encode(image[i : i + 1]).latent_dist.mode()
487
+ for i in range(batch_size)
488
+ ]
489
+ image_latents = torch.cat(image_latents, dim=0)
490
+ else:
491
+ image_latents = self.vae.encode(image).latent_dist.mode()
492
+
493
+ if (
494
+ batch_size > image_latents.shape[0]
495
+ and batch_size % image_latents.shape[0] == 0
496
+ ):
497
+ # expand image_latents for batch_size
498
+ deprecation_message = (
499
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
500
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
501
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
502
+ " your script to pass as many initial images as text prompts to suppress this warning."
503
+ )
504
+ deprecate(
505
+ "len(prompt) != len(image)",
506
+ "1.0.0",
507
+ deprecation_message,
508
+ standard_warn=False,
509
+ )
510
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
511
+ image_latents = torch.cat(
512
+ [image_latents] * additional_image_per_prompt, dim=0
513
+ )
514
+ elif (
515
+ batch_size > image_latents.shape[0]
516
+ and batch_size % image_latents.shape[0] != 0
517
+ ):
518
+ raise ValueError(
519
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
520
+ )
521
+ else:
522
+ image_latents = torch.cat([image_latents], dim=0)
523
+
524
+ if do_classifier_free_guidance:
525
+ uncond_image_latents = torch.zeros_like(image_latents)
526
+ image_latents = torch.cat(
527
+ [image_latents, image_latents, uncond_image_latents], dim=0
528
+ )
529
+
530
+ return image_latents
531
+
532
+ @torch.no_grad()
533
+ def __call__(
534
+ self,
535
+ prompt: Union[str, List[str]] = None,
536
+ photo: Union[
537
+ torch.FloatTensor,
538
+ PIL.Image.Image,
539
+ np.ndarray,
540
+ List[torch.FloatTensor],
541
+ List[PIL.Image.Image],
542
+ List[np.ndarray],
543
+ ] = None,
544
+ height: Optional[int] = None,
545
+ width: Optional[int] = None,
546
+ num_inference_steps: int = 100,
547
+ required_aovs: List[str] = ["albedo"],
548
+ negative_prompt: Optional[Union[str, List[str]]] = None,
549
+ num_images_per_prompt: Optional[int] = 1,
550
+ use_default_scaling_factor: Optional[bool] = False,
551
+ guidance_scale: float = 0.0,
552
+ image_guidance_scale: float = 0.0,
553
+ guidance_rescale: float = 0.0,
554
+ eta: float = 0.0,
555
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
556
+ latents: Optional[torch.FloatTensor] = None,
557
+ prompt_embeds: Optional[torch.FloatTensor] = None,
558
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
559
+ output_type: Optional[str] = "pil",
560
+ return_dict: bool = True,
561
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
562
+ callback_steps: int = 1,
563
+ ):
564
+ r"""
565
+ The call function to the pipeline for generation.
566
+
567
+ Args:
568
+ prompt (`str` or `List[str]`, *optional*):
569
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
570
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
571
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
572
+ image latents as `image`, but if passing latents directly it is not encoded again.
573
+ num_inference_steps (`int`, *optional*, defaults to 100):
574
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
575
+ expense of slower inference.
576
+ guidance_scale (`float`, *optional*, defaults to 7.5):
577
+ A higher guidance scale value encourages the model to generate images closely linked to the text
578
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
579
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
580
+ Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
581
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
582
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
583
+ value of at least `1`.
584
+ negative_prompt (`str` or `List[str]`, *optional*):
585
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
586
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
587
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
588
+ The number of images to generate per prompt.
589
+ eta (`float`, *optional*, defaults to 0.0):
590
+ Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
591
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
592
+ generator (`torch.Generator`, *optional*):
593
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
594
+ generation deterministic.
595
+ latents (`torch.FloatTensor`, *optional*):
596
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
597
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
598
+ tensor is generated by sampling using the supplied random `generator`.
599
+ prompt_embeds (`torch.FloatTensor`, *optional*):
600
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
601
+ provided, text embeddings are generated from the `prompt` input argument.
602
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
603
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
604
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
605
+ output_type (`str`, *optional*, defaults to `"pil"`):
606
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
607
+ return_dict (`bool`, *optional*, defaults to `True`):
608
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
609
+ plain tuple.
610
+ callback (`Callable`, *optional*):
611
+ A function that calls every `callback_steps` steps during inference. The function is called with the
612
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
613
+ callback_steps (`int`, *optional*, defaults to 1):
614
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
615
+ every step.
616
+
617
+ Examples:
618
+
619
+ ```py
620
+ >>> import PIL
621
+ >>> import requests
622
+ >>> import torch
623
+ >>> from io import BytesIO
624
+
625
+ >>> from diffusers import StableDiffusionInstructPix2PixPipeline
626
+
627
+
628
+ >>> def download_image(url):
629
+ ... response = requests.get(url)
630
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
631
+
632
+
633
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
634
+
635
+ >>> image = download_image(img_url).resize((512, 512))
636
+
637
+ >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
638
+ ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
639
+ ... )
640
+ >>> pipe = pipe.to("cuda")
641
+
642
+ >>> prompt = "make the mountains snowy"
643
+ >>> image = pipe(prompt=prompt, image=image).images[0]
644
+ ```
645
+
646
+ Returns:
647
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
648
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
649
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
650
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
651
+ "not-safe-for-work" (nsfw) content.
652
+ """
653
+ # 0. Check inputs
654
+ self.check_inputs(
655
+ prompt,
656
+ callback_steps,
657
+ negative_prompt,
658
+ prompt_embeds,
659
+ negative_prompt_embeds,
660
+ )
661
+
662
+ # 1. Define call parameters
663
+ if prompt is not None and isinstance(prompt, str):
664
+ batch_size = 1
665
+ elif prompt is not None and isinstance(prompt, list):
666
+ batch_size = len(prompt)
667
+ else:
668
+ batch_size = prompt_embeds.shape[0]
669
+
670
+ device = self._execution_device
671
+ do_classifier_free_guidance = (
672
+ guidance_scale > 1.0 and image_guidance_scale >= 1.0
673
+ )
674
+ # check if scheduler is in sigmas space
675
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
676
+
677
+ # 2. Encode input prompt
678
+ prompt_embeds = self._encode_prompt(
679
+ prompt,
680
+ device,
681
+ num_images_per_prompt,
682
+ do_classifier_free_guidance,
683
+ negative_prompt,
684
+ prompt_embeds=prompt_embeds,
685
+ negative_prompt_embeds=negative_prompt_embeds,
686
+ )
687
+
688
+ # 3. Preprocess image
689
+ # Normalize image to [-1,1]
690
+ preprocessed_photo = self.image_processor.preprocess(photo)
691
+
692
+ # 4. set timesteps
693
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
694
+ timesteps = self.scheduler.timesteps
695
+
696
+ # 5. Prepare Image latents
697
+ image_latents = self.prepare_image_latents(
698
+ preprocessed_photo,
699
+ batch_size,
700
+ num_images_per_prompt,
701
+ prompt_embeds.dtype,
702
+ device,
703
+ do_classifier_free_guidance,
704
+ generator,
705
+ )
706
+ image_latents = image_latents * self.vae.config.scaling_factor
707
+
708
+ height, width = image_latents.shape[-2:]
709
+ height = height * self.vae_scale_factor
710
+ width = width * self.vae_scale_factor
711
+
712
+ # 6. Prepare latent variables
713
+ num_channels_latents = self.unet.config.out_channels
714
+ latents = self.prepare_latents(
715
+ batch_size * num_images_per_prompt,
716
+ num_channels_latents,
717
+ height,
718
+ width,
719
+ prompt_embeds.dtype,
720
+ device,
721
+ generator,
722
+ latents,
723
+ )
724
+
725
+ # 7. Check that shapes of latents and image match the UNet channels
726
+ num_channels_image = image_latents.shape[1]
727
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
728
+ raise ValueError(
729
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
730
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
731
+ f" `num_channels_image`: {num_channels_image} "
732
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
733
+ " `pipeline.unet` or your `image` input."
734
+ )
735
+
736
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
737
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
738
+
739
+ # 9. Denoising loop
740
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
741
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
742
+ for i, t in enumerate(timesteps):
743
+ # Expand the latents if we are doing classifier free guidance.
744
+ # The latents are expanded 3 times because for pix2pix the guidance\
745
+ # is applied for both the text and the input image.
746
+ latent_model_input = (
747
+ torch.cat([latents] * 3) if do_classifier_free_guidance else latents
748
+ )
749
+
750
+ # concat latents, image_latents in the channel dimension
751
+ scaled_latent_model_input = self.scheduler.scale_model_input(
752
+ latent_model_input, t
753
+ )
754
+ scaled_latent_model_input = torch.cat(
755
+ [scaled_latent_model_input, image_latents], dim=1
756
+ )
757
+
758
+ # predict the noise residual
759
+ noise_pred = self.unet(
760
+ scaled_latent_model_input,
761
+ t,
762
+ encoder_hidden_states=prompt_embeds,
763
+ return_dict=False,
764
+ )[0]
765
+
766
+ # perform guidance
767
+ if do_classifier_free_guidance:
768
+ (
769
+ noise_pred_text,
770
+ noise_pred_image,
771
+ noise_pred_uncond,
772
+ ) = noise_pred.chunk(3)
773
+ noise_pred = (
774
+ noise_pred_uncond
775
+ + guidance_scale * (noise_pred_text - noise_pred_image)
776
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
777
+ )
778
+
779
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
780
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
781
+ noise_pred = rescale_noise_cfg(
782
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
783
+ )
784
+
785
+ # compute the previous noisy sample x_t -> x_t-1
786
+ latents = self.scheduler.step(
787
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
788
+ )[0]
789
+
790
+ # call the callback, if provided
791
+ if i == len(timesteps) - 1 or (
792
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
793
+ ):
794
+ progress_bar.update()
795
+ if callback is not None and i % callback_steps == 0:
796
+ callback(i, t, latents)
797
+
798
+ aov_latents = latents / self.vae.config.scaling_factor
799
+ aov = self.vae.decode(aov_latents, return_dict=False)[0]
800
+ do_denormalize = [True] * aov.shape[0]
801
+ aov_name = required_aovs[0]
802
+ if aov_name == "albedo" or aov_name == "irradiance":
803
+ do_gamma_correction = True
804
+ else:
805
+ do_gamma_correction = False
806
+
807
+ if aov_name == "roughness" or aov_name == "metallic":
808
+ aov = aov[:, 0:1].repeat(1, 3, 1, 1)
809
+
810
+ aov = self.image_processor.postprocess(
811
+ aov,
812
+ output_type=output_type,
813
+ do_denormalize=do_denormalize,
814
+ do_gamma_correction=do_gamma_correction,
815
+ )
816
+ aovs = [aov]
817
+
818
+ # Offload last model to CPU
819
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
820
+ self.final_offload_hook.offload()
821
+ return StableDiffusionAOVPipelineOutput(images=aovs)
x2rgb/example/kitchen-albedo.png ADDED

Git LFS Details

  • SHA256: d2b3e2ae5001c4214d5c87041e57933708cbb424eca6f7a2659c2c3e91a6a8ce
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
x2rgb/example/kitchen-irradiance.png ADDED

Git LFS Details

  • SHA256: 259b873bba6405d72a87a30321f4572a47d296726368eba2f0303d4ab3bcd269
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
x2rgb/example/kitchen-metallic.png ADDED

Git LFS Details

  • SHA256: cd6fec250659c8915c821b9063b851da4da59c2459c56fa08338cb81c5e6b70d
  • Pointer size: 130 Bytes
  • Size of remote file: 33.9 kB
x2rgb/example/kitchen-normal.png ADDED

Git LFS Details

  • SHA256: abf769887d2ee8fa050f56f50285502fcf8dbb8b69c28c1f3910ad1ee2874068
  • Pointer size: 131 Bytes
  • Size of remote file: 415 kB
x2rgb/example/kitchen-ref.png ADDED

Git LFS Details

  • SHA256: 19e57fc6737291cb59611786c9894fd4c2bedb0ba14b875942241195afff3534
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
x2rgb/example/kitchen-roughness.png ADDED

Git LFS Details

  • SHA256: 9d1195686031d170151798b00d095c48a40e1a8a508d65c5a841fcabb0ae8fad
  • Pointer size: 130 Bytes
  • Size of remote file: 84 kB
x2rgb/gradio_demo_x2rgb.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ from typing import cast
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from diffusers import DDIMScheduler
9
+ from load_image import load_exr_image, load_ldr_image
10
+ from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline
11
+
12
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
+
14
+ current_directory = os.path.dirname(os.path.abspath(__file__))
15
+
16
+ _pipe = StableDiffusionAOVDropoutPipeline.from_pretrained(
17
+ "zheng95z/x-to-rgb",
18
+ torch_dtype=torch.float16,
19
+ cache_dir=os.path.join(current_directory, "model_cache"),
20
+ ).to("cuda")
21
+ pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe)
22
+ pipe.scheduler = DDIMScheduler.from_config(
23
+ pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
24
+ )
25
+ pipe.set_progress_bar_config(disable=True)
26
+ pipe.to("cuda")
27
+ pipe = cast(StableDiffusionAOVDropoutPipeline, pipe)
28
+
29
+
30
+ @spaces.GPU
31
+ def generate(
32
+ albedo,
33
+ normal,
34
+ roughness,
35
+ metallic,
36
+ irradiance,
37
+ prompt: str,
38
+ seed: int,
39
+ inference_step: int,
40
+ num_samples: int,
41
+ guidance_scale: float,
42
+ image_guidance_scale: float,
43
+ ) -> list[Image.Image]:
44
+ generator = torch.Generator(device="cuda").manual_seed(seed)
45
+
46
+ # Load and process each intrinsic channel image
47
+ def process_image(file, **kwargs):
48
+ if file is None:
49
+ return None
50
+ if file.name.endswith(".exr"):
51
+ return load_exr_image(file.name, **kwargs).to("cuda")
52
+ elif file.name.endswith((".png", ".jpg", ".jpeg")):
53
+ return load_ldr_image(file.name, **kwargs).to("cuda")
54
+ return None
55
+
56
+ albedo_image = process_image(albedo, clamp=True)
57
+ normal_image = process_image(normal, normalize=True)
58
+ roughness_image = process_image(roughness, clamp=True)
59
+ metallic_image = process_image(metallic, clamp=True)
60
+ irradiance_image = process_image(irradiance, tonemaping=True, clamp=True)
61
+
62
+ # Set default height and width based on the first available image
63
+ height, width = 768, 768
64
+ for img in [
65
+ albedo_image,
66
+ normal_image,
67
+ roughness_image,
68
+ metallic_image,
69
+ irradiance_image,
70
+ ]:
71
+ if img is not None:
72
+ height, width = img.shape[1], img.shape[2]
73
+ break
74
+
75
+ required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
76
+ return_list = []
77
+
78
+ for i in range(num_samples):
79
+ generated_image = pipe(
80
+ prompt=prompt,
81
+ albedo=albedo_image,
82
+ normal=normal_image,
83
+ roughness=roughness_image,
84
+ metallic=metallic_image,
85
+ irradiance=irradiance_image,
86
+ num_inference_steps=inference_step,
87
+ height=height,
88
+ width=width,
89
+ generator=generator,
90
+ required_aovs=required_aovs,
91
+ guidance_scale=guidance_scale,
92
+ image_guidance_scale=image_guidance_scale,
93
+ guidance_rescale=0.7,
94
+ output_type="np",
95
+ ).images[0] # type: ignore
96
+
97
+ return_list.append((generated_image, f"Generated Image {i}"))
98
+
99
+ # Append additional images to the output gallery
100
+ def post_process_image(img, **kwargs):
101
+ if img is not None:
102
+ return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
103
+ return np.zeros((height, width, 3))
104
+
105
+ return_list.extend(
106
+ [
107
+ post_process_image(albedo_image, label="Albedo"),
108
+ post_process_image(normal_image, label="Normal"),
109
+ post_process_image(roughness_image, label="Roughness"),
110
+ post_process_image(metallic_image, label="Metallic"),
111
+ post_process_image(irradiance_image, label="Irradiance"),
112
+ ]
113
+ )
114
+
115
+ return return_list
116
+
117
+
118
+ with gr.Blocks() as demo:
119
+ with gr.Row():
120
+ gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)")
121
+ with gr.Row():
122
+ # Input side
123
+ with gr.Column():
124
+ gr.Markdown("### Given intrinsic channels")
125
+ albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"])
126
+ normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"])
127
+ roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"])
128
+ metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"])
129
+ irradiance = gr.File(
130
+ label="Irradiance", file_types=[".exr", ".png", ".jpg"]
131
+ )
132
+
133
+ gr.Markdown("### Parameters")
134
+ prompt = gr.Textbox(label="Prompt")
135
+ run_button = gr.Button(value="Run")
136
+ with gr.Accordion("Advanced options", open=False):
137
+ seed = gr.Slider(
138
+ label="Seed",
139
+ minimum=-1,
140
+ maximum=2147483647,
141
+ step=1,
142
+ randomize=True,
143
+ )
144
+ inference_step = gr.Slider(
145
+ label="Inference Step",
146
+ minimum=1,
147
+ maximum=100,
148
+ step=1,
149
+ value=50,
150
+ )
151
+ num_samples = gr.Slider(
152
+ label="Samples",
153
+ minimum=1,
154
+ maximum=100,
155
+ step=1,
156
+ value=1,
157
+ )
158
+ guidance_scale = gr.Slider(
159
+ label="Guidance Scale",
160
+ minimum=0.0,
161
+ maximum=10.0,
162
+ step=0.1,
163
+ value=7.5,
164
+ )
165
+ image_guidance_scale = gr.Slider(
166
+ label="Image Guidance Scale",
167
+ minimum=0.0,
168
+ maximum=10.0,
169
+ step=0.1,
170
+ value=1.5,
171
+ )
172
+
173
+ # Output side
174
+ with gr.Column():
175
+ gr.Markdown("### Output Gallery")
176
+ result_gallery = gr.Gallery(
177
+ label="Output",
178
+ show_label=False,
179
+ elem_id="gallery",
180
+ columns=2,
181
+ )
182
+
183
+ run_button.click(
184
+ fn=generate,
185
+ inputs=[
186
+ albedo,
187
+ normal,
188
+ roughness,
189
+ metallic,
190
+ irradiance,
191
+ prompt,
192
+ seed,
193
+ inference_step,
194
+ num_samples,
195
+ guidance_scale,
196
+ image_guidance_scale,
197
+ ],
198
+ outputs=result_gallery,
199
+ queue=True,
200
+ )
201
+
202
+
203
+ if __name__ == "__main__":
204
+ demo.launch(debug=False, share=False, show_api=False)
x2rgb/load_image.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+
6
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
7
+ import numpy as np
8
+
9
+
10
+ def convert_rgb_2_XYZ(rgb):
11
+ # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
12
+ # rgb: (h, w, 3)
13
+ # XYZ: (h, w, 3)
14
+ XYZ = torch.ones_like(rgb)
15
+ XYZ[:, :, 0] = (
16
+ 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
17
+ )
18
+ XYZ[:, :, 1] = (
19
+ 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
20
+ )
21
+ XYZ[:, :, 2] = (
22
+ 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
23
+ )
24
+ return XYZ
25
+
26
+
27
+ def convert_XYZ_2_Yxy(XYZ):
28
+ # XYZ: (h, w, 3)
29
+ # Yxy: (h, w, 3)
30
+ Yxy = torch.ones_like(XYZ)
31
+ Yxy[:, :, 0] = XYZ[:, :, 1]
32
+ sum = torch.sum(XYZ, dim=2)
33
+ inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
34
+ Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
35
+ Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
36
+ return Yxy
37
+
38
+
39
+ def convert_rgb_2_Yxy(rgb):
40
+ # rgb: (h, w, 3)
41
+ # Yxy: (h, w, 3)
42
+ return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
43
+
44
+
45
+ def convert_XYZ_2_rgb(XYZ):
46
+ # XYZ: (h, w, 3)
47
+ # rgb: (h, w, 3)
48
+ rgb = torch.ones_like(XYZ)
49
+ rgb[:, :, 0] = (
50
+ 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
51
+ )
52
+ rgb[:, :, 1] = (
53
+ -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
54
+ )
55
+ rgb[:, :, 2] = (
56
+ 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
57
+ )
58
+ return rgb
59
+
60
+
61
+ def convert_Yxy_2_XYZ(Yxy):
62
+ # Yxy: (h, w, 3)
63
+ # XYZ: (h, w, 3)
64
+ XYZ = torch.ones_like(Yxy)
65
+ XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
66
+ XYZ[:, :, 1] = Yxy[:, :, 0]
67
+ XYZ[:, :, 2] = (
68
+ (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
69
+ / torch.clamp(Yxy[:, :, 2], min=1e-4)
70
+ * Yxy[:, :, 0]
71
+ )
72
+ return XYZ
73
+
74
+
75
+ def convert_Yxy_2_rgb(Yxy):
76
+ # Yxy: (h, w, 3)
77
+ # rgb: (h, w, 3)
78
+ return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
79
+
80
+
81
+ def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
82
+ # Load png or jpg image
83
+ image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
84
+ image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
85
+ image[~torch.isfinite(image)] = 0
86
+ if from_srgb:
87
+ # Convert from sRGB to linear RGB
88
+ image = image**2.2
89
+ if clamp:
90
+ image = torch.clamp(image, min=0.0, max=1.0)
91
+ if normalize:
92
+ # Normalize to [-1, 1]
93
+ image = image * 2.0 - 1.0
94
+ image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
95
+ return image.permute(2, 0, 1) # returns (c, h, w)
96
+
97
+
98
+ def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
99
+ image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
100
+ image = torch.from_numpy(image.astype("float32")) # (h, w, c)
101
+ image[~torch.isfinite(image)] = 0
102
+ if tonemaping:
103
+ # Exposure adjuestment
104
+ image_Yxy = convert_rgb_2_Yxy(image)
105
+ lum = (
106
+ image[:, :, 0:1] * 0.2125
107
+ + image[:, :, 1:2] * 0.7154
108
+ + image[:, :, 2:3] * 0.0721
109
+ )
110
+ lum = torch.log(torch.clamp(lum, min=1e-6))
111
+ lum_mean = torch.exp(torch.mean(lum))
112
+ lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
113
+ image_Yxy[:, :, 0:1] = lp
114
+ image = convert_Yxy_2_rgb(image_Yxy)
115
+ if clamp:
116
+ image = torch.clamp(image, min=0.0, max=1.0)
117
+ if normalize:
118
+ image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
119
+ return image.permute(2, 0, 1) # returns (c, h, w)
x2rgb/pipeline_x2rgb.py ADDED
@@ -0,0 +1,967 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import register_to_config
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
12
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
13
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
15
+ rescale_noise_cfg,
16
+ )
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import CONFIG_NAME, BaseOutput, deprecate, logging, randn_tensor
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VaeImageProcrssorAOV(VaeImageProcessor):
25
+ """
26
+ Image processor for VAE AOV.
27
+
28
+ Args:
29
+ do_resize (`bool`, *optional*, defaults to `True`):
30
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
31
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
32
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
33
+ resample (`str`, *optional*, defaults to `lanczos`):
34
+ Resampling filter to use when resizing the image.
35
+ do_normalize (`bool`, *optional*, defaults to `True`):
36
+ Whether to normalize the image to [-1,1].
37
+ """
38
+
39
+ config_name = CONFIG_NAME
40
+
41
+ @register_to_config
42
+ def __init__(
43
+ self,
44
+ do_resize: bool = True,
45
+ vae_scale_factor: int = 8,
46
+ resample: str = "lanczos",
47
+ do_normalize: bool = True,
48
+ ):
49
+ super().__init__()
50
+
51
+ def postprocess(
52
+ self,
53
+ image: torch.FloatTensor,
54
+ output_type: str = "pil",
55
+ do_denormalize: Optional[List[bool]] = None,
56
+ do_gamma_correction: bool = True,
57
+ ):
58
+ if not isinstance(image, torch.Tensor):
59
+ raise ValueError(
60
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
61
+ )
62
+ if output_type not in ["latent", "pt", "np", "pil"]:
63
+ deprecation_message = (
64
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
65
+ "`pil`, `np`, `pt`, `latent`"
66
+ )
67
+ deprecate(
68
+ "Unsupported output_type",
69
+ "1.0.0",
70
+ deprecation_message,
71
+ standard_warn=False,
72
+ )
73
+ output_type = "np"
74
+
75
+ if output_type == "latent":
76
+ return image
77
+
78
+ if do_denormalize is None:
79
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
80
+
81
+ image = torch.stack(
82
+ [
83
+ self.denormalize(image[i]) if do_denormalize[i] else image[i]
84
+ for i in range(image.shape[0])
85
+ ]
86
+ )
87
+
88
+ # Gamma correction
89
+ if do_gamma_correction:
90
+ image = torch.pow(image, 1.0 / 2.2)
91
+
92
+ if output_type == "pt":
93
+ return image
94
+
95
+ image = self.pt_to_numpy(image)
96
+
97
+ if output_type == "np":
98
+ return image
99
+
100
+ if output_type == "pil":
101
+ return self.numpy_to_pil(image)
102
+
103
+ def preprocess_normal(
104
+ self,
105
+ image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
106
+ height: Optional[int] = None,
107
+ width: Optional[int] = None,
108
+ ) -> torch.Tensor:
109
+ image = torch.stack([image], axis=0)
110
+ return image
111
+
112
+
113
+ @dataclass
114
+ class StableDiffusionAOVPipelineOutput(BaseOutput):
115
+ """
116
+ Output class for Stable Diffusion AOV pipelines.
117
+
118
+ Args:
119
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
120
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
121
+ num_channels)`.
122
+ nsfw_content_detected (`List[bool]`)
123
+ List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
124
+ `None` if safety checking could not be performed.
125
+ """
126
+
127
+ images: Union[List[PIL.Image.Image], np.ndarray]
128
+ predicted_x0_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] = None
129
+
130
+
131
+ class StableDiffusionAOVDropoutPipeline(
132
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
133
+ ):
134
+ r"""
135
+ Pipeline for AOVs.
136
+
137
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
138
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
139
+
140
+ The pipeline also inherits the following loading methods:
141
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
142
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
143
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
144
+
145
+ Args:
146
+ vae ([`AutoencoderKL`]):
147
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
148
+ text_encoder ([`~transformers.CLIPTextModel`]):
149
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
150
+ tokenizer ([`~transformers.CLIPTokenizer`]):
151
+ A `CLIPTokenizer` to tokenize text.
152
+ unet ([`UNet2DConditionModel`]):
153
+ A `UNet2DConditionModel` to denoise the encoded image latents.
154
+ scheduler ([`SchedulerMixin`]):
155
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
156
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ vae: AutoencoderKL,
162
+ text_encoder: CLIPTextModel,
163
+ tokenizer: CLIPTokenizer,
164
+ unet: UNet2DConditionModel,
165
+ scheduler: KarrasDiffusionSchedulers,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ unet=unet,
174
+ scheduler=scheduler,
175
+ )
176
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
177
+ self.image_processor = VaeImageProcrssorAOV(
178
+ vae_scale_factor=self.vae_scale_factor
179
+ )
180
+ self.register_to_config()
181
+
182
+ def _encode_prompt(
183
+ self,
184
+ prompt,
185
+ device,
186
+ num_images_per_prompt,
187
+ do_classifier_free_guidance,
188
+ negative_prompt=None,
189
+ prompt_embeds: Optional[torch.FloatTensor] = None,
190
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
191
+ ):
192
+ r"""
193
+ Encodes the prompt into text encoder hidden states.
194
+
195
+ Args:
196
+ prompt (`str` or `List[str]`, *optional*):
197
+ prompt to be encoded
198
+ device: (`torch.device`):
199
+ torch device
200
+ num_images_per_prompt (`int`):
201
+ number of images that should be generated per prompt
202
+ do_classifier_free_guidance (`bool`):
203
+ whether to use classifier free guidance or not
204
+ negative_ prompt (`str` or `List[str]`, *optional*):
205
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
206
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
207
+ less than `1`).
208
+ prompt_embeds (`torch.FloatTensor`, *optional*):
209
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
210
+ provided, text embeddings will be generated from `prompt` input argument.
211
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
213
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
214
+ argument.
215
+ """
216
+ if prompt is not None and isinstance(prompt, str):
217
+ batch_size = 1
218
+ elif prompt is not None and isinstance(prompt, list):
219
+ batch_size = len(prompt)
220
+ else:
221
+ batch_size = prompt_embeds.shape[0]
222
+
223
+ if prompt_embeds is None:
224
+ # textual inversion: procecss multi-vector tokens if necessary
225
+ if isinstance(self, TextualInversionLoaderMixin):
226
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
227
+
228
+ text_inputs = self.tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ max_length=self.tokenizer.model_max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ text_input_ids = text_inputs.input_ids
236
+ untruncated_ids = self.tokenizer(
237
+ prompt, padding="longest", return_tensors="pt"
238
+ ).input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
241
+ -1
242
+ ] and not torch.equal(text_input_ids, untruncated_ids):
243
+ removed_text = self.tokenizer.batch_decode(
244
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
+ )
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ if (
252
+ hasattr(self.text_encoder.config, "use_attention_mask")
253
+ and self.text_encoder.config.use_attention_mask
254
+ ):
255
+ attention_mask = text_inputs.attention_mask.to(device)
256
+ else:
257
+ attention_mask = None
258
+
259
+ prompt_embeds = self.text_encoder(
260
+ text_input_ids.to(device),
261
+ attention_mask=attention_mask,
262
+ )
263
+ prompt_embeds = prompt_embeds[0]
264
+
265
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ bs_embed, seq_len, _ = prompt_embeds.shape
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ bs_embed * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ # get unconditional embeddings for classifier free guidance
275
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
276
+ uncond_tokens: List[str]
277
+ if negative_prompt is None:
278
+ uncond_tokens = [""] * batch_size
279
+ elif type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif isinstance(negative_prompt, str):
285
+ uncond_tokens = [negative_prompt]
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = negative_prompt
294
+
295
+ # textual inversion: procecss multi-vector tokens if necessary
296
+ if isinstance(self, TextualInversionLoaderMixin):
297
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
298
+
299
+ max_length = prompt_embeds.shape[1]
300
+ uncond_input = self.tokenizer(
301
+ uncond_tokens,
302
+ padding="max_length",
303
+ max_length=max_length,
304
+ truncation=True,
305
+ return_tensors="pt",
306
+ )
307
+
308
+ if (
309
+ hasattr(self.text_encoder.config, "use_attention_mask")
310
+ and self.text_encoder.config.use_attention_mask
311
+ ):
312
+ attention_mask = uncond_input.attention_mask.to(device)
313
+ else:
314
+ attention_mask = None
315
+
316
+ negative_prompt_embeds = self.text_encoder(
317
+ uncond_input.input_ids.to(device),
318
+ attention_mask=attention_mask,
319
+ )
320
+ negative_prompt_embeds = negative_prompt_embeds[0]
321
+
322
+ if do_classifier_free_guidance:
323
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
324
+ seq_len = negative_prompt_embeds.shape[1]
325
+
326
+ negative_prompt_embeds = negative_prompt_embeds.to(
327
+ dtype=self.text_encoder.dtype, device=device
328
+ )
329
+
330
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
331
+ 1, num_images_per_prompt, 1
332
+ )
333
+ negative_prompt_embeds = negative_prompt_embeds.view(
334
+ batch_size * num_images_per_prompt, seq_len, -1
335
+ )
336
+
337
+ # For classifier free guidance, we need to do two forward passes.
338
+ # Here we concatenate the unconditional and text embeddings into a single batch
339
+ # to avoid doing two forward passes
340
+ # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
341
+ prompt_embeds = torch.cat(
342
+ [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
343
+ )
344
+
345
+ return prompt_embeds
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(
354
+ inspect.signature(self.scheduler.step).parameters.keys()
355
+ )
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(
362
+ inspect.signature(self.scheduler.step).parameters.keys()
363
+ )
364
+ if accepts_generator:
365
+ extra_step_kwargs["generator"] = generator
366
+ return extra_step_kwargs
367
+
368
+ def check_inputs(
369
+ self,
370
+ prompt,
371
+ callback_steps,
372
+ negative_prompt=None,
373
+ prompt_embeds=None,
374
+ negative_prompt_embeds=None,
375
+ ):
376
+ if (callback_steps is None) or (
377
+ callback_steps is not None
378
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
379
+ ):
380
+ raise ValueError(
381
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
382
+ f" {type(callback_steps)}."
383
+ )
384
+
385
+ if prompt is not None and prompt_embeds is not None:
386
+ raise ValueError(
387
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
388
+ " only forward one of the two."
389
+ )
390
+ elif prompt is None and prompt_embeds is None:
391
+ raise ValueError(
392
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
393
+ )
394
+ elif prompt is not None and (
395
+ not isinstance(prompt, str) and not isinstance(prompt, list)
396
+ ):
397
+ raise ValueError(
398
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
399
+ )
400
+
401
+ if negative_prompt is not None and negative_prompt_embeds is not None:
402
+ raise ValueError(
403
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
404
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
405
+ )
406
+
407
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
408
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
409
+ raise ValueError(
410
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
411
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
412
+ f" {negative_prompt_embeds.shape}."
413
+ )
414
+
415
+ def prepare_latents(
416
+ self,
417
+ batch_size,
418
+ num_channels_latents,
419
+ height,
420
+ width,
421
+ dtype,
422
+ device,
423
+ generator,
424
+ latents=None,
425
+ ):
426
+ shape = (
427
+ batch_size,
428
+ num_channels_latents,
429
+ height // self.vae_scale_factor,
430
+ width // self.vae_scale_factor,
431
+ )
432
+ if isinstance(generator, list) and len(generator) != batch_size:
433
+ raise ValueError(
434
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
435
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
436
+ )
437
+
438
+ if latents is None:
439
+ latents = randn_tensor(
440
+ shape, generator=generator, device=device, dtype=dtype
441
+ )
442
+ else:
443
+ latents = latents.to(device)
444
+
445
+ # scale the initial noise by the standard deviation required by the scheduler
446
+ latents = latents * self.scheduler.init_noise_sigma
447
+ return latents
448
+
449
+ def prepare_image_latents(
450
+ self,
451
+ image,
452
+ batch_size,
453
+ num_images_per_prompt,
454
+ dtype,
455
+ device,
456
+ do_classifier_free_guidance,
457
+ generator=None,
458
+ ):
459
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
460
+ raise ValueError(
461
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
462
+ )
463
+
464
+ image = image.to(device=device, dtype=dtype)
465
+
466
+ batch_size = batch_size * num_images_per_prompt
467
+
468
+ if image.shape[1] == 4:
469
+ image_latents = image
470
+ else:
471
+ if isinstance(generator, list) and len(generator) != batch_size:
472
+ raise ValueError(
473
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
474
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
475
+ )
476
+
477
+ if isinstance(generator, list):
478
+ image_latents = [
479
+ self.vae.encode(image[i : i + 1]).latent_dist.mode()
480
+ for i in range(batch_size)
481
+ ]
482
+ image_latents = torch.cat(image_latents, dim=0)
483
+ else:
484
+ image_latents = self.vae.encode(image).latent_dist.mode()
485
+
486
+ if (
487
+ batch_size > image_latents.shape[0]
488
+ and batch_size % image_latents.shape[0] == 0
489
+ ):
490
+ # expand image_latents for batch_size
491
+ deprecation_message = (
492
+ f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
493
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
494
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
495
+ " your script to pass as many initial images as text prompts to suppress this warning."
496
+ )
497
+ deprecate(
498
+ "len(prompt) != len(image)",
499
+ "1.0.0",
500
+ deprecation_message,
501
+ standard_warn=False,
502
+ )
503
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
504
+ image_latents = torch.cat(
505
+ [image_latents] * additional_image_per_prompt, dim=0
506
+ )
507
+ elif (
508
+ batch_size > image_latents.shape[0]
509
+ and batch_size % image_latents.shape[0] != 0
510
+ ):
511
+ raise ValueError(
512
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
513
+ )
514
+ else:
515
+ image_latents = torch.cat([image_latents], dim=0)
516
+
517
+ if do_classifier_free_guidance:
518
+ uncond_image_latents = torch.zeros_like(image_latents)
519
+ image_latents = torch.cat(
520
+ [image_latents, image_latents, uncond_image_latents], dim=0
521
+ )
522
+
523
+ return image_latents
524
+
525
+ @torch.no_grad()
526
+ def __call__(
527
+ self,
528
+ height: int,
529
+ width: int,
530
+ prompt: Union[str, List[str]] = None,
531
+ albedo: Optional[
532
+ Union[
533
+ torch.FloatTensor,
534
+ PIL.Image.Image,
535
+ np.ndarray,
536
+ List[torch.FloatTensor],
537
+ List[PIL.Image.Image],
538
+ List[np.ndarray],
539
+ ]
540
+ ] = None,
541
+ normal: Optional[
542
+ Union[
543
+ torch.FloatTensor,
544
+ PIL.Image.Image,
545
+ np.ndarray,
546
+ List[torch.FloatTensor],
547
+ List[PIL.Image.Image],
548
+ List[np.ndarray],
549
+ ]
550
+ ] = None,
551
+ roughness: Optional[
552
+ Union[
553
+ torch.FloatTensor,
554
+ PIL.Image.Image,
555
+ np.ndarray,
556
+ List[torch.FloatTensor],
557
+ List[PIL.Image.Image],
558
+ List[np.ndarray],
559
+ ]
560
+ ] = None,
561
+ metallic: Optional[
562
+ Union[
563
+ torch.FloatTensor,
564
+ PIL.Image.Image,
565
+ np.ndarray,
566
+ List[torch.FloatTensor],
567
+ List[PIL.Image.Image],
568
+ List[np.ndarray],
569
+ ]
570
+ ] = None,
571
+ irradiance: Optional[
572
+ Union[
573
+ torch.FloatTensor,
574
+ PIL.Image.Image,
575
+ np.ndarray,
576
+ List[torch.FloatTensor],
577
+ List[PIL.Image.Image],
578
+ List[np.ndarray],
579
+ ]
580
+ ] = None,
581
+ guidance_scale: float = 0.0,
582
+ image_guidance_scale: float = 0.0,
583
+ guidance_rescale: float = 0.0,
584
+ num_inference_steps: int = 100,
585
+ required_aovs: List[str] = ["albedo"],
586
+ return_predicted_x0s: bool = False,
587
+ negative_prompt: Optional[Union[str, List[str]]] = None,
588
+ num_images_per_prompt: Optional[int] = 1,
589
+ eta: float = 0.0,
590
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
591
+ latents: Optional[torch.FloatTensor] = None,
592
+ prompt_embeds: Optional[torch.FloatTensor] = None,
593
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
594
+ output_type: Optional[str] = "pil",
595
+ return_dict: bool = True,
596
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
+ callback_steps: int = 1,
598
+ ):
599
+ r"""
600
+ The call function to the pipeline for generation.
601
+
602
+ Args:
603
+ prompt (`str` or `List[str]`, *optional*):
604
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
605
+ image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
606
+ `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
607
+ image latents as `image`, but if passing latents directly it is not encoded again.
608
+ num_inference_steps (`int`, *optional*, defaults to 100):
609
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
610
+ expense of slower inference.
611
+ guidance_scale (`float`, *optional*, defaults to 7.5):
612
+ A higher guidance scale value encourages the model to generate images closely linked to the text
613
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
614
+ image_guidance_scale (`float`, *optional*, defaults to 1.5):
615
+ Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
616
+ `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
617
+ linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
618
+ value of at least `1`.
619
+ negative_prompt (`str` or `List[str]`, *optional*):
620
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
621
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
622
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
623
+ The number of images to generate per prompt.
624
+ eta (`float`, *optional*, defaults to 0.0):
625
+ Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
626
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
627
+ generator (`torch.Generator`, *optional*):
628
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
629
+ generation deterministic.
630
+ latents (`torch.FloatTensor`, *optional*):
631
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
632
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
633
+ tensor is generated by sampling using the supplied random `generator`.
634
+ prompt_embeds (`torch.FloatTensor`, *optional*):
635
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
636
+ provided, text embeddings are generated from the `prompt` input argument.
637
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
638
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
639
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
640
+ output_type (`str`, *optional*, defaults to `"pil"`):
641
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
642
+ return_dict (`bool`, *optional*, defaults to `True`):
643
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
644
+ plain tuple.
645
+ callback (`Callable`, *optional*):
646
+ A function that calls every `callback_steps` steps during inference. The function is called with the
647
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
648
+ callback_steps (`int`, *optional*, defaults to 1):
649
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
650
+ every step.
651
+
652
+ Examples:
653
+
654
+ ```py
655
+ >>> import PIL
656
+ >>> import requests
657
+ >>> import torch
658
+ >>> from io import BytesIO
659
+
660
+ >>> from diffusers import StableDiffusionInstructPix2PixPipeline
661
+
662
+
663
+ >>> def download_image(url):
664
+ ... response = requests.get(url)
665
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
666
+
667
+
668
+ >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
669
+
670
+ >>> image = download_image(img_url).resize((512, 512))
671
+
672
+ >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
673
+ ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
674
+ ... )
675
+ >>> pipe = pipe.to("cuda")
676
+
677
+ >>> prompt = "make the mountains snowy"
678
+ >>> image = pipe(prompt=prompt, image=image).images[0]
679
+ ```
680
+
681
+ Returns:
682
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
683
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
684
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
685
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
686
+ "not-safe-for-work" (nsfw) content.
687
+ """
688
+ # 0. Check inputs
689
+ self.check_inputs(
690
+ prompt,
691
+ callback_steps,
692
+ negative_prompt,
693
+ prompt_embeds,
694
+ negative_prompt_embeds,
695
+ )
696
+
697
+ # 1. Define call parameters
698
+ if prompt is not None and isinstance(prompt, str):
699
+ batch_size = 1
700
+ elif prompt is not None and isinstance(prompt, list):
701
+ batch_size = len(prompt)
702
+ else:
703
+ batch_size = prompt_embeds.shape[0]
704
+
705
+ device = self._execution_device
706
+ do_classifier_free_guidance = (
707
+ guidance_scale >= 1.0 and image_guidance_scale >= 1.0
708
+ )
709
+ # check if scheduler is in sigmas space
710
+ scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
711
+
712
+ # 2. Encode input prompt
713
+ prompt_embeds = self._encode_prompt(
714
+ prompt,
715
+ device,
716
+ num_images_per_prompt,
717
+ do_classifier_free_guidance,
718
+ negative_prompt,
719
+ prompt_embeds=prompt_embeds,
720
+ negative_prompt_embeds=negative_prompt_embeds,
721
+ )
722
+
723
+ # 3. Preprocess image
724
+ # For normal, the preprocessing does nothing
725
+ # For others, the preprocessing remap the values to [-1, 1]
726
+ preprocessed_aovs = {}
727
+ for aov_name in required_aovs:
728
+ if aov_name == "albedo":
729
+ if albedo is not None:
730
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
731
+ albedo
732
+ )
733
+ else:
734
+ preprocessed_aovs[aov_name] = None
735
+
736
+ if aov_name == "normal":
737
+ if normal is not None:
738
+ preprocessed_aovs[aov_name] = (
739
+ self.image_processor.preprocess_normal(normal)
740
+ )
741
+ else:
742
+ preprocessed_aovs[aov_name] = None
743
+
744
+ if aov_name == "roughness":
745
+ if roughness is not None:
746
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
747
+ roughness
748
+ )
749
+ else:
750
+ preprocessed_aovs[aov_name] = None
751
+ if aov_name == "metallic":
752
+ if metallic is not None:
753
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
754
+ metallic
755
+ )
756
+ else:
757
+ preprocessed_aovs[aov_name] = None
758
+ if aov_name == "irradiance":
759
+ if irradiance is not None:
760
+ preprocessed_aovs[aov_name] = self.image_processor.preprocess(
761
+ irradiance
762
+ )
763
+ else:
764
+ preprocessed_aovs[aov_name] = None
765
+
766
+ # 4. set timesteps
767
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
768
+ timesteps = self.scheduler.timesteps
769
+
770
+ # 5. Prepare latent variables
771
+ num_channels_latents = self.vae.config.latent_channels
772
+ latents = self.prepare_latents(
773
+ batch_size * num_images_per_prompt,
774
+ num_channels_latents,
775
+ height,
776
+ width,
777
+ prompt_embeds.dtype,
778
+ device,
779
+ generator,
780
+ latents,
781
+ )
782
+
783
+ height_latent, width_latent = latents.shape[-2:]
784
+
785
+ # 6. Prepare Image latents
786
+ image_latents = []
787
+ # Magicial scaling factors for each AOV (calculated from the training data)
788
+ scaling_factors = {
789
+ "albedo": 0.17301377137652138,
790
+ "normal": 0.17483895473058078,
791
+ "roughness": 0.1680724853626448,
792
+ "metallic": 0.13135013390855135,
793
+ }
794
+ for aov_name, aov in preprocessed_aovs.items():
795
+ if aov is None:
796
+ image_latent = torch.zeros(
797
+ batch_size,
798
+ num_channels_latents,
799
+ height_latent,
800
+ width_latent,
801
+ dtype=prompt_embeds.dtype,
802
+ device=device,
803
+ )
804
+ if aov_name == "irradiance":
805
+ image_latent = image_latent[:, 0:3]
806
+ if do_classifier_free_guidance:
807
+ image_latents.append(
808
+ torch.cat([image_latent, image_latent, image_latent], dim=0)
809
+ )
810
+ else:
811
+ image_latents.append(image_latent)
812
+ else:
813
+ if aov_name == "irradiance":
814
+ image_latent = F.interpolate(
815
+ aov.to(device=device, dtype=prompt_embeds.dtype),
816
+ size=(height_latent, width_latent),
817
+ mode="bilinear",
818
+ align_corners=False,
819
+ antialias=True,
820
+ )
821
+ if do_classifier_free_guidance:
822
+ uncond_image_latent = torch.zeros_like(image_latent)
823
+ image_latent = torch.cat(
824
+ [image_latent, image_latent, uncond_image_latent], dim=0
825
+ )
826
+ else:
827
+ scaling_factor = scaling_factors[aov_name]
828
+ image_latent = (
829
+ self.prepare_image_latents(
830
+ aov,
831
+ batch_size,
832
+ num_images_per_prompt,
833
+ prompt_embeds.dtype,
834
+ device,
835
+ do_classifier_free_guidance,
836
+ generator,
837
+ )
838
+ * scaling_factor
839
+ )
840
+ image_latents.append(image_latent)
841
+ image_latents = torch.cat(image_latents, dim=1)
842
+
843
+ # 7. Check that shapes of latents and image match the UNet channels
844
+ num_channels_image = image_latents.shape[1]
845
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
846
+ raise ValueError(
847
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
848
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
849
+ f" `num_channels_image`: {num_channels_image} "
850
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
851
+ " `pipeline.unet` or your `image` input."
852
+ )
853
+
854
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
855
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
856
+
857
+ predicted_x0s = []
858
+
859
+ # 9. Denoising loop
860
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
861
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
862
+ for i, t in enumerate(timesteps):
863
+ # Expand the latents if we are doing classifier free guidance.
864
+ # The latents are expanded 3 times because for pix2pix the guidance\
865
+ # is applied for both the text and the input image.
866
+ latent_model_input = (
867
+ torch.cat([latents] * 3) if do_classifier_free_guidance else latents
868
+ )
869
+
870
+ # concat latents, image_latents in the channel dimension
871
+ scaled_latent_model_input = self.scheduler.scale_model_input(
872
+ latent_model_input, t
873
+ )
874
+ scaled_latent_model_input = torch.cat(
875
+ [scaled_latent_model_input, image_latents], dim=1
876
+ )
877
+
878
+ # predict the noise residual
879
+ noise_pred = self.unet(
880
+ scaled_latent_model_input,
881
+ t,
882
+ encoder_hidden_states=prompt_embeds,
883
+ return_dict=False,
884
+ )[0]
885
+
886
+ # perform guidance
887
+ if do_classifier_free_guidance:
888
+ (
889
+ noise_pred_text,
890
+ noise_pred_image,
891
+ noise_pred_uncond,
892
+ ) = noise_pred.chunk(3)
893
+ noise_pred = (
894
+ noise_pred_uncond
895
+ + guidance_scale * (noise_pred_text - noise_pred_image)
896
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
897
+ )
898
+
899
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
900
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
901
+ noise_pred = rescale_noise_cfg(
902
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
903
+ )
904
+
905
+ # compute the previous noisy sample x_t -> x_t-1
906
+ output = self.scheduler.step(
907
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=True
908
+ )
909
+
910
+ latents = output[0]
911
+
912
+ if return_predicted_x0s:
913
+ predicted_x0s.append(output[1])
914
+
915
+ # call the callback, if provided
916
+ if i == len(timesteps) - 1 or (
917
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
918
+ ):
919
+ progress_bar.update()
920
+ if callback is not None and i % callback_steps == 0:
921
+ callback(i, t, latents)
922
+
923
+ if not output_type == "latent":
924
+ image = self.vae.decode(
925
+ latents / self.vae.config.scaling_factor, return_dict=False
926
+ )[0]
927
+
928
+ if return_predicted_x0s:
929
+ predicted_x0_images = [
930
+ self.vae.decode(
931
+ predicted_x0 / self.vae.config.scaling_factor, return_dict=False
932
+ )[0]
933
+ for predicted_x0 in predicted_x0s
934
+ ]
935
+ else:
936
+ image = latents
937
+ predicted_x0_images = predicted_x0s
938
+
939
+ do_denormalize = [True] * image.shape[0]
940
+
941
+ image = self.image_processor.postprocess(
942
+ image, output_type=output_type, do_denormalize=do_denormalize
943
+ )
944
+
945
+ if return_predicted_x0s:
946
+ predicted_x0_images = [
947
+ self.image_processor.postprocess(
948
+ predicted_x0_image,
949
+ output_type=output_type,
950
+ do_denormalize=do_denormalize,
951
+ )
952
+ for predicted_x0_image in predicted_x0_images
953
+ ]
954
+
955
+ # Offload last model to CPU
956
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
957
+ self.final_offload_hook.offload()
958
+
959
+ if not return_dict:
960
+ return image
961
+
962
+ if return_predicted_x0s:
963
+ return StableDiffusionAOVPipelineOutput(
964
+ images=image, predicted_x0_images=predicted_x0_images
965
+ )
966
+ else:
967
+ return StableDiffusionAOVPipelineOutput(images=image)