Spaces:
Runtime error
Runtime error
Commit
Β·
a9af355
0
Parent(s):
Initial commit
Browse files- .gitattributes +38 -0
- .gitignore +5 -0
- LICENSE +26 -0
- README.md +78 -0
- assets/rgbx24_teaser.png +3 -0
- environment.yml +0 -0
- requirements.txt +10 -0
- rgb2x/example/Castlereagh_corridor_photo.png +3 -0
- rgb2x/gradio_demo_rgb2x.py +154 -0
- rgb2x/load_image.py +119 -0
- rgb2x/pipeline_rgb2x.py +821 -0
- x2rgb/example/kitchen-albedo.png +3 -0
- x2rgb/example/kitchen-irradiance.png +3 -0
- x2rgb/example/kitchen-metallic.png +3 -0
- x2rgb/example/kitchen-normal.png +3 -0
- x2rgb/example/kitchen-ref.png +3 -0
- x2rgb/example/kitchen-roughness.png +3 -0
- x2rgb/gradio_demo_x2rgb.py +204 -0
- x2rgb/load_image.py +119 -0
- x2rgb/pipeline_x2rgb.py +967 -0
.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
37 |
+
* text=auto
|
38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
|
3 |
+
.venv/
|
4 |
+
rgb2x/model_cache
|
5 |
+
x2rgb/model_cache
|
LICENSE
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ADOBE RESEARCH LICENSE
|
2 |
+
|
3 |
+
This license agreement (the βLicenseβ) between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 (βAdobeβ), and you, the individual or entity exercising rights under this License (βyouβ or βyourβ), sets forth the terms for your use of certain research materials that are owned by Adobe (the βLicensed Materialsβ). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this license on behalf of an entity, then βyouβ means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License.
|
4 |
+
|
5 |
+
1. **GRANT OF LICENSE.**
|
6 |
+
|
7 |
+
1.1. Adobe grants you a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License.
|
8 |
+
|
9 |
+
1.2. You may add your own copyright statement to your modifications and may provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only.
|
10 |
+
|
11 |
+
1.3. For purposes of this License, noncommercial research purposes include academic research and teaching but do not include commercial licensing or distribution, development of commercial products, or any other activity which results in commercial gain.
|
12 |
+
|
13 |
+
|
14 |
+
2. **OWNERSHIP AND ATTRIBUTION.** Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must keep intact any copyright or other notices or disclaimers in the Licensed Materials.
|
15 |
+
|
16 |
+
3. **DISCLAIMER OF WARRANTIES.** THE LICENSED MATERIALS ARE PROVIDED βAS ISβ WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE RESULTS AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT OF THIRD-PARTY RIGHTS.
|
17 |
+
|
18 |
+
4. **LIMITATION OF LIABILITY.** IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES OF ANY NATURE WHATSOEVER, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
19 |
+
|
20 |
+
5. **TERM AND TERMINATION.**
|
21 |
+
|
22 |
+
5.1. The License is effective upon acceptance by you and will remain in effect unless terminated earlier as permitted under this License.
|
23 |
+
|
24 |
+
5.2. If you breach any material provision of this License, then your rights will terminate immediately.
|
25 |
+
|
26 |
+
5.3. All clauses which by their nature should survive the termination of this License will survive such termination. In addition, and without limiting the generality of the preceding sentence, Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), and 4 (Limitation of Liability) will survive termination of this License.
|
README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Rgbx
|
3 |
+
emoji: π
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.5.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
<h1 align="center"> RGBβX: Image Decomposition and Synthesis Using Material- and Lighting-aware Diffusion Models </h1>
|
12 |
+
|
13 |
+
<p align="center"><a href="https://zheng95z.github.io/" target="_blank">Zheng Zeng</a>, <a href="https://valentin.deschaintre.fr/" target="_blank">Valentin Deschaintre</a>, <a href="https://www.iliyan.com/" target="_blank">Iliyan Georgiev</a>, <a href="https://yannickhold.com/" target="_blank">Yannick Hold-Geoffroy</a>, <a href="https://yiweihu.netlify.app/" target="_blank">Yiwei Hu</a>, <a href="https://luanfujun.com/" target="_blank">Fujun Luan</a>, <a href="https://sites.cs.ucsb.edu/~lingqi/" target="_blank">Ling-Qi Yan</a>, <a href="http://www.miloshasan.net/" target="_blank">MiloΕ‘ HaΕ‘an</a></p>
|
14 |
+
|
15 |
+
<p align="center">ACM SIGGRAPH 2024</p>
|
16 |
+
|
17 |
+
<p align="center"><img src="assets/rgbx24_teaser.png"></p>
|
18 |
+
|
19 |
+
The three areas of realistic forward rendering, per-pixel inverse rendering, and generative image synthesis may seem like separate and unrelated sub-fields of graphics and vision. However, recent work has demonstrated improved estimation of per-pixel intrinsic channels (albedo, roughness, metallicity) based on a diffusion architecture; we call this the RGBβX problem. We further show that the reverse problem of synthesizing realistic images given intrinsic channels, XβRGB, can also be addressed in a diffusion framework.
|
20 |
+
|
21 |
+
Focusing on the image domain of interior scenes, we introduce an improved diffusion model for RGBβX, which also estimates lighting, as well as the first diffusion XβRGB model capable of synthesizing realistic images from (full or partial) intrinsic channels. Our XβRGB model explores a middle ground between traditional rendering and generative models: we can specify only certain appearance properties that should be followed, and give freedom to the model to hallucinate a plausible version of the rest.
|
22 |
+
|
23 |
+
This flexibility makes it possible to use a mix of heterogeneous training datasets, which differ in the available channels. We use multiple existing datasets and extend them with our own synthetic and real data, resulting in a model capable of extracting scene properties better than previous work and of generating highly realistic images of interior scenes.
|
24 |
+
|
25 |
+
## Structure
|
26 |
+
```
|
27 |
+
βββ assets <- Assets used by the README.md
|
28 |
+
βββ rgb2x <- Code for the RGBβX model
|
29 |
+
β βββ example <- Example photo
|
30 |
+
β βββ model_cache <- Model weights (automatically downloaded when running the inference script)
|
31 |
+
βββ x2rgb <- Code for the XβRGB model
|
32 |
+
β βββ example <- Example photo
|
33 |
+
β βββ model_cache <- Model weights (automatically downloaded when running the inference script)
|
34 |
+
βββ environment.yaml <- Env file for creating conda environment
|
35 |
+
βββ LICENSE
|
36 |
+
βββ README.md
|
37 |
+
```
|
38 |
+
|
39 |
+
## Model Weights
|
40 |
+
You don't need to manually download the model weights. The weights will be downloaded automatically to `/rgb2x/model_cache/` and `/x2rgb/model_cache/` when you run the inference scripts.
|
41 |
+
|
42 |
+
You can manually acquire the weights by cloning the models from Hugging Face:
|
43 |
+
```bash
|
44 |
+
git-lfs install
|
45 |
+
git clone https://huggingface.co/zheng95z/x-to-rgb
|
46 |
+
git clone https://huggingface.co/zheng95z/rgb-to-x
|
47 |
+
```
|
48 |
+
|
49 |
+
## Installation
|
50 |
+
Create a conda environment using the provided `environment.yaml` file.
|
51 |
+
|
52 |
+
```bash
|
53 |
+
conda env create -n rgbx -f environment.yaml
|
54 |
+
conda activate rgbx
|
55 |
+
```
|
56 |
+
|
57 |
+
Note that this environment is only compatible with NVIDIA GPUs. Additionally, we recommend using a GPU with a minimum of 12GB of memory.
|
58 |
+
|
59 |
+
## Inference
|
60 |
+
When you run the inference scripts, gradio demos will be hosted on your local machine. You can access the demos by opening the URLs (shown in the terminal) in your browser.
|
61 |
+
|
62 |
+
### RGBβX
|
63 |
+
```bash
|
64 |
+
cd rgb2x
|
65 |
+
python gradio_demo_rgb2x.py
|
66 |
+
```
|
67 |
+
|
68 |
+
**Please note that the metallicity channel prediction might behave differently between the demo and the paper. This is because the demo utilizes a checkpoint that predicts roughness and metallicity separately, whereas in the paper, we used a checkpoint where the roughness and metallicity channels were combined into a single RGB image (with the blue channel set to 0). Unfortunately, the latter checkpoint was lost during the transition between computing platforms, and we apologize for the inconvenience. We plan to resolve this issue and will provide an updated demo in the near future.**
|
69 |
+
|
70 |
+
### XβRGB
|
71 |
+
```bash
|
72 |
+
cd x2rgb
|
73 |
+
python gradio_demo_x2rgb.py
|
74 |
+
```
|
75 |
+
|
76 |
+
## Acknowledgements
|
77 |
+
|
78 |
+
This implementation builds upon Hugging Faceβs [Diffusers](https://github.com/huggingface/diffusers) library. We also acknowledge [Gradio](https://www.gradio.app/) for providing an easy-to-use interface that allowed us to create the inference demos for our models.
|
assets/rgbx24_teaser.png
ADDED
![]() |
Git LFS Details
|
environment.yml
ADDED
Binary file (648 Bytes). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchaudio==2.5.1
|
3 |
+
torchvision==0.20.1
|
4 |
+
diffusers==0.20.0
|
5 |
+
gradio==5.5.0
|
6 |
+
imageio==2.34.1
|
7 |
+
numpy==1.26.4
|
8 |
+
opencv-python==4.9.0.80
|
9 |
+
transformers==4.40.2
|
10 |
+
spaces==0.30.4
|
rgb2x/example/Castlereagh_corridor_photo.png
ADDED
![]() |
Git LFS Details
|
rgb2x/gradio_demo_rgb2x.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
from typing import cast
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from diffusers import DDIMScheduler
|
9 |
+
from load_image import load_exr_image, load_ldr_image
|
10 |
+
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
|
11 |
+
|
12 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
13 |
+
|
14 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
|
16 |
+
_pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
|
17 |
+
"zheng95z/rgb-to-x",
|
18 |
+
torch_dtype=torch.float16,
|
19 |
+
cache_dir=os.path.join(current_directory, "model_cache"),
|
20 |
+
).to("cuda")
|
21 |
+
pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
|
22 |
+
pipe.scheduler = DDIMScheduler.from_config(
|
23 |
+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
24 |
+
)
|
25 |
+
pipe.set_progress_bar_config(disable=True)
|
26 |
+
pipe.to("cuda")
|
27 |
+
pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)
|
28 |
+
|
29 |
+
|
30 |
+
@spaces.GPU
|
31 |
+
def generate(
|
32 |
+
photo,
|
33 |
+
seed: int,
|
34 |
+
inference_step: int,
|
35 |
+
num_samples: int,
|
36 |
+
) -> list[Image.Image]:
|
37 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
38 |
+
|
39 |
+
if photo.name.endswith(".exr"):
|
40 |
+
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
|
41 |
+
elif (
|
42 |
+
photo.name.endswith(".png")
|
43 |
+
or photo.name.endswith(".jpg")
|
44 |
+
or photo.name.endswith(".jpeg")
|
45 |
+
):
|
46 |
+
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
|
47 |
+
|
48 |
+
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
|
49 |
+
old_height = photo.shape[1]
|
50 |
+
old_width = photo.shape[2]
|
51 |
+
new_height = old_height
|
52 |
+
new_width = old_width
|
53 |
+
radio = old_height / old_width
|
54 |
+
max_side = 1000
|
55 |
+
if old_height > old_width:
|
56 |
+
new_height = max_side
|
57 |
+
new_width = int(new_height / radio)
|
58 |
+
else:
|
59 |
+
new_width = max_side
|
60 |
+
new_height = int(new_width * radio)
|
61 |
+
|
62 |
+
if new_width % 8 != 0 or new_height % 8 != 0:
|
63 |
+
new_width = new_width // 8 * 8
|
64 |
+
new_height = new_height // 8 * 8
|
65 |
+
|
66 |
+
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
|
67 |
+
|
68 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
69 |
+
prompts = {
|
70 |
+
"albedo": "Albedo (diffuse basecolor)",
|
71 |
+
"normal": "Camera-space Normal",
|
72 |
+
"roughness": "Roughness",
|
73 |
+
"metallic": "Metallicness",
|
74 |
+
"irradiance": "Irradiance (diffuse lighting)",
|
75 |
+
}
|
76 |
+
|
77 |
+
return_list = []
|
78 |
+
for i in range(num_samples):
|
79 |
+
for aov_name in required_aovs:
|
80 |
+
prompt = prompts[aov_name]
|
81 |
+
generated_image = pipe(
|
82 |
+
prompt=prompt,
|
83 |
+
photo=photo,
|
84 |
+
num_inference_steps=inference_step,
|
85 |
+
height=new_height,
|
86 |
+
width=new_width,
|
87 |
+
generator=generator,
|
88 |
+
required_aovs=[aov_name],
|
89 |
+
).images[0][0] # type: ignore
|
90 |
+
|
91 |
+
generated_image = torchvision.transforms.Resize((old_height, old_width))(
|
92 |
+
generated_image
|
93 |
+
)
|
94 |
+
|
95 |
+
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
96 |
+
return_list.append(generated_image)
|
97 |
+
|
98 |
+
return return_list
|
99 |
+
|
100 |
+
|
101 |
+
with gr.Blocks() as demo:
|
102 |
+
with gr.Row():
|
103 |
+
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
|
104 |
+
with gr.Row():
|
105 |
+
# Input side
|
106 |
+
with gr.Column():
|
107 |
+
gr.Markdown("### Given Image")
|
108 |
+
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
|
109 |
+
|
110 |
+
gr.Markdown("### Parameters")
|
111 |
+
run_button = gr.Button(value="Run")
|
112 |
+
with gr.Accordion("Advanced options", open=False):
|
113 |
+
seed = gr.Slider(
|
114 |
+
label="Seed",
|
115 |
+
minimum=-1,
|
116 |
+
maximum=2147483647,
|
117 |
+
step=1,
|
118 |
+
randomize=True,
|
119 |
+
)
|
120 |
+
inference_step = gr.Slider(
|
121 |
+
label="Inference Step",
|
122 |
+
minimum=1,
|
123 |
+
maximum=100,
|
124 |
+
step=1,
|
125 |
+
value=50,
|
126 |
+
)
|
127 |
+
num_samples = gr.Slider(
|
128 |
+
label="Samples",
|
129 |
+
minimum=1,
|
130 |
+
maximum=100,
|
131 |
+
step=1,
|
132 |
+
value=1,
|
133 |
+
)
|
134 |
+
|
135 |
+
# Output side
|
136 |
+
with gr.Column():
|
137 |
+
gr.Markdown("### Output Gallery")
|
138 |
+
result_gallery = gr.Gallery(
|
139 |
+
label="Output",
|
140 |
+
show_label=False,
|
141 |
+
elem_id="gallery",
|
142 |
+
columns=2,
|
143 |
+
)
|
144 |
+
|
145 |
+
run_button.click(
|
146 |
+
fn=generate,
|
147 |
+
inputs=[photo, seed, inference_step, num_samples],
|
148 |
+
outputs=result_gallery,
|
149 |
+
queue=True,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.launch(debug=False, share=False, show_api=False)
|
rgb2x/load_image.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def convert_rgb_2_XYZ(rgb):
|
11 |
+
# Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
|
12 |
+
# rgb: (h, w, 3)
|
13 |
+
# XYZ: (h, w, 3)
|
14 |
+
XYZ = torch.ones_like(rgb)
|
15 |
+
XYZ[:, :, 0] = (
|
16 |
+
0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
|
17 |
+
)
|
18 |
+
XYZ[:, :, 1] = (
|
19 |
+
0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
|
20 |
+
)
|
21 |
+
XYZ[:, :, 2] = (
|
22 |
+
0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
|
23 |
+
)
|
24 |
+
return XYZ
|
25 |
+
|
26 |
+
|
27 |
+
def convert_XYZ_2_Yxy(XYZ):
|
28 |
+
# XYZ: (h, w, 3)
|
29 |
+
# Yxy: (h, w, 3)
|
30 |
+
Yxy = torch.ones_like(XYZ)
|
31 |
+
Yxy[:, :, 0] = XYZ[:, :, 1]
|
32 |
+
sum = torch.sum(XYZ, dim=2)
|
33 |
+
inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
|
34 |
+
Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
|
35 |
+
Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
|
36 |
+
return Yxy
|
37 |
+
|
38 |
+
|
39 |
+
def convert_rgb_2_Yxy(rgb):
|
40 |
+
# rgb: (h, w, 3)
|
41 |
+
# Yxy: (h, w, 3)
|
42 |
+
return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
|
43 |
+
|
44 |
+
|
45 |
+
def convert_XYZ_2_rgb(XYZ):
|
46 |
+
# XYZ: (h, w, 3)
|
47 |
+
# rgb: (h, w, 3)
|
48 |
+
rgb = torch.ones_like(XYZ)
|
49 |
+
rgb[:, :, 0] = (
|
50 |
+
3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
|
51 |
+
)
|
52 |
+
rgb[:, :, 1] = (
|
53 |
+
-0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
|
54 |
+
)
|
55 |
+
rgb[:, :, 2] = (
|
56 |
+
0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
|
57 |
+
)
|
58 |
+
return rgb
|
59 |
+
|
60 |
+
|
61 |
+
def convert_Yxy_2_XYZ(Yxy):
|
62 |
+
# Yxy: (h, w, 3)
|
63 |
+
# XYZ: (h, w, 3)
|
64 |
+
XYZ = torch.ones_like(Yxy)
|
65 |
+
XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
|
66 |
+
XYZ[:, :, 1] = Yxy[:, :, 0]
|
67 |
+
XYZ[:, :, 2] = (
|
68 |
+
(1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
|
69 |
+
/ torch.clamp(Yxy[:, :, 2], min=1e-4)
|
70 |
+
* Yxy[:, :, 0]
|
71 |
+
)
|
72 |
+
return XYZ
|
73 |
+
|
74 |
+
|
75 |
+
def convert_Yxy_2_rgb(Yxy):
|
76 |
+
# Yxy: (h, w, 3)
|
77 |
+
# rgb: (h, w, 3)
|
78 |
+
return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
|
79 |
+
|
80 |
+
|
81 |
+
def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
|
82 |
+
# Load png or jpg image
|
83 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
84 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
|
85 |
+
image[~torch.isfinite(image)] = 0
|
86 |
+
if from_srgb:
|
87 |
+
# Convert from sRGB to linear RGB
|
88 |
+
image = image**2.2
|
89 |
+
if clamp:
|
90 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
91 |
+
if normalize:
|
92 |
+
# Normalize to [-1, 1]
|
93 |
+
image = image * 2.0 - 1.0
|
94 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
95 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
96 |
+
|
97 |
+
|
98 |
+
def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
|
99 |
+
image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
|
100 |
+
image = torch.from_numpy(image.astype("float32")) # (h, w, c)
|
101 |
+
image[~torch.isfinite(image)] = 0
|
102 |
+
if tonemaping:
|
103 |
+
# Exposure adjuestment
|
104 |
+
image_Yxy = convert_rgb_2_Yxy(image)
|
105 |
+
lum = (
|
106 |
+
image[:, :, 0:1] * 0.2125
|
107 |
+
+ image[:, :, 1:2] * 0.7154
|
108 |
+
+ image[:, :, 2:3] * 0.0721
|
109 |
+
)
|
110 |
+
lum = torch.log(torch.clamp(lum, min=1e-6))
|
111 |
+
lum_mean = torch.exp(torch.mean(lum))
|
112 |
+
lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
|
113 |
+
image_Yxy[:, :, 0:1] = lp
|
114 |
+
image = convert_Yxy_2_rgb(image_Yxy)
|
115 |
+
if clamp:
|
116 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
117 |
+
if normalize:
|
118 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
119 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
rgb2x/pipeline_rgb2x.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
from diffusers.configuration_utils import register_to_config
|
9 |
+
from diffusers.image_processor import VaeImageProcessor
|
10 |
+
from diffusers.loaders import (
|
11 |
+
LoraLoaderMixin,
|
12 |
+
TextualInversionLoaderMixin,
|
13 |
+
)
|
14 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
15 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
16 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
17 |
+
rescale_noise_cfg,
|
18 |
+
)
|
19 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
20 |
+
from diffusers.utils import (
|
21 |
+
CONFIG_NAME,
|
22 |
+
BaseOutput,
|
23 |
+
deprecate,
|
24 |
+
logging,
|
25 |
+
)
|
26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
27 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class VaeImageProcrssorAOV(VaeImageProcessor):
|
33 |
+
"""
|
34 |
+
Image processor for VAE AOV.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
38 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
39 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
40 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
41 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
42 |
+
Resampling filter to use when resizing the image.
|
43 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
44 |
+
Whether to normalize the image to [-1,1].
|
45 |
+
"""
|
46 |
+
|
47 |
+
config_name = CONFIG_NAME
|
48 |
+
|
49 |
+
@register_to_config
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
do_resize: bool = True,
|
53 |
+
vae_scale_factor: int = 8,
|
54 |
+
resample: str = "lanczos",
|
55 |
+
do_normalize: bool = True,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
def postprocess(
|
60 |
+
self,
|
61 |
+
image: torch.FloatTensor,
|
62 |
+
output_type: str = "pil",
|
63 |
+
do_denormalize: Optional[List[bool]] = None,
|
64 |
+
do_gamma_correction: bool = True,
|
65 |
+
):
|
66 |
+
if not isinstance(image, torch.Tensor):
|
67 |
+
raise ValueError(
|
68 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
69 |
+
)
|
70 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
71 |
+
deprecation_message = (
|
72 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
73 |
+
"`pil`, `np`, `pt`, `latent`"
|
74 |
+
)
|
75 |
+
deprecate(
|
76 |
+
"Unsupported output_type",
|
77 |
+
"1.0.0",
|
78 |
+
deprecation_message,
|
79 |
+
standard_warn=False,
|
80 |
+
)
|
81 |
+
output_type = "np"
|
82 |
+
|
83 |
+
if output_type == "latent":
|
84 |
+
return image
|
85 |
+
|
86 |
+
if do_denormalize is None:
|
87 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
88 |
+
|
89 |
+
image = torch.stack(
|
90 |
+
[
|
91 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
92 |
+
for i in range(image.shape[0])
|
93 |
+
]
|
94 |
+
)
|
95 |
+
|
96 |
+
# Gamma correction
|
97 |
+
if do_gamma_correction:
|
98 |
+
image = torch.pow(image, 1.0 / 2.2)
|
99 |
+
|
100 |
+
if output_type == "pt":
|
101 |
+
return image
|
102 |
+
|
103 |
+
image = self.pt_to_numpy(image)
|
104 |
+
|
105 |
+
if output_type == "np":
|
106 |
+
return image
|
107 |
+
|
108 |
+
if output_type == "pil":
|
109 |
+
return self.numpy_to_pil(image)
|
110 |
+
|
111 |
+
def preprocess_normal(
|
112 |
+
self,
|
113 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
114 |
+
height: Optional[int] = None,
|
115 |
+
width: Optional[int] = None,
|
116 |
+
) -> torch.Tensor:
|
117 |
+
image = torch.stack([image], axis=0)
|
118 |
+
return image
|
119 |
+
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
class StableDiffusionAOVPipelineOutput(BaseOutput):
|
123 |
+
"""
|
124 |
+
Output class for Stable Diffusion AOV pipelines.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
128 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
129 |
+
num_channels)`.
|
130 |
+
nsfw_content_detected (`List[bool]`)
|
131 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
132 |
+
`None` if safety checking could not be performed.
|
133 |
+
"""
|
134 |
+
|
135 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
136 |
+
|
137 |
+
|
138 |
+
class StableDiffusionAOVMatEstPipeline(
|
139 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
140 |
+
):
|
141 |
+
r"""
|
142 |
+
Pipeline for AOVs.
|
143 |
+
|
144 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
145 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
146 |
+
|
147 |
+
The pipeline also inherits the following loading methods:
|
148 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
149 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
150 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
151 |
+
|
152 |
+
Args:
|
153 |
+
vae ([`AutoencoderKL`]):
|
154 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
155 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
156 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
157 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
158 |
+
A `CLIPTokenizer` to tokenize text.
|
159 |
+
unet ([`UNet2DConditionModel`]):
|
160 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
161 |
+
scheduler ([`SchedulerMixin`]):
|
162 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
163 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
164 |
+
"""
|
165 |
+
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
vae: AutoencoderKL,
|
169 |
+
text_encoder: CLIPTextModel,
|
170 |
+
tokenizer: CLIPTokenizer,
|
171 |
+
unet: UNet2DConditionModel,
|
172 |
+
scheduler: KarrasDiffusionSchedulers,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
|
176 |
+
self.register_modules(
|
177 |
+
vae=vae,
|
178 |
+
text_encoder=text_encoder,
|
179 |
+
tokenizer=tokenizer,
|
180 |
+
unet=unet,
|
181 |
+
scheduler=scheduler,
|
182 |
+
)
|
183 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
184 |
+
self.image_processor = VaeImageProcrssorAOV(
|
185 |
+
vae_scale_factor=self.vae_scale_factor
|
186 |
+
)
|
187 |
+
self.register_to_config()
|
188 |
+
|
189 |
+
def _encode_prompt(
|
190 |
+
self,
|
191 |
+
prompt,
|
192 |
+
device,
|
193 |
+
num_images_per_prompt,
|
194 |
+
do_classifier_free_guidance,
|
195 |
+
negative_prompt=None,
|
196 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
197 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
198 |
+
):
|
199 |
+
r"""
|
200 |
+
Encodes the prompt into text encoder hidden states.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
prompt (`str` or `List[str]`, *optional*):
|
204 |
+
prompt to be encoded
|
205 |
+
device: (`torch.device`):
|
206 |
+
torch device
|
207 |
+
num_images_per_prompt (`int`):
|
208 |
+
number of images that should be generated per prompt
|
209 |
+
do_classifier_free_guidance (`bool`):
|
210 |
+
whether to use classifier free guidance or not
|
211 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
212 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
213 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
214 |
+
less than `1`).
|
215 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
216 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
217 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
218 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
219 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
220 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
221 |
+
argument.
|
222 |
+
"""
|
223 |
+
if prompt is not None and isinstance(prompt, str):
|
224 |
+
batch_size = 1
|
225 |
+
elif prompt is not None and isinstance(prompt, list):
|
226 |
+
batch_size = len(prompt)
|
227 |
+
else:
|
228 |
+
batch_size = prompt_embeds.shape[0]
|
229 |
+
|
230 |
+
if prompt_embeds is None:
|
231 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
232 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
233 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
234 |
+
|
235 |
+
text_inputs = self.tokenizer(
|
236 |
+
prompt,
|
237 |
+
padding="max_length",
|
238 |
+
max_length=self.tokenizer.model_max_length,
|
239 |
+
truncation=True,
|
240 |
+
return_tensors="pt",
|
241 |
+
)
|
242 |
+
text_input_ids = text_inputs.input_ids
|
243 |
+
untruncated_ids = self.tokenizer(
|
244 |
+
prompt, padding="longest", return_tensors="pt"
|
245 |
+
).input_ids
|
246 |
+
|
247 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
248 |
+
-1
|
249 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
250 |
+
removed_text = self.tokenizer.batch_decode(
|
251 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
252 |
+
)
|
253 |
+
logger.warning(
|
254 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
255 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
256 |
+
)
|
257 |
+
|
258 |
+
if (
|
259 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
260 |
+
and self.text_encoder.config.use_attention_mask
|
261 |
+
):
|
262 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
263 |
+
else:
|
264 |
+
attention_mask = None
|
265 |
+
|
266 |
+
prompt_embeds = self.text_encoder(
|
267 |
+
text_input_ids.to(device),
|
268 |
+
attention_mask=attention_mask,
|
269 |
+
)
|
270 |
+
prompt_embeds = prompt_embeds[0]
|
271 |
+
|
272 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
273 |
+
|
274 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
275 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
276 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
277 |
+
prompt_embeds = prompt_embeds.view(
|
278 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
279 |
+
)
|
280 |
+
|
281 |
+
# get unconditional embeddings for classifier free guidance
|
282 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
283 |
+
uncond_tokens: List[str]
|
284 |
+
if negative_prompt is None:
|
285 |
+
uncond_tokens = [""] * batch_size
|
286 |
+
elif type(prompt) is not type(negative_prompt):
|
287 |
+
raise TypeError(
|
288 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
289 |
+
f" {type(prompt)}."
|
290 |
+
)
|
291 |
+
elif isinstance(negative_prompt, str):
|
292 |
+
uncond_tokens = [negative_prompt]
|
293 |
+
elif batch_size != len(negative_prompt):
|
294 |
+
raise ValueError(
|
295 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
296 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
297 |
+
" the batch size of `prompt`."
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
uncond_tokens = negative_prompt
|
301 |
+
|
302 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
303 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
304 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
305 |
+
|
306 |
+
max_length = prompt_embeds.shape[1]
|
307 |
+
uncond_input = self.tokenizer(
|
308 |
+
uncond_tokens,
|
309 |
+
padding="max_length",
|
310 |
+
max_length=max_length,
|
311 |
+
truncation=True,
|
312 |
+
return_tensors="pt",
|
313 |
+
)
|
314 |
+
|
315 |
+
if (
|
316 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
317 |
+
and self.text_encoder.config.use_attention_mask
|
318 |
+
):
|
319 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
320 |
+
else:
|
321 |
+
attention_mask = None
|
322 |
+
|
323 |
+
negative_prompt_embeds = self.text_encoder(
|
324 |
+
uncond_input.input_ids.to(device),
|
325 |
+
attention_mask=attention_mask,
|
326 |
+
)
|
327 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
328 |
+
|
329 |
+
if do_classifier_free_guidance:
|
330 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
331 |
+
seq_len = negative_prompt_embeds.shape[1]
|
332 |
+
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
334 |
+
dtype=self.text_encoder.dtype, device=device
|
335 |
+
)
|
336 |
+
|
337 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
338 |
+
1, num_images_per_prompt, 1
|
339 |
+
)
|
340 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
341 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
342 |
+
)
|
343 |
+
|
344 |
+
# For classifier free guidance, we need to do two forward passes.
|
345 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
346 |
+
# to avoid doing two forward passes
|
347 |
+
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
348 |
+
prompt_embeds = torch.cat(
|
349 |
+
[prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
350 |
+
)
|
351 |
+
|
352 |
+
return prompt_embeds
|
353 |
+
|
354 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
355 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
356 |
+
# eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
357 |
+
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
358 |
+
# and should be between [0, 1]
|
359 |
+
|
360 |
+
accepts_eta = "eta" in set(
|
361 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
362 |
+
)
|
363 |
+
extra_step_kwargs = {}
|
364 |
+
if accepts_eta:
|
365 |
+
extra_step_kwargs["eta"] = eta
|
366 |
+
|
367 |
+
# check if the scheduler accepts generator
|
368 |
+
accepts_generator = "generator" in set(
|
369 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
370 |
+
)
|
371 |
+
if accepts_generator:
|
372 |
+
extra_step_kwargs["generator"] = generator
|
373 |
+
return extra_step_kwargs
|
374 |
+
|
375 |
+
def check_inputs(
|
376 |
+
self,
|
377 |
+
prompt,
|
378 |
+
callback_steps,
|
379 |
+
negative_prompt=None,
|
380 |
+
prompt_embeds=None,
|
381 |
+
negative_prompt_embeds=None,
|
382 |
+
):
|
383 |
+
if (callback_steps is None) or (
|
384 |
+
callback_steps is not None
|
385 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
386 |
+
):
|
387 |
+
raise ValueError(
|
388 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
389 |
+
f" {type(callback_steps)}."
|
390 |
+
)
|
391 |
+
|
392 |
+
if prompt is not None and prompt_embeds is not None:
|
393 |
+
raise ValueError(
|
394 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
395 |
+
" only forward one of the two."
|
396 |
+
)
|
397 |
+
elif prompt is None and prompt_embeds is None:
|
398 |
+
raise ValueError(
|
399 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
400 |
+
)
|
401 |
+
elif prompt is not None and (
|
402 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
403 |
+
):
|
404 |
+
raise ValueError(
|
405 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
406 |
+
)
|
407 |
+
|
408 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
409 |
+
raise ValueError(
|
410 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
411 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
412 |
+
)
|
413 |
+
|
414 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
415 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
416 |
+
raise ValueError(
|
417 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
418 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
419 |
+
f" {negative_prompt_embeds.shape}."
|
420 |
+
)
|
421 |
+
|
422 |
+
def prepare_latents(
|
423 |
+
self,
|
424 |
+
batch_size,
|
425 |
+
num_channels_latents,
|
426 |
+
height,
|
427 |
+
width,
|
428 |
+
dtype,
|
429 |
+
device,
|
430 |
+
generator,
|
431 |
+
latents=None,
|
432 |
+
):
|
433 |
+
shape = (
|
434 |
+
batch_size,
|
435 |
+
num_channels_latents,
|
436 |
+
height // self.vae_scale_factor,
|
437 |
+
width // self.vae_scale_factor,
|
438 |
+
)
|
439 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
440 |
+
raise ValueError(
|
441 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
442 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
443 |
+
)
|
444 |
+
|
445 |
+
if latents is None:
|
446 |
+
latents = randn_tensor(
|
447 |
+
shape, generator=generator, device=device, dtype=dtype
|
448 |
+
)
|
449 |
+
else:
|
450 |
+
latents = latents.to(device)
|
451 |
+
|
452 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
453 |
+
latents = latents * self.scheduler.init_noise_sigma
|
454 |
+
return latents
|
455 |
+
|
456 |
+
def prepare_image_latents(
|
457 |
+
self,
|
458 |
+
image,
|
459 |
+
batch_size,
|
460 |
+
num_images_per_prompt,
|
461 |
+
dtype,
|
462 |
+
device,
|
463 |
+
do_classifier_free_guidance,
|
464 |
+
generator=None,
|
465 |
+
):
|
466 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
467 |
+
raise ValueError(
|
468 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
469 |
+
)
|
470 |
+
|
471 |
+
image = image.to(device=device, dtype=dtype)
|
472 |
+
|
473 |
+
batch_size = batch_size * num_images_per_prompt
|
474 |
+
|
475 |
+
if image.shape[1] == 4:
|
476 |
+
image_latents = image
|
477 |
+
else:
|
478 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
479 |
+
raise ValueError(
|
480 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
481 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
482 |
+
)
|
483 |
+
|
484 |
+
if isinstance(generator, list):
|
485 |
+
image_latents = [
|
486 |
+
self.vae.encode(image[i : i + 1]).latent_dist.mode()
|
487 |
+
for i in range(batch_size)
|
488 |
+
]
|
489 |
+
image_latents = torch.cat(image_latents, dim=0)
|
490 |
+
else:
|
491 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
492 |
+
|
493 |
+
if (
|
494 |
+
batch_size > image_latents.shape[0]
|
495 |
+
and batch_size % image_latents.shape[0] == 0
|
496 |
+
):
|
497 |
+
# expand image_latents for batch_size
|
498 |
+
deprecation_message = (
|
499 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
500 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
501 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
502 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
503 |
+
)
|
504 |
+
deprecate(
|
505 |
+
"len(prompt) != len(image)",
|
506 |
+
"1.0.0",
|
507 |
+
deprecation_message,
|
508 |
+
standard_warn=False,
|
509 |
+
)
|
510 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
511 |
+
image_latents = torch.cat(
|
512 |
+
[image_latents] * additional_image_per_prompt, dim=0
|
513 |
+
)
|
514 |
+
elif (
|
515 |
+
batch_size > image_latents.shape[0]
|
516 |
+
and batch_size % image_latents.shape[0] != 0
|
517 |
+
):
|
518 |
+
raise ValueError(
|
519 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
image_latents = torch.cat([image_latents], dim=0)
|
523 |
+
|
524 |
+
if do_classifier_free_guidance:
|
525 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
526 |
+
image_latents = torch.cat(
|
527 |
+
[image_latents, image_latents, uncond_image_latents], dim=0
|
528 |
+
)
|
529 |
+
|
530 |
+
return image_latents
|
531 |
+
|
532 |
+
@torch.no_grad()
|
533 |
+
def __call__(
|
534 |
+
self,
|
535 |
+
prompt: Union[str, List[str]] = None,
|
536 |
+
photo: Union[
|
537 |
+
torch.FloatTensor,
|
538 |
+
PIL.Image.Image,
|
539 |
+
np.ndarray,
|
540 |
+
List[torch.FloatTensor],
|
541 |
+
List[PIL.Image.Image],
|
542 |
+
List[np.ndarray],
|
543 |
+
] = None,
|
544 |
+
height: Optional[int] = None,
|
545 |
+
width: Optional[int] = None,
|
546 |
+
num_inference_steps: int = 100,
|
547 |
+
required_aovs: List[str] = ["albedo"],
|
548 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
549 |
+
num_images_per_prompt: Optional[int] = 1,
|
550 |
+
use_default_scaling_factor: Optional[bool] = False,
|
551 |
+
guidance_scale: float = 0.0,
|
552 |
+
image_guidance_scale: float = 0.0,
|
553 |
+
guidance_rescale: float = 0.0,
|
554 |
+
eta: float = 0.0,
|
555 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
556 |
+
latents: Optional[torch.FloatTensor] = None,
|
557 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
558 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
559 |
+
output_type: Optional[str] = "pil",
|
560 |
+
return_dict: bool = True,
|
561 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
562 |
+
callback_steps: int = 1,
|
563 |
+
):
|
564 |
+
r"""
|
565 |
+
The call function to the pipeline for generation.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
prompt (`str` or `List[str]`, *optional*):
|
569 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
570 |
+
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
571 |
+
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
572 |
+
image latents as `image`, but if passing latents directly it is not encoded again.
|
573 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
574 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
575 |
+
expense of slower inference.
|
576 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
577 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
578 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
579 |
+
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
580 |
+
Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
|
581 |
+
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
582 |
+
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
583 |
+
value of at least `1`.
|
584 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
585 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
586 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
587 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
588 |
+
The number of images to generate per prompt.
|
589 |
+
eta (`float`, *optional*, defaults to 0.0):
|
590 |
+
Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
591 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
592 |
+
generator (`torch.Generator`, *optional*):
|
593 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
594 |
+
generation deterministic.
|
595 |
+
latents (`torch.FloatTensor`, *optional*):
|
596 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
597 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
598 |
+
tensor is generated by sampling using the supplied random `generator`.
|
599 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
600 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
601 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
602 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
603 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
604 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
605 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
606 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
607 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
608 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
609 |
+
plain tuple.
|
610 |
+
callback (`Callable`, *optional*):
|
611 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
612 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
613 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
614 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
615 |
+
every step.
|
616 |
+
|
617 |
+
Examples:
|
618 |
+
|
619 |
+
```py
|
620 |
+
>>> import PIL
|
621 |
+
>>> import requests
|
622 |
+
>>> import torch
|
623 |
+
>>> from io import BytesIO
|
624 |
+
|
625 |
+
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
626 |
+
|
627 |
+
|
628 |
+
>>> def download_image(url):
|
629 |
+
... response = requests.get(url)
|
630 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
631 |
+
|
632 |
+
|
633 |
+
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
634 |
+
|
635 |
+
>>> image = download_image(img_url).resize((512, 512))
|
636 |
+
|
637 |
+
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
638 |
+
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
639 |
+
... )
|
640 |
+
>>> pipe = pipe.to("cuda")
|
641 |
+
|
642 |
+
>>> prompt = "make the mountains snowy"
|
643 |
+
>>> image = pipe(prompt=prompt, image=image).images[0]
|
644 |
+
```
|
645 |
+
|
646 |
+
Returns:
|
647 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
648 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
649 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
650 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
651 |
+
"not-safe-for-work" (nsfw) content.
|
652 |
+
"""
|
653 |
+
# 0. Check inputs
|
654 |
+
self.check_inputs(
|
655 |
+
prompt,
|
656 |
+
callback_steps,
|
657 |
+
negative_prompt,
|
658 |
+
prompt_embeds,
|
659 |
+
negative_prompt_embeds,
|
660 |
+
)
|
661 |
+
|
662 |
+
# 1. Define call parameters
|
663 |
+
if prompt is not None and isinstance(prompt, str):
|
664 |
+
batch_size = 1
|
665 |
+
elif prompt is not None and isinstance(prompt, list):
|
666 |
+
batch_size = len(prompt)
|
667 |
+
else:
|
668 |
+
batch_size = prompt_embeds.shape[0]
|
669 |
+
|
670 |
+
device = self._execution_device
|
671 |
+
do_classifier_free_guidance = (
|
672 |
+
guidance_scale > 1.0 and image_guidance_scale >= 1.0
|
673 |
+
)
|
674 |
+
# check if scheduler is in sigmas space
|
675 |
+
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
676 |
+
|
677 |
+
# 2. Encode input prompt
|
678 |
+
prompt_embeds = self._encode_prompt(
|
679 |
+
prompt,
|
680 |
+
device,
|
681 |
+
num_images_per_prompt,
|
682 |
+
do_classifier_free_guidance,
|
683 |
+
negative_prompt,
|
684 |
+
prompt_embeds=prompt_embeds,
|
685 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
686 |
+
)
|
687 |
+
|
688 |
+
# 3. Preprocess image
|
689 |
+
# Normalize image to [-1,1]
|
690 |
+
preprocessed_photo = self.image_processor.preprocess(photo)
|
691 |
+
|
692 |
+
# 4. set timesteps
|
693 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
694 |
+
timesteps = self.scheduler.timesteps
|
695 |
+
|
696 |
+
# 5. Prepare Image latents
|
697 |
+
image_latents = self.prepare_image_latents(
|
698 |
+
preprocessed_photo,
|
699 |
+
batch_size,
|
700 |
+
num_images_per_prompt,
|
701 |
+
prompt_embeds.dtype,
|
702 |
+
device,
|
703 |
+
do_classifier_free_guidance,
|
704 |
+
generator,
|
705 |
+
)
|
706 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
707 |
+
|
708 |
+
height, width = image_latents.shape[-2:]
|
709 |
+
height = height * self.vae_scale_factor
|
710 |
+
width = width * self.vae_scale_factor
|
711 |
+
|
712 |
+
# 6. Prepare latent variables
|
713 |
+
num_channels_latents = self.unet.config.out_channels
|
714 |
+
latents = self.prepare_latents(
|
715 |
+
batch_size * num_images_per_prompt,
|
716 |
+
num_channels_latents,
|
717 |
+
height,
|
718 |
+
width,
|
719 |
+
prompt_embeds.dtype,
|
720 |
+
device,
|
721 |
+
generator,
|
722 |
+
latents,
|
723 |
+
)
|
724 |
+
|
725 |
+
# 7. Check that shapes of latents and image match the UNet channels
|
726 |
+
num_channels_image = image_latents.shape[1]
|
727 |
+
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
728 |
+
raise ValueError(
|
729 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
730 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
731 |
+
f" `num_channels_image`: {num_channels_image} "
|
732 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
733 |
+
" `pipeline.unet` or your `image` input."
|
734 |
+
)
|
735 |
+
|
736 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
737 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
738 |
+
|
739 |
+
# 9. Denoising loop
|
740 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
741 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
742 |
+
for i, t in enumerate(timesteps):
|
743 |
+
# Expand the latents if we are doing classifier free guidance.
|
744 |
+
# The latents are expanded 3 times because for pix2pix the guidance\
|
745 |
+
# is applied for both the text and the input image.
|
746 |
+
latent_model_input = (
|
747 |
+
torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
748 |
+
)
|
749 |
+
|
750 |
+
# concat latents, image_latents in the channel dimension
|
751 |
+
scaled_latent_model_input = self.scheduler.scale_model_input(
|
752 |
+
latent_model_input, t
|
753 |
+
)
|
754 |
+
scaled_latent_model_input = torch.cat(
|
755 |
+
[scaled_latent_model_input, image_latents], dim=1
|
756 |
+
)
|
757 |
+
|
758 |
+
# predict the noise residual
|
759 |
+
noise_pred = self.unet(
|
760 |
+
scaled_latent_model_input,
|
761 |
+
t,
|
762 |
+
encoder_hidden_states=prompt_embeds,
|
763 |
+
return_dict=False,
|
764 |
+
)[0]
|
765 |
+
|
766 |
+
# perform guidance
|
767 |
+
if do_classifier_free_guidance:
|
768 |
+
(
|
769 |
+
noise_pred_text,
|
770 |
+
noise_pred_image,
|
771 |
+
noise_pred_uncond,
|
772 |
+
) = noise_pred.chunk(3)
|
773 |
+
noise_pred = (
|
774 |
+
noise_pred_uncond
|
775 |
+
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
776 |
+
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
777 |
+
)
|
778 |
+
|
779 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
780 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
781 |
+
noise_pred = rescale_noise_cfg(
|
782 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
783 |
+
)
|
784 |
+
|
785 |
+
# compute the previous noisy sample x_t -> x_t-1
|
786 |
+
latents = self.scheduler.step(
|
787 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
788 |
+
)[0]
|
789 |
+
|
790 |
+
# call the callback, if provided
|
791 |
+
if i == len(timesteps) - 1 or (
|
792 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
793 |
+
):
|
794 |
+
progress_bar.update()
|
795 |
+
if callback is not None and i % callback_steps == 0:
|
796 |
+
callback(i, t, latents)
|
797 |
+
|
798 |
+
aov_latents = latents / self.vae.config.scaling_factor
|
799 |
+
aov = self.vae.decode(aov_latents, return_dict=False)[0]
|
800 |
+
do_denormalize = [True] * aov.shape[0]
|
801 |
+
aov_name = required_aovs[0]
|
802 |
+
if aov_name == "albedo" or aov_name == "irradiance":
|
803 |
+
do_gamma_correction = True
|
804 |
+
else:
|
805 |
+
do_gamma_correction = False
|
806 |
+
|
807 |
+
if aov_name == "roughness" or aov_name == "metallic":
|
808 |
+
aov = aov[:, 0:1].repeat(1, 3, 1, 1)
|
809 |
+
|
810 |
+
aov = self.image_processor.postprocess(
|
811 |
+
aov,
|
812 |
+
output_type=output_type,
|
813 |
+
do_denormalize=do_denormalize,
|
814 |
+
do_gamma_correction=do_gamma_correction,
|
815 |
+
)
|
816 |
+
aovs = [aov]
|
817 |
+
|
818 |
+
# Offload last model to CPU
|
819 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
820 |
+
self.final_offload_hook.offload()
|
821 |
+
return StableDiffusionAOVPipelineOutput(images=aovs)
|
x2rgb/example/kitchen-albedo.png
ADDED
![]() |
Git LFS Details
|
x2rgb/example/kitchen-irradiance.png
ADDED
![]() |
Git LFS Details
|
x2rgb/example/kitchen-metallic.png
ADDED
![]() |
Git LFS Details
|
x2rgb/example/kitchen-normal.png
ADDED
![]() |
Git LFS Details
|
x2rgb/example/kitchen-ref.png
ADDED
![]() |
Git LFS Details
|
x2rgb/example/kitchen-roughness.png
ADDED
![]() |
Git LFS Details
|
x2rgb/gradio_demo_x2rgb.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
from typing import cast
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from diffusers import DDIMScheduler
|
9 |
+
from load_image import load_exr_image, load_ldr_image
|
10 |
+
from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline
|
11 |
+
|
12 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
13 |
+
|
14 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
|
16 |
+
_pipe = StableDiffusionAOVDropoutPipeline.from_pretrained(
|
17 |
+
"zheng95z/x-to-rgb",
|
18 |
+
torch_dtype=torch.float16,
|
19 |
+
cache_dir=os.path.join(current_directory, "model_cache"),
|
20 |
+
).to("cuda")
|
21 |
+
pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe)
|
22 |
+
pipe.scheduler = DDIMScheduler.from_config(
|
23 |
+
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
24 |
+
)
|
25 |
+
pipe.set_progress_bar_config(disable=True)
|
26 |
+
pipe.to("cuda")
|
27 |
+
pipe = cast(StableDiffusionAOVDropoutPipeline, pipe)
|
28 |
+
|
29 |
+
|
30 |
+
@spaces.GPU
|
31 |
+
def generate(
|
32 |
+
albedo,
|
33 |
+
normal,
|
34 |
+
roughness,
|
35 |
+
metallic,
|
36 |
+
irradiance,
|
37 |
+
prompt: str,
|
38 |
+
seed: int,
|
39 |
+
inference_step: int,
|
40 |
+
num_samples: int,
|
41 |
+
guidance_scale: float,
|
42 |
+
image_guidance_scale: float,
|
43 |
+
) -> list[Image.Image]:
|
44 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
45 |
+
|
46 |
+
# Load and process each intrinsic channel image
|
47 |
+
def process_image(file, **kwargs):
|
48 |
+
if file is None:
|
49 |
+
return None
|
50 |
+
if file.name.endswith(".exr"):
|
51 |
+
return load_exr_image(file.name, **kwargs).to("cuda")
|
52 |
+
elif file.name.endswith((".png", ".jpg", ".jpeg")):
|
53 |
+
return load_ldr_image(file.name, **kwargs).to("cuda")
|
54 |
+
return None
|
55 |
+
|
56 |
+
albedo_image = process_image(albedo, clamp=True)
|
57 |
+
normal_image = process_image(normal, normalize=True)
|
58 |
+
roughness_image = process_image(roughness, clamp=True)
|
59 |
+
metallic_image = process_image(metallic, clamp=True)
|
60 |
+
irradiance_image = process_image(irradiance, tonemaping=True, clamp=True)
|
61 |
+
|
62 |
+
# Set default height and width based on the first available image
|
63 |
+
height, width = 768, 768
|
64 |
+
for img in [
|
65 |
+
albedo_image,
|
66 |
+
normal_image,
|
67 |
+
roughness_image,
|
68 |
+
metallic_image,
|
69 |
+
irradiance_image,
|
70 |
+
]:
|
71 |
+
if img is not None:
|
72 |
+
height, width = img.shape[1], img.shape[2]
|
73 |
+
break
|
74 |
+
|
75 |
+
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
|
76 |
+
return_list = []
|
77 |
+
|
78 |
+
for i in range(num_samples):
|
79 |
+
generated_image = pipe(
|
80 |
+
prompt=prompt,
|
81 |
+
albedo=albedo_image,
|
82 |
+
normal=normal_image,
|
83 |
+
roughness=roughness_image,
|
84 |
+
metallic=metallic_image,
|
85 |
+
irradiance=irradiance_image,
|
86 |
+
num_inference_steps=inference_step,
|
87 |
+
height=height,
|
88 |
+
width=width,
|
89 |
+
generator=generator,
|
90 |
+
required_aovs=required_aovs,
|
91 |
+
guidance_scale=guidance_scale,
|
92 |
+
image_guidance_scale=image_guidance_scale,
|
93 |
+
guidance_rescale=0.7,
|
94 |
+
output_type="np",
|
95 |
+
).images[0] # type: ignore
|
96 |
+
|
97 |
+
return_list.append((generated_image, f"Generated Image {i}"))
|
98 |
+
|
99 |
+
# Append additional images to the output gallery
|
100 |
+
def post_process_image(img, **kwargs):
|
101 |
+
if img is not None:
|
102 |
+
return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
|
103 |
+
return np.zeros((height, width, 3))
|
104 |
+
|
105 |
+
return_list.extend(
|
106 |
+
[
|
107 |
+
post_process_image(albedo_image, label="Albedo"),
|
108 |
+
post_process_image(normal_image, label="Normal"),
|
109 |
+
post_process_image(roughness_image, label="Roughness"),
|
110 |
+
post_process_image(metallic_image, label="Metallic"),
|
111 |
+
post_process_image(irradiance_image, label="Irradiance"),
|
112 |
+
]
|
113 |
+
)
|
114 |
+
|
115 |
+
return return_list
|
116 |
+
|
117 |
+
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
with gr.Row():
|
120 |
+
gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)")
|
121 |
+
with gr.Row():
|
122 |
+
# Input side
|
123 |
+
with gr.Column():
|
124 |
+
gr.Markdown("### Given intrinsic channels")
|
125 |
+
albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"])
|
126 |
+
normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"])
|
127 |
+
roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"])
|
128 |
+
metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"])
|
129 |
+
irradiance = gr.File(
|
130 |
+
label="Irradiance", file_types=[".exr", ".png", ".jpg"]
|
131 |
+
)
|
132 |
+
|
133 |
+
gr.Markdown("### Parameters")
|
134 |
+
prompt = gr.Textbox(label="Prompt")
|
135 |
+
run_button = gr.Button(value="Run")
|
136 |
+
with gr.Accordion("Advanced options", open=False):
|
137 |
+
seed = gr.Slider(
|
138 |
+
label="Seed",
|
139 |
+
minimum=-1,
|
140 |
+
maximum=2147483647,
|
141 |
+
step=1,
|
142 |
+
randomize=True,
|
143 |
+
)
|
144 |
+
inference_step = gr.Slider(
|
145 |
+
label="Inference Step",
|
146 |
+
minimum=1,
|
147 |
+
maximum=100,
|
148 |
+
step=1,
|
149 |
+
value=50,
|
150 |
+
)
|
151 |
+
num_samples = gr.Slider(
|
152 |
+
label="Samples",
|
153 |
+
minimum=1,
|
154 |
+
maximum=100,
|
155 |
+
step=1,
|
156 |
+
value=1,
|
157 |
+
)
|
158 |
+
guidance_scale = gr.Slider(
|
159 |
+
label="Guidance Scale",
|
160 |
+
minimum=0.0,
|
161 |
+
maximum=10.0,
|
162 |
+
step=0.1,
|
163 |
+
value=7.5,
|
164 |
+
)
|
165 |
+
image_guidance_scale = gr.Slider(
|
166 |
+
label="Image Guidance Scale",
|
167 |
+
minimum=0.0,
|
168 |
+
maximum=10.0,
|
169 |
+
step=0.1,
|
170 |
+
value=1.5,
|
171 |
+
)
|
172 |
+
|
173 |
+
# Output side
|
174 |
+
with gr.Column():
|
175 |
+
gr.Markdown("### Output Gallery")
|
176 |
+
result_gallery = gr.Gallery(
|
177 |
+
label="Output",
|
178 |
+
show_label=False,
|
179 |
+
elem_id="gallery",
|
180 |
+
columns=2,
|
181 |
+
)
|
182 |
+
|
183 |
+
run_button.click(
|
184 |
+
fn=generate,
|
185 |
+
inputs=[
|
186 |
+
albedo,
|
187 |
+
normal,
|
188 |
+
roughness,
|
189 |
+
metallic,
|
190 |
+
irradiance,
|
191 |
+
prompt,
|
192 |
+
seed,
|
193 |
+
inference_step,
|
194 |
+
num_samples,
|
195 |
+
guidance_scale,
|
196 |
+
image_guidance_scale,
|
197 |
+
],
|
198 |
+
outputs=result_gallery,
|
199 |
+
queue=True,
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
demo.launch(debug=False, share=False, show_api=False)
|
x2rgb/load_image.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
|
6 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def convert_rgb_2_XYZ(rgb):
|
11 |
+
# Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
|
12 |
+
# rgb: (h, w, 3)
|
13 |
+
# XYZ: (h, w, 3)
|
14 |
+
XYZ = torch.ones_like(rgb)
|
15 |
+
XYZ[:, :, 0] = (
|
16 |
+
0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
|
17 |
+
)
|
18 |
+
XYZ[:, :, 1] = (
|
19 |
+
0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
|
20 |
+
)
|
21 |
+
XYZ[:, :, 2] = (
|
22 |
+
0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
|
23 |
+
)
|
24 |
+
return XYZ
|
25 |
+
|
26 |
+
|
27 |
+
def convert_XYZ_2_Yxy(XYZ):
|
28 |
+
# XYZ: (h, w, 3)
|
29 |
+
# Yxy: (h, w, 3)
|
30 |
+
Yxy = torch.ones_like(XYZ)
|
31 |
+
Yxy[:, :, 0] = XYZ[:, :, 1]
|
32 |
+
sum = torch.sum(XYZ, dim=2)
|
33 |
+
inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
|
34 |
+
Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
|
35 |
+
Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
|
36 |
+
return Yxy
|
37 |
+
|
38 |
+
|
39 |
+
def convert_rgb_2_Yxy(rgb):
|
40 |
+
# rgb: (h, w, 3)
|
41 |
+
# Yxy: (h, w, 3)
|
42 |
+
return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
|
43 |
+
|
44 |
+
|
45 |
+
def convert_XYZ_2_rgb(XYZ):
|
46 |
+
# XYZ: (h, w, 3)
|
47 |
+
# rgb: (h, w, 3)
|
48 |
+
rgb = torch.ones_like(XYZ)
|
49 |
+
rgb[:, :, 0] = (
|
50 |
+
3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
|
51 |
+
)
|
52 |
+
rgb[:, :, 1] = (
|
53 |
+
-0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
|
54 |
+
)
|
55 |
+
rgb[:, :, 2] = (
|
56 |
+
0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
|
57 |
+
)
|
58 |
+
return rgb
|
59 |
+
|
60 |
+
|
61 |
+
def convert_Yxy_2_XYZ(Yxy):
|
62 |
+
# Yxy: (h, w, 3)
|
63 |
+
# XYZ: (h, w, 3)
|
64 |
+
XYZ = torch.ones_like(Yxy)
|
65 |
+
XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
|
66 |
+
XYZ[:, :, 1] = Yxy[:, :, 0]
|
67 |
+
XYZ[:, :, 2] = (
|
68 |
+
(1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
|
69 |
+
/ torch.clamp(Yxy[:, :, 2], min=1e-4)
|
70 |
+
* Yxy[:, :, 0]
|
71 |
+
)
|
72 |
+
return XYZ
|
73 |
+
|
74 |
+
|
75 |
+
def convert_Yxy_2_rgb(Yxy):
|
76 |
+
# Yxy: (h, w, 3)
|
77 |
+
# rgb: (h, w, 3)
|
78 |
+
return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
|
79 |
+
|
80 |
+
|
81 |
+
def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
|
82 |
+
# Load png or jpg image
|
83 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
84 |
+
image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
|
85 |
+
image[~torch.isfinite(image)] = 0
|
86 |
+
if from_srgb:
|
87 |
+
# Convert from sRGB to linear RGB
|
88 |
+
image = image**2.2
|
89 |
+
if clamp:
|
90 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
91 |
+
if normalize:
|
92 |
+
# Normalize to [-1, 1]
|
93 |
+
image = image * 2.0 - 1.0
|
94 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
95 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
96 |
+
|
97 |
+
|
98 |
+
def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
|
99 |
+
image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
|
100 |
+
image = torch.from_numpy(image.astype("float32")) # (h, w, c)
|
101 |
+
image[~torch.isfinite(image)] = 0
|
102 |
+
if tonemaping:
|
103 |
+
# Exposure adjuestment
|
104 |
+
image_Yxy = convert_rgb_2_Yxy(image)
|
105 |
+
lum = (
|
106 |
+
image[:, :, 0:1] * 0.2125
|
107 |
+
+ image[:, :, 1:2] * 0.7154
|
108 |
+
+ image[:, :, 2:3] * 0.0721
|
109 |
+
)
|
110 |
+
lum = torch.log(torch.clamp(lum, min=1e-6))
|
111 |
+
lum_mean = torch.exp(torch.mean(lum))
|
112 |
+
lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
|
113 |
+
image_Yxy[:, :, 0:1] = lp
|
114 |
+
image = convert_Yxy_2_rgb(image_Yxy)
|
115 |
+
if clamp:
|
116 |
+
image = torch.clamp(image, min=0.0, max=1.0)
|
117 |
+
if normalize:
|
118 |
+
image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
|
119 |
+
return image.permute(2, 0, 1) # returns (c, h, w)
|
x2rgb/pipeline_x2rgb.py
ADDED
@@ -0,0 +1,967 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from diffusers.configuration_utils import register_to_config
|
10 |
+
from diffusers.image_processor import VaeImageProcessor
|
11 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
12 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
13 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
14 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
15 |
+
rescale_noise_cfg,
|
16 |
+
)
|
17 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
18 |
+
from diffusers.utils import CONFIG_NAME, BaseOutput, deprecate, logging, randn_tensor
|
19 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class VaeImageProcrssorAOV(VaeImageProcessor):
|
25 |
+
"""
|
26 |
+
Image processor for VAE AOV.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
30 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
31 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
32 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
33 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
34 |
+
Resampling filter to use when resizing the image.
|
35 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
36 |
+
Whether to normalize the image to [-1,1].
|
37 |
+
"""
|
38 |
+
|
39 |
+
config_name = CONFIG_NAME
|
40 |
+
|
41 |
+
@register_to_config
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
do_resize: bool = True,
|
45 |
+
vae_scale_factor: int = 8,
|
46 |
+
resample: str = "lanczos",
|
47 |
+
do_normalize: bool = True,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
def postprocess(
|
52 |
+
self,
|
53 |
+
image: torch.FloatTensor,
|
54 |
+
output_type: str = "pil",
|
55 |
+
do_denormalize: Optional[List[bool]] = None,
|
56 |
+
do_gamma_correction: bool = True,
|
57 |
+
):
|
58 |
+
if not isinstance(image, torch.Tensor):
|
59 |
+
raise ValueError(
|
60 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
61 |
+
)
|
62 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
63 |
+
deprecation_message = (
|
64 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
65 |
+
"`pil`, `np`, `pt`, `latent`"
|
66 |
+
)
|
67 |
+
deprecate(
|
68 |
+
"Unsupported output_type",
|
69 |
+
"1.0.0",
|
70 |
+
deprecation_message,
|
71 |
+
standard_warn=False,
|
72 |
+
)
|
73 |
+
output_type = "np"
|
74 |
+
|
75 |
+
if output_type == "latent":
|
76 |
+
return image
|
77 |
+
|
78 |
+
if do_denormalize is None:
|
79 |
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
80 |
+
|
81 |
+
image = torch.stack(
|
82 |
+
[
|
83 |
+
self.denormalize(image[i]) if do_denormalize[i] else image[i]
|
84 |
+
for i in range(image.shape[0])
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
# Gamma correction
|
89 |
+
if do_gamma_correction:
|
90 |
+
image = torch.pow(image, 1.0 / 2.2)
|
91 |
+
|
92 |
+
if output_type == "pt":
|
93 |
+
return image
|
94 |
+
|
95 |
+
image = self.pt_to_numpy(image)
|
96 |
+
|
97 |
+
if output_type == "np":
|
98 |
+
return image
|
99 |
+
|
100 |
+
if output_type == "pil":
|
101 |
+
return self.numpy_to_pil(image)
|
102 |
+
|
103 |
+
def preprocess_normal(
|
104 |
+
self,
|
105 |
+
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
106 |
+
height: Optional[int] = None,
|
107 |
+
width: Optional[int] = None,
|
108 |
+
) -> torch.Tensor:
|
109 |
+
image = torch.stack([image], axis=0)
|
110 |
+
return image
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class StableDiffusionAOVPipelineOutput(BaseOutput):
|
115 |
+
"""
|
116 |
+
Output class for Stable Diffusion AOV pipelines.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
120 |
+
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
121 |
+
num_channels)`.
|
122 |
+
nsfw_content_detected (`List[bool]`)
|
123 |
+
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
|
124 |
+
`None` if safety checking could not be performed.
|
125 |
+
"""
|
126 |
+
|
127 |
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
128 |
+
predicted_x0_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] = None
|
129 |
+
|
130 |
+
|
131 |
+
class StableDiffusionAOVDropoutPipeline(
|
132 |
+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
|
133 |
+
):
|
134 |
+
r"""
|
135 |
+
Pipeline for AOVs.
|
136 |
+
|
137 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
138 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
139 |
+
|
140 |
+
The pipeline also inherits the following loading methods:
|
141 |
+
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
142 |
+
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
143 |
+
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
144 |
+
|
145 |
+
Args:
|
146 |
+
vae ([`AutoencoderKL`]):
|
147 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
148 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
149 |
+
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
150 |
+
tokenizer ([`~transformers.CLIPTokenizer`]):
|
151 |
+
A `CLIPTokenizer` to tokenize text.
|
152 |
+
unet ([`UNet2DConditionModel`]):
|
153 |
+
A `UNet2DConditionModel` to denoise the encoded image latents.
|
154 |
+
scheduler ([`SchedulerMixin`]):
|
155 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
156 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
vae: AutoencoderKL,
|
162 |
+
text_encoder: CLIPTextModel,
|
163 |
+
tokenizer: CLIPTokenizer,
|
164 |
+
unet: UNet2DConditionModel,
|
165 |
+
scheduler: KarrasDiffusionSchedulers,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.register_modules(
|
170 |
+
vae=vae,
|
171 |
+
text_encoder=text_encoder,
|
172 |
+
tokenizer=tokenizer,
|
173 |
+
unet=unet,
|
174 |
+
scheduler=scheduler,
|
175 |
+
)
|
176 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
177 |
+
self.image_processor = VaeImageProcrssorAOV(
|
178 |
+
vae_scale_factor=self.vae_scale_factor
|
179 |
+
)
|
180 |
+
self.register_to_config()
|
181 |
+
|
182 |
+
def _encode_prompt(
|
183 |
+
self,
|
184 |
+
prompt,
|
185 |
+
device,
|
186 |
+
num_images_per_prompt,
|
187 |
+
do_classifier_free_guidance,
|
188 |
+
negative_prompt=None,
|
189 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
190 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
191 |
+
):
|
192 |
+
r"""
|
193 |
+
Encodes the prompt into text encoder hidden states.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
prompt (`str` or `List[str]`, *optional*):
|
197 |
+
prompt to be encoded
|
198 |
+
device: (`torch.device`):
|
199 |
+
torch device
|
200 |
+
num_images_per_prompt (`int`):
|
201 |
+
number of images that should be generated per prompt
|
202 |
+
do_classifier_free_guidance (`bool`):
|
203 |
+
whether to use classifier free guidance or not
|
204 |
+
negative_ prompt (`str` or `List[str]`, *optional*):
|
205 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
206 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
207 |
+
less than `1`).
|
208 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
209 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
210 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
211 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
212 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
213 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
214 |
+
argument.
|
215 |
+
"""
|
216 |
+
if prompt is not None and isinstance(prompt, str):
|
217 |
+
batch_size = 1
|
218 |
+
elif prompt is not None and isinstance(prompt, list):
|
219 |
+
batch_size = len(prompt)
|
220 |
+
else:
|
221 |
+
batch_size = prompt_embeds.shape[0]
|
222 |
+
|
223 |
+
if prompt_embeds is None:
|
224 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
225 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
226 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
227 |
+
|
228 |
+
text_inputs = self.tokenizer(
|
229 |
+
prompt,
|
230 |
+
padding="max_length",
|
231 |
+
max_length=self.tokenizer.model_max_length,
|
232 |
+
truncation=True,
|
233 |
+
return_tensors="pt",
|
234 |
+
)
|
235 |
+
text_input_ids = text_inputs.input_ids
|
236 |
+
untruncated_ids = self.tokenizer(
|
237 |
+
prompt, padding="longest", return_tensors="pt"
|
238 |
+
).input_ids
|
239 |
+
|
240 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
241 |
+
-1
|
242 |
+
] and not torch.equal(text_input_ids, untruncated_ids):
|
243 |
+
removed_text = self.tokenizer.batch_decode(
|
244 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
245 |
+
)
|
246 |
+
logger.warning(
|
247 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
248 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
249 |
+
)
|
250 |
+
|
251 |
+
if (
|
252 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
253 |
+
and self.text_encoder.config.use_attention_mask
|
254 |
+
):
|
255 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
256 |
+
else:
|
257 |
+
attention_mask = None
|
258 |
+
|
259 |
+
prompt_embeds = self.text_encoder(
|
260 |
+
text_input_ids.to(device),
|
261 |
+
attention_mask=attention_mask,
|
262 |
+
)
|
263 |
+
prompt_embeds = prompt_embeds[0]
|
264 |
+
|
265 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
266 |
+
|
267 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
268 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
269 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
270 |
+
prompt_embeds = prompt_embeds.view(
|
271 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
272 |
+
)
|
273 |
+
|
274 |
+
# get unconditional embeddings for classifier free guidance
|
275 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
276 |
+
uncond_tokens: List[str]
|
277 |
+
if negative_prompt is None:
|
278 |
+
uncond_tokens = [""] * batch_size
|
279 |
+
elif type(prompt) is not type(negative_prompt):
|
280 |
+
raise TypeError(
|
281 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
282 |
+
f" {type(prompt)}."
|
283 |
+
)
|
284 |
+
elif isinstance(negative_prompt, str):
|
285 |
+
uncond_tokens = [negative_prompt]
|
286 |
+
elif batch_size != len(negative_prompt):
|
287 |
+
raise ValueError(
|
288 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
289 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
290 |
+
" the batch size of `prompt`."
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
uncond_tokens = negative_prompt
|
294 |
+
|
295 |
+
# textual inversion: procecss multi-vector tokens if necessary
|
296 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
297 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
298 |
+
|
299 |
+
max_length = prompt_embeds.shape[1]
|
300 |
+
uncond_input = self.tokenizer(
|
301 |
+
uncond_tokens,
|
302 |
+
padding="max_length",
|
303 |
+
max_length=max_length,
|
304 |
+
truncation=True,
|
305 |
+
return_tensors="pt",
|
306 |
+
)
|
307 |
+
|
308 |
+
if (
|
309 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
310 |
+
and self.text_encoder.config.use_attention_mask
|
311 |
+
):
|
312 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
313 |
+
else:
|
314 |
+
attention_mask = None
|
315 |
+
|
316 |
+
negative_prompt_embeds = self.text_encoder(
|
317 |
+
uncond_input.input_ids.to(device),
|
318 |
+
attention_mask=attention_mask,
|
319 |
+
)
|
320 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
321 |
+
|
322 |
+
if do_classifier_free_guidance:
|
323 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
324 |
+
seq_len = negative_prompt_embeds.shape[1]
|
325 |
+
|
326 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
327 |
+
dtype=self.text_encoder.dtype, device=device
|
328 |
+
)
|
329 |
+
|
330 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
331 |
+
1, num_images_per_prompt, 1
|
332 |
+
)
|
333 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
334 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
335 |
+
)
|
336 |
+
|
337 |
+
# For classifier free guidance, we need to do two forward passes.
|
338 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
339 |
+
# to avoid doing two forward passes
|
340 |
+
# pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
341 |
+
prompt_embeds = torch.cat(
|
342 |
+
[prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
|
343 |
+
)
|
344 |
+
|
345 |
+
return prompt_embeds
|
346 |
+
|
347 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
348 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
349 |
+
# eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
350 |
+
# eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
|
351 |
+
# and should be between [0, 1]
|
352 |
+
|
353 |
+
accepts_eta = "eta" in set(
|
354 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
355 |
+
)
|
356 |
+
extra_step_kwargs = {}
|
357 |
+
if accepts_eta:
|
358 |
+
extra_step_kwargs["eta"] = eta
|
359 |
+
|
360 |
+
# check if the scheduler accepts generator
|
361 |
+
accepts_generator = "generator" in set(
|
362 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
363 |
+
)
|
364 |
+
if accepts_generator:
|
365 |
+
extra_step_kwargs["generator"] = generator
|
366 |
+
return extra_step_kwargs
|
367 |
+
|
368 |
+
def check_inputs(
|
369 |
+
self,
|
370 |
+
prompt,
|
371 |
+
callback_steps,
|
372 |
+
negative_prompt=None,
|
373 |
+
prompt_embeds=None,
|
374 |
+
negative_prompt_embeds=None,
|
375 |
+
):
|
376 |
+
if (callback_steps is None) or (
|
377 |
+
callback_steps is not None
|
378 |
+
and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
379 |
+
):
|
380 |
+
raise ValueError(
|
381 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
382 |
+
f" {type(callback_steps)}."
|
383 |
+
)
|
384 |
+
|
385 |
+
if prompt is not None and prompt_embeds is not None:
|
386 |
+
raise ValueError(
|
387 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
388 |
+
" only forward one of the two."
|
389 |
+
)
|
390 |
+
elif prompt is None and prompt_embeds is None:
|
391 |
+
raise ValueError(
|
392 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
393 |
+
)
|
394 |
+
elif prompt is not None and (
|
395 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
396 |
+
):
|
397 |
+
raise ValueError(
|
398 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
399 |
+
)
|
400 |
+
|
401 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
402 |
+
raise ValueError(
|
403 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
404 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
405 |
+
)
|
406 |
+
|
407 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
408 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
409 |
+
raise ValueError(
|
410 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
411 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
412 |
+
f" {negative_prompt_embeds.shape}."
|
413 |
+
)
|
414 |
+
|
415 |
+
def prepare_latents(
|
416 |
+
self,
|
417 |
+
batch_size,
|
418 |
+
num_channels_latents,
|
419 |
+
height,
|
420 |
+
width,
|
421 |
+
dtype,
|
422 |
+
device,
|
423 |
+
generator,
|
424 |
+
latents=None,
|
425 |
+
):
|
426 |
+
shape = (
|
427 |
+
batch_size,
|
428 |
+
num_channels_latents,
|
429 |
+
height // self.vae_scale_factor,
|
430 |
+
width // self.vae_scale_factor,
|
431 |
+
)
|
432 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
433 |
+
raise ValueError(
|
434 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
435 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
436 |
+
)
|
437 |
+
|
438 |
+
if latents is None:
|
439 |
+
latents = randn_tensor(
|
440 |
+
shape, generator=generator, device=device, dtype=dtype
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
latents = latents.to(device)
|
444 |
+
|
445 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
446 |
+
latents = latents * self.scheduler.init_noise_sigma
|
447 |
+
return latents
|
448 |
+
|
449 |
+
def prepare_image_latents(
|
450 |
+
self,
|
451 |
+
image,
|
452 |
+
batch_size,
|
453 |
+
num_images_per_prompt,
|
454 |
+
dtype,
|
455 |
+
device,
|
456 |
+
do_classifier_free_guidance,
|
457 |
+
generator=None,
|
458 |
+
):
|
459 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
460 |
+
raise ValueError(
|
461 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
462 |
+
)
|
463 |
+
|
464 |
+
image = image.to(device=device, dtype=dtype)
|
465 |
+
|
466 |
+
batch_size = batch_size * num_images_per_prompt
|
467 |
+
|
468 |
+
if image.shape[1] == 4:
|
469 |
+
image_latents = image
|
470 |
+
else:
|
471 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
472 |
+
raise ValueError(
|
473 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
474 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
475 |
+
)
|
476 |
+
|
477 |
+
if isinstance(generator, list):
|
478 |
+
image_latents = [
|
479 |
+
self.vae.encode(image[i : i + 1]).latent_dist.mode()
|
480 |
+
for i in range(batch_size)
|
481 |
+
]
|
482 |
+
image_latents = torch.cat(image_latents, dim=0)
|
483 |
+
else:
|
484 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
485 |
+
|
486 |
+
if (
|
487 |
+
batch_size > image_latents.shape[0]
|
488 |
+
and batch_size % image_latents.shape[0] == 0
|
489 |
+
):
|
490 |
+
# expand image_latents for batch_size
|
491 |
+
deprecation_message = (
|
492 |
+
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
|
493 |
+
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
494 |
+
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
495 |
+
" your script to pass as many initial images as text prompts to suppress this warning."
|
496 |
+
)
|
497 |
+
deprecate(
|
498 |
+
"len(prompt) != len(image)",
|
499 |
+
"1.0.0",
|
500 |
+
deprecation_message,
|
501 |
+
standard_warn=False,
|
502 |
+
)
|
503 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
504 |
+
image_latents = torch.cat(
|
505 |
+
[image_latents] * additional_image_per_prompt, dim=0
|
506 |
+
)
|
507 |
+
elif (
|
508 |
+
batch_size > image_latents.shape[0]
|
509 |
+
and batch_size % image_latents.shape[0] != 0
|
510 |
+
):
|
511 |
+
raise ValueError(
|
512 |
+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
513 |
+
)
|
514 |
+
else:
|
515 |
+
image_latents = torch.cat([image_latents], dim=0)
|
516 |
+
|
517 |
+
if do_classifier_free_guidance:
|
518 |
+
uncond_image_latents = torch.zeros_like(image_latents)
|
519 |
+
image_latents = torch.cat(
|
520 |
+
[image_latents, image_latents, uncond_image_latents], dim=0
|
521 |
+
)
|
522 |
+
|
523 |
+
return image_latents
|
524 |
+
|
525 |
+
@torch.no_grad()
|
526 |
+
def __call__(
|
527 |
+
self,
|
528 |
+
height: int,
|
529 |
+
width: int,
|
530 |
+
prompt: Union[str, List[str]] = None,
|
531 |
+
albedo: Optional[
|
532 |
+
Union[
|
533 |
+
torch.FloatTensor,
|
534 |
+
PIL.Image.Image,
|
535 |
+
np.ndarray,
|
536 |
+
List[torch.FloatTensor],
|
537 |
+
List[PIL.Image.Image],
|
538 |
+
List[np.ndarray],
|
539 |
+
]
|
540 |
+
] = None,
|
541 |
+
normal: Optional[
|
542 |
+
Union[
|
543 |
+
torch.FloatTensor,
|
544 |
+
PIL.Image.Image,
|
545 |
+
np.ndarray,
|
546 |
+
List[torch.FloatTensor],
|
547 |
+
List[PIL.Image.Image],
|
548 |
+
List[np.ndarray],
|
549 |
+
]
|
550 |
+
] = None,
|
551 |
+
roughness: Optional[
|
552 |
+
Union[
|
553 |
+
torch.FloatTensor,
|
554 |
+
PIL.Image.Image,
|
555 |
+
np.ndarray,
|
556 |
+
List[torch.FloatTensor],
|
557 |
+
List[PIL.Image.Image],
|
558 |
+
List[np.ndarray],
|
559 |
+
]
|
560 |
+
] = None,
|
561 |
+
metallic: Optional[
|
562 |
+
Union[
|
563 |
+
torch.FloatTensor,
|
564 |
+
PIL.Image.Image,
|
565 |
+
np.ndarray,
|
566 |
+
List[torch.FloatTensor],
|
567 |
+
List[PIL.Image.Image],
|
568 |
+
List[np.ndarray],
|
569 |
+
]
|
570 |
+
] = None,
|
571 |
+
irradiance: Optional[
|
572 |
+
Union[
|
573 |
+
torch.FloatTensor,
|
574 |
+
PIL.Image.Image,
|
575 |
+
np.ndarray,
|
576 |
+
List[torch.FloatTensor],
|
577 |
+
List[PIL.Image.Image],
|
578 |
+
List[np.ndarray],
|
579 |
+
]
|
580 |
+
] = None,
|
581 |
+
guidance_scale: float = 0.0,
|
582 |
+
image_guidance_scale: float = 0.0,
|
583 |
+
guidance_rescale: float = 0.0,
|
584 |
+
num_inference_steps: int = 100,
|
585 |
+
required_aovs: List[str] = ["albedo"],
|
586 |
+
return_predicted_x0s: bool = False,
|
587 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
588 |
+
num_images_per_prompt: Optional[int] = 1,
|
589 |
+
eta: float = 0.0,
|
590 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
591 |
+
latents: Optional[torch.FloatTensor] = None,
|
592 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
593 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
594 |
+
output_type: Optional[str] = "pil",
|
595 |
+
return_dict: bool = True,
|
596 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
597 |
+
callback_steps: int = 1,
|
598 |
+
):
|
599 |
+
r"""
|
600 |
+
The call function to the pipeline for generation.
|
601 |
+
|
602 |
+
Args:
|
603 |
+
prompt (`str` or `List[str]`, *optional*):
|
604 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
605 |
+
image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
606 |
+
`Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
|
607 |
+
image latents as `image`, but if passing latents directly it is not encoded again.
|
608 |
+
num_inference_steps (`int`, *optional*, defaults to 100):
|
609 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
610 |
+
expense of slower inference.
|
611 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
612 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
613 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
614 |
+
image_guidance_scale (`float`, *optional*, defaults to 1.5):
|
615 |
+
Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
|
616 |
+
`image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
|
617 |
+
linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
|
618 |
+
value of at least `1`.
|
619 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
620 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
621 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
622 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
623 |
+
The number of images to generate per prompt.
|
624 |
+
eta (`float`, *optional*, defaults to 0.0):
|
625 |
+
Corresponds to parameter eta (Ξ·) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
626 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
627 |
+
generator (`torch.Generator`, *optional*):
|
628 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
629 |
+
generation deterministic.
|
630 |
+
latents (`torch.FloatTensor`, *optional*):
|
631 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
632 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
633 |
+
tensor is generated by sampling using the supplied random `generator`.
|
634 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
635 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
636 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
637 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
638 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
639 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
640 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
641 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
642 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
643 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
644 |
+
plain tuple.
|
645 |
+
callback (`Callable`, *optional*):
|
646 |
+
A function that calls every `callback_steps` steps during inference. The function is called with the
|
647 |
+
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
648 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
649 |
+
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
650 |
+
every step.
|
651 |
+
|
652 |
+
Examples:
|
653 |
+
|
654 |
+
```py
|
655 |
+
>>> import PIL
|
656 |
+
>>> import requests
|
657 |
+
>>> import torch
|
658 |
+
>>> from io import BytesIO
|
659 |
+
|
660 |
+
>>> from diffusers import StableDiffusionInstructPix2PixPipeline
|
661 |
+
|
662 |
+
|
663 |
+
>>> def download_image(url):
|
664 |
+
... response = requests.get(url)
|
665 |
+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
666 |
+
|
667 |
+
|
668 |
+
>>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
|
669 |
+
|
670 |
+
>>> image = download_image(img_url).resize((512, 512))
|
671 |
+
|
672 |
+
>>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
673 |
+
... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
674 |
+
... )
|
675 |
+
>>> pipe = pipe.to("cuda")
|
676 |
+
|
677 |
+
>>> prompt = "make the mountains snowy"
|
678 |
+
>>> image = pipe(prompt=prompt, image=image).images[0]
|
679 |
+
```
|
680 |
+
|
681 |
+
Returns:
|
682 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
683 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
684 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
685 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
686 |
+
"not-safe-for-work" (nsfw) content.
|
687 |
+
"""
|
688 |
+
# 0. Check inputs
|
689 |
+
self.check_inputs(
|
690 |
+
prompt,
|
691 |
+
callback_steps,
|
692 |
+
negative_prompt,
|
693 |
+
prompt_embeds,
|
694 |
+
negative_prompt_embeds,
|
695 |
+
)
|
696 |
+
|
697 |
+
# 1. Define call parameters
|
698 |
+
if prompt is not None and isinstance(prompt, str):
|
699 |
+
batch_size = 1
|
700 |
+
elif prompt is not None and isinstance(prompt, list):
|
701 |
+
batch_size = len(prompt)
|
702 |
+
else:
|
703 |
+
batch_size = prompt_embeds.shape[0]
|
704 |
+
|
705 |
+
device = self._execution_device
|
706 |
+
do_classifier_free_guidance = (
|
707 |
+
guidance_scale >= 1.0 and image_guidance_scale >= 1.0
|
708 |
+
)
|
709 |
+
# check if scheduler is in sigmas space
|
710 |
+
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
|
711 |
+
|
712 |
+
# 2. Encode input prompt
|
713 |
+
prompt_embeds = self._encode_prompt(
|
714 |
+
prompt,
|
715 |
+
device,
|
716 |
+
num_images_per_prompt,
|
717 |
+
do_classifier_free_guidance,
|
718 |
+
negative_prompt,
|
719 |
+
prompt_embeds=prompt_embeds,
|
720 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
721 |
+
)
|
722 |
+
|
723 |
+
# 3. Preprocess image
|
724 |
+
# For normal, the preprocessing does nothing
|
725 |
+
# For others, the preprocessing remap the values to [-1, 1]
|
726 |
+
preprocessed_aovs = {}
|
727 |
+
for aov_name in required_aovs:
|
728 |
+
if aov_name == "albedo":
|
729 |
+
if albedo is not None:
|
730 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
731 |
+
albedo
|
732 |
+
)
|
733 |
+
else:
|
734 |
+
preprocessed_aovs[aov_name] = None
|
735 |
+
|
736 |
+
if aov_name == "normal":
|
737 |
+
if normal is not None:
|
738 |
+
preprocessed_aovs[aov_name] = (
|
739 |
+
self.image_processor.preprocess_normal(normal)
|
740 |
+
)
|
741 |
+
else:
|
742 |
+
preprocessed_aovs[aov_name] = None
|
743 |
+
|
744 |
+
if aov_name == "roughness":
|
745 |
+
if roughness is not None:
|
746 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
747 |
+
roughness
|
748 |
+
)
|
749 |
+
else:
|
750 |
+
preprocessed_aovs[aov_name] = None
|
751 |
+
if aov_name == "metallic":
|
752 |
+
if metallic is not None:
|
753 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
754 |
+
metallic
|
755 |
+
)
|
756 |
+
else:
|
757 |
+
preprocessed_aovs[aov_name] = None
|
758 |
+
if aov_name == "irradiance":
|
759 |
+
if irradiance is not None:
|
760 |
+
preprocessed_aovs[aov_name] = self.image_processor.preprocess(
|
761 |
+
irradiance
|
762 |
+
)
|
763 |
+
else:
|
764 |
+
preprocessed_aovs[aov_name] = None
|
765 |
+
|
766 |
+
# 4. set timesteps
|
767 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
768 |
+
timesteps = self.scheduler.timesteps
|
769 |
+
|
770 |
+
# 5. Prepare latent variables
|
771 |
+
num_channels_latents = self.vae.config.latent_channels
|
772 |
+
latents = self.prepare_latents(
|
773 |
+
batch_size * num_images_per_prompt,
|
774 |
+
num_channels_latents,
|
775 |
+
height,
|
776 |
+
width,
|
777 |
+
prompt_embeds.dtype,
|
778 |
+
device,
|
779 |
+
generator,
|
780 |
+
latents,
|
781 |
+
)
|
782 |
+
|
783 |
+
height_latent, width_latent = latents.shape[-2:]
|
784 |
+
|
785 |
+
# 6. Prepare Image latents
|
786 |
+
image_latents = []
|
787 |
+
# Magicial scaling factors for each AOV (calculated from the training data)
|
788 |
+
scaling_factors = {
|
789 |
+
"albedo": 0.17301377137652138,
|
790 |
+
"normal": 0.17483895473058078,
|
791 |
+
"roughness": 0.1680724853626448,
|
792 |
+
"metallic": 0.13135013390855135,
|
793 |
+
}
|
794 |
+
for aov_name, aov in preprocessed_aovs.items():
|
795 |
+
if aov is None:
|
796 |
+
image_latent = torch.zeros(
|
797 |
+
batch_size,
|
798 |
+
num_channels_latents,
|
799 |
+
height_latent,
|
800 |
+
width_latent,
|
801 |
+
dtype=prompt_embeds.dtype,
|
802 |
+
device=device,
|
803 |
+
)
|
804 |
+
if aov_name == "irradiance":
|
805 |
+
image_latent = image_latent[:, 0:3]
|
806 |
+
if do_classifier_free_guidance:
|
807 |
+
image_latents.append(
|
808 |
+
torch.cat([image_latent, image_latent, image_latent], dim=0)
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
image_latents.append(image_latent)
|
812 |
+
else:
|
813 |
+
if aov_name == "irradiance":
|
814 |
+
image_latent = F.interpolate(
|
815 |
+
aov.to(device=device, dtype=prompt_embeds.dtype),
|
816 |
+
size=(height_latent, width_latent),
|
817 |
+
mode="bilinear",
|
818 |
+
align_corners=False,
|
819 |
+
antialias=True,
|
820 |
+
)
|
821 |
+
if do_classifier_free_guidance:
|
822 |
+
uncond_image_latent = torch.zeros_like(image_latent)
|
823 |
+
image_latent = torch.cat(
|
824 |
+
[image_latent, image_latent, uncond_image_latent], dim=0
|
825 |
+
)
|
826 |
+
else:
|
827 |
+
scaling_factor = scaling_factors[aov_name]
|
828 |
+
image_latent = (
|
829 |
+
self.prepare_image_latents(
|
830 |
+
aov,
|
831 |
+
batch_size,
|
832 |
+
num_images_per_prompt,
|
833 |
+
prompt_embeds.dtype,
|
834 |
+
device,
|
835 |
+
do_classifier_free_guidance,
|
836 |
+
generator,
|
837 |
+
)
|
838 |
+
* scaling_factor
|
839 |
+
)
|
840 |
+
image_latents.append(image_latent)
|
841 |
+
image_latents = torch.cat(image_latents, dim=1)
|
842 |
+
|
843 |
+
# 7. Check that shapes of latents and image match the UNet channels
|
844 |
+
num_channels_image = image_latents.shape[1]
|
845 |
+
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
846 |
+
raise ValueError(
|
847 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
848 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
849 |
+
f" `num_channels_image`: {num_channels_image} "
|
850 |
+
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
851 |
+
" `pipeline.unet` or your `image` input."
|
852 |
+
)
|
853 |
+
|
854 |
+
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
855 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
856 |
+
|
857 |
+
predicted_x0s = []
|
858 |
+
|
859 |
+
# 9. Denoising loop
|
860 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
861 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
862 |
+
for i, t in enumerate(timesteps):
|
863 |
+
# Expand the latents if we are doing classifier free guidance.
|
864 |
+
# The latents are expanded 3 times because for pix2pix the guidance\
|
865 |
+
# is applied for both the text and the input image.
|
866 |
+
latent_model_input = (
|
867 |
+
torch.cat([latents] * 3) if do_classifier_free_guidance else latents
|
868 |
+
)
|
869 |
+
|
870 |
+
# concat latents, image_latents in the channel dimension
|
871 |
+
scaled_latent_model_input = self.scheduler.scale_model_input(
|
872 |
+
latent_model_input, t
|
873 |
+
)
|
874 |
+
scaled_latent_model_input = torch.cat(
|
875 |
+
[scaled_latent_model_input, image_latents], dim=1
|
876 |
+
)
|
877 |
+
|
878 |
+
# predict the noise residual
|
879 |
+
noise_pred = self.unet(
|
880 |
+
scaled_latent_model_input,
|
881 |
+
t,
|
882 |
+
encoder_hidden_states=prompt_embeds,
|
883 |
+
return_dict=False,
|
884 |
+
)[0]
|
885 |
+
|
886 |
+
# perform guidance
|
887 |
+
if do_classifier_free_guidance:
|
888 |
+
(
|
889 |
+
noise_pred_text,
|
890 |
+
noise_pred_image,
|
891 |
+
noise_pred_uncond,
|
892 |
+
) = noise_pred.chunk(3)
|
893 |
+
noise_pred = (
|
894 |
+
noise_pred_uncond
|
895 |
+
+ guidance_scale * (noise_pred_text - noise_pred_image)
|
896 |
+
+ image_guidance_scale * (noise_pred_image - noise_pred_uncond)
|
897 |
+
)
|
898 |
+
|
899 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
900 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
901 |
+
noise_pred = rescale_noise_cfg(
|
902 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
903 |
+
)
|
904 |
+
|
905 |
+
# compute the previous noisy sample x_t -> x_t-1
|
906 |
+
output = self.scheduler.step(
|
907 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=True
|
908 |
+
)
|
909 |
+
|
910 |
+
latents = output[0]
|
911 |
+
|
912 |
+
if return_predicted_x0s:
|
913 |
+
predicted_x0s.append(output[1])
|
914 |
+
|
915 |
+
# call the callback, if provided
|
916 |
+
if i == len(timesteps) - 1 or (
|
917 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
918 |
+
):
|
919 |
+
progress_bar.update()
|
920 |
+
if callback is not None and i % callback_steps == 0:
|
921 |
+
callback(i, t, latents)
|
922 |
+
|
923 |
+
if not output_type == "latent":
|
924 |
+
image = self.vae.decode(
|
925 |
+
latents / self.vae.config.scaling_factor, return_dict=False
|
926 |
+
)[0]
|
927 |
+
|
928 |
+
if return_predicted_x0s:
|
929 |
+
predicted_x0_images = [
|
930 |
+
self.vae.decode(
|
931 |
+
predicted_x0 / self.vae.config.scaling_factor, return_dict=False
|
932 |
+
)[0]
|
933 |
+
for predicted_x0 in predicted_x0s
|
934 |
+
]
|
935 |
+
else:
|
936 |
+
image = latents
|
937 |
+
predicted_x0_images = predicted_x0s
|
938 |
+
|
939 |
+
do_denormalize = [True] * image.shape[0]
|
940 |
+
|
941 |
+
image = self.image_processor.postprocess(
|
942 |
+
image, output_type=output_type, do_denormalize=do_denormalize
|
943 |
+
)
|
944 |
+
|
945 |
+
if return_predicted_x0s:
|
946 |
+
predicted_x0_images = [
|
947 |
+
self.image_processor.postprocess(
|
948 |
+
predicted_x0_image,
|
949 |
+
output_type=output_type,
|
950 |
+
do_denormalize=do_denormalize,
|
951 |
+
)
|
952 |
+
for predicted_x0_image in predicted_x0_images
|
953 |
+
]
|
954 |
+
|
955 |
+
# Offload last model to CPU
|
956 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
957 |
+
self.final_offload_hook.offload()
|
958 |
+
|
959 |
+
if not return_dict:
|
960 |
+
return image
|
961 |
+
|
962 |
+
if return_predicted_x0s:
|
963 |
+
return StableDiffusionAOVPipelineOutput(
|
964 |
+
images=image, predicted_x0_images=predicted_x0_images
|
965 |
+
)
|
966 |
+
else:
|
967 |
+
return StableDiffusionAOVPipelineOutput(images=image)
|