blanchon commited on
Commit
7c0f7b6
·
1 Parent(s): cebcba2

Remove X2RGB and add examples

Browse files
rgb2x/gradio_demo_rgb2x.py CHANGED
@@ -141,6 +141,18 @@ with gr.Blocks() as demo:
141
  elem_id="gallery",
142
  columns=2,
143
  )
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  run_button.click(
146
  fn=generate,
 
141
  elem_id="gallery",
142
  columns=2,
143
  )
144
+ examples = gr.Examples(
145
+ examples=[
146
+ [
147
+ "rgb2x/example/Castlereagh_corridor_photo.png",
148
+ ]
149
+ ],
150
+ inputs=[photo],
151
+ outputs=[result_gallery],
152
+ fn=generate,
153
+ cache_mode="eager",
154
+ cache_examples=True,
155
+ )
156
 
157
  run_button.click(
158
  fn=generate,
x2rgb/example/kitchen-albedo.png DELETED

Git LFS Details

  • SHA256: d2b3e2ae5001c4214d5c87041e57933708cbb424eca6f7a2659c2c3e91a6a8ce
  • Pointer size: 131 Bytes
  • Size of remote file: 313 kB
x2rgb/example/kitchen-irradiance.png DELETED

Git LFS Details

  • SHA256: 259b873bba6405d72a87a30321f4572a47d296726368eba2f0303d4ab3bcd269
  • Pointer size: 131 Bytes
  • Size of remote file: 959 kB
x2rgb/example/kitchen-metallic.png DELETED

Git LFS Details

  • SHA256: cd6fec250659c8915c821b9063b851da4da59c2459c56fa08338cb81c5e6b70d
  • Pointer size: 130 Bytes
  • Size of remote file: 33.9 kB
x2rgb/example/kitchen-normal.png DELETED

Git LFS Details

  • SHA256: abf769887d2ee8fa050f56f50285502fcf8dbb8b69c28c1f3910ad1ee2874068
  • Pointer size: 131 Bytes
  • Size of remote file: 415 kB
x2rgb/example/kitchen-ref.png DELETED

Git LFS Details

  • SHA256: 19e57fc6737291cb59611786c9894fd4c2bedb0ba14b875942241195afff3534
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
x2rgb/example/kitchen-roughness.png DELETED

Git LFS Details

  • SHA256: 9d1195686031d170151798b00d095c48a40e1a8a508d65c5a841fcabb0ae8fad
  • Pointer size: 130 Bytes
  • Size of remote file: 84 kB
x2rgb/gradio_demo_x2rgb.py DELETED
@@ -1,204 +0,0 @@
1
- import spaces
2
- import os
3
- from typing import cast
4
- import gradio as gr
5
- import numpy as np
6
- import torch
7
- from PIL import Image
8
- from diffusers import DDIMScheduler
9
- from load_image import load_exr_image, load_ldr_image
10
- from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline
11
-
12
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
13
-
14
- current_directory = os.path.dirname(os.path.abspath(__file__))
15
-
16
- _pipe = StableDiffusionAOVDropoutPipeline.from_pretrained(
17
- "zheng95z/x-to-rgb",
18
- torch_dtype=torch.float16,
19
- cache_dir=os.path.join(current_directory, "model_cache"),
20
- ).to("cuda")
21
- pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe)
22
- pipe.scheduler = DDIMScheduler.from_config(
23
- pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
24
- )
25
- pipe.set_progress_bar_config(disable=True)
26
- pipe.to("cuda")
27
- pipe = cast(StableDiffusionAOVDropoutPipeline, pipe)
28
-
29
-
30
- @spaces.GPU
31
- def generate(
32
- albedo,
33
- normal,
34
- roughness,
35
- metallic,
36
- irradiance,
37
- prompt: str,
38
- seed: int,
39
- inference_step: int,
40
- num_samples: int,
41
- guidance_scale: float,
42
- image_guidance_scale: float,
43
- ) -> list[Image.Image]:
44
- generator = torch.Generator(device="cuda").manual_seed(seed)
45
-
46
- # Load and process each intrinsic channel image
47
- def process_image(file, **kwargs):
48
- if file is None:
49
- return None
50
- if file.name.endswith(".exr"):
51
- return load_exr_image(file.name, **kwargs).to("cuda")
52
- elif file.name.endswith((".png", ".jpg", ".jpeg")):
53
- return load_ldr_image(file.name, **kwargs).to("cuda")
54
- return None
55
-
56
- albedo_image = process_image(albedo, clamp=True)
57
- normal_image = process_image(normal, normalize=True)
58
- roughness_image = process_image(roughness, clamp=True)
59
- metallic_image = process_image(metallic, clamp=True)
60
- irradiance_image = process_image(irradiance, tonemaping=True, clamp=True)
61
-
62
- # Set default height and width based on the first available image
63
- height, width = 768, 768
64
- for img in [
65
- albedo_image,
66
- normal_image,
67
- roughness_image,
68
- metallic_image,
69
- irradiance_image,
70
- ]:
71
- if img is not None:
72
- height, width = img.shape[1], img.shape[2]
73
- break
74
-
75
- required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
76
- return_list = []
77
-
78
- for i in range(num_samples):
79
- generated_image = pipe(
80
- prompt=prompt,
81
- albedo=albedo_image,
82
- normal=normal_image,
83
- roughness=roughness_image,
84
- metallic=metallic_image,
85
- irradiance=irradiance_image,
86
- num_inference_steps=inference_step,
87
- height=height,
88
- width=width,
89
- generator=generator,
90
- required_aovs=required_aovs,
91
- guidance_scale=guidance_scale,
92
- image_guidance_scale=image_guidance_scale,
93
- guidance_rescale=0.7,
94
- output_type="np",
95
- ).images[0] # type: ignore
96
-
97
- return_list.append((generated_image, f"Generated Image {i}"))
98
-
99
- # Append additional images to the output gallery
100
- def post_process_image(img, **kwargs):
101
- if img is not None:
102
- return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
103
- return np.zeros((height, width, 3))
104
-
105
- return_list.extend(
106
- [
107
- post_process_image(albedo_image, label="Albedo"),
108
- post_process_image(normal_image, label="Normal"),
109
- post_process_image(roughness_image, label="Roughness"),
110
- post_process_image(metallic_image, label="Metallic"),
111
- post_process_image(irradiance_image, label="Irradiance"),
112
- ]
113
- )
114
-
115
- return return_list
116
-
117
-
118
- with gr.Blocks() as demo:
119
- with gr.Row():
120
- gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)")
121
- with gr.Row():
122
- # Input side
123
- with gr.Column():
124
- gr.Markdown("### Given intrinsic channels")
125
- albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"])
126
- normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"])
127
- roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"])
128
- metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"])
129
- irradiance = gr.File(
130
- label="Irradiance", file_types=[".exr", ".png", ".jpg"]
131
- )
132
-
133
- gr.Markdown("### Parameters")
134
- prompt = gr.Textbox(label="Prompt")
135
- run_button = gr.Button(value="Run")
136
- with gr.Accordion("Advanced options", open=False):
137
- seed = gr.Slider(
138
- label="Seed",
139
- minimum=-1,
140
- maximum=2147483647,
141
- step=1,
142
- randomize=True,
143
- )
144
- inference_step = gr.Slider(
145
- label="Inference Step",
146
- minimum=1,
147
- maximum=100,
148
- step=1,
149
- value=50,
150
- )
151
- num_samples = gr.Slider(
152
- label="Samples",
153
- minimum=1,
154
- maximum=100,
155
- step=1,
156
- value=1,
157
- )
158
- guidance_scale = gr.Slider(
159
- label="Guidance Scale",
160
- minimum=0.0,
161
- maximum=10.0,
162
- step=0.1,
163
- value=7.5,
164
- )
165
- image_guidance_scale = gr.Slider(
166
- label="Image Guidance Scale",
167
- minimum=0.0,
168
- maximum=10.0,
169
- step=0.1,
170
- value=1.5,
171
- )
172
-
173
- # Output side
174
- with gr.Column():
175
- gr.Markdown("### Output Gallery")
176
- result_gallery = gr.Gallery(
177
- label="Output",
178
- show_label=False,
179
- elem_id="gallery",
180
- columns=2,
181
- )
182
-
183
- run_button.click(
184
- fn=generate,
185
- inputs=[
186
- albedo,
187
- normal,
188
- roughness,
189
- metallic,
190
- irradiance,
191
- prompt,
192
- seed,
193
- inference_step,
194
- num_samples,
195
- guidance_scale,
196
- image_guidance_scale,
197
- ],
198
- outputs=result_gallery,
199
- queue=True,
200
- )
201
-
202
-
203
- if __name__ == "__main__":
204
- demo.launch(debug=False, share=False, show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
x2rgb/load_image.py DELETED
@@ -1,119 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import torch
5
-
6
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
7
- import numpy as np
8
-
9
-
10
- def convert_rgb_2_XYZ(rgb):
11
- # Reference: https://web.archive.org/web/20191027010220/http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html
12
- # rgb: (h, w, 3)
13
- # XYZ: (h, w, 3)
14
- XYZ = torch.ones_like(rgb)
15
- XYZ[:, :, 0] = (
16
- 0.4124564 * rgb[:, :, 0] + 0.3575761 * rgb[:, :, 1] + 0.1804375 * rgb[:, :, 2]
17
- )
18
- XYZ[:, :, 1] = (
19
- 0.2126729 * rgb[:, :, 0] + 0.7151522 * rgb[:, :, 1] + 0.0721750 * rgb[:, :, 2]
20
- )
21
- XYZ[:, :, 2] = (
22
- 0.0193339 * rgb[:, :, 0] + 0.1191920 * rgb[:, :, 1] + 0.9503041 * rgb[:, :, 2]
23
- )
24
- return XYZ
25
-
26
-
27
- def convert_XYZ_2_Yxy(XYZ):
28
- # XYZ: (h, w, 3)
29
- # Yxy: (h, w, 3)
30
- Yxy = torch.ones_like(XYZ)
31
- Yxy[:, :, 0] = XYZ[:, :, 1]
32
- sum = torch.sum(XYZ, dim=2)
33
- inv_sum = 1.0 / torch.clamp(sum, min=1e-4)
34
- Yxy[:, :, 1] = XYZ[:, :, 0] * inv_sum
35
- Yxy[:, :, 2] = XYZ[:, :, 1] * inv_sum
36
- return Yxy
37
-
38
-
39
- def convert_rgb_2_Yxy(rgb):
40
- # rgb: (h, w, 3)
41
- # Yxy: (h, w, 3)
42
- return convert_XYZ_2_Yxy(convert_rgb_2_XYZ(rgb))
43
-
44
-
45
- def convert_XYZ_2_rgb(XYZ):
46
- # XYZ: (h, w, 3)
47
- # rgb: (h, w, 3)
48
- rgb = torch.ones_like(XYZ)
49
- rgb[:, :, 0] = (
50
- 3.2404542 * XYZ[:, :, 0] - 1.5371385 * XYZ[:, :, 1] - 0.4985314 * XYZ[:, :, 2]
51
- )
52
- rgb[:, :, 1] = (
53
- -0.9692660 * XYZ[:, :, 0] + 1.8760108 * XYZ[:, :, 1] + 0.0415560 * XYZ[:, :, 2]
54
- )
55
- rgb[:, :, 2] = (
56
- 0.0556434 * XYZ[:, :, 0] - 0.2040259 * XYZ[:, :, 1] + 1.0572252 * XYZ[:, :, 2]
57
- )
58
- return rgb
59
-
60
-
61
- def convert_Yxy_2_XYZ(Yxy):
62
- # Yxy: (h, w, 3)
63
- # XYZ: (h, w, 3)
64
- XYZ = torch.ones_like(Yxy)
65
- XYZ[:, :, 0] = Yxy[:, :, 1] / torch.clamp(Yxy[:, :, 2], min=1e-6) * Yxy[:, :, 0]
66
- XYZ[:, :, 1] = Yxy[:, :, 0]
67
- XYZ[:, :, 2] = (
68
- (1.0 - Yxy[:, :, 1] - Yxy[:, :, 2])
69
- / torch.clamp(Yxy[:, :, 2], min=1e-4)
70
- * Yxy[:, :, 0]
71
- )
72
- return XYZ
73
-
74
-
75
- def convert_Yxy_2_rgb(Yxy):
76
- # Yxy: (h, w, 3)
77
- # rgb: (h, w, 3)
78
- return convert_XYZ_2_rgb(convert_Yxy_2_XYZ(Yxy))
79
-
80
-
81
- def load_ldr_image(image_path, from_srgb=False, clamp=False, normalize=False):
82
- # Load png or jpg image
83
- image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
84
- image = torch.from_numpy(image.astype(np.float32) / 255.0) # (h, w, c)
85
- image[~torch.isfinite(image)] = 0
86
- if from_srgb:
87
- # Convert from sRGB to linear RGB
88
- image = image**2.2
89
- if clamp:
90
- image = torch.clamp(image, min=0.0, max=1.0)
91
- if normalize:
92
- # Normalize to [-1, 1]
93
- image = image * 2.0 - 1.0
94
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
95
- return image.permute(2, 0, 1) # returns (c, h, w)
96
-
97
-
98
- def load_exr_image(image_path, tonemaping=False, clamp=False, normalize=False):
99
- image = cv2.cvtColor(cv2.imread(image_path, -1), cv2.COLOR_BGR2RGB)
100
- image = torch.from_numpy(image.astype("float32")) # (h, w, c)
101
- image[~torch.isfinite(image)] = 0
102
- if tonemaping:
103
- # Exposure adjuestment
104
- image_Yxy = convert_rgb_2_Yxy(image)
105
- lum = (
106
- image[:, :, 0:1] * 0.2125
107
- + image[:, :, 1:2] * 0.7154
108
- + image[:, :, 2:3] * 0.0721
109
- )
110
- lum = torch.log(torch.clamp(lum, min=1e-6))
111
- lum_mean = torch.exp(torch.mean(lum))
112
- lp = image_Yxy[:, :, 0:1] * 0.18 / torch.clamp(lum_mean, min=1e-6)
113
- image_Yxy[:, :, 0:1] = lp
114
- image = convert_Yxy_2_rgb(image_Yxy)
115
- if clamp:
116
- image = torch.clamp(image, min=0.0, max=1.0)
117
- if normalize:
118
- image = torch.nn.functional.normalize(image, dim=-1, eps=1e-6)
119
- return image.permute(2, 0, 1) # returns (c, h, w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
x2rgb/pipeline_x2rgb.py DELETED
@@ -1,967 +0,0 @@
1
- import inspect
2
- from dataclasses import dataclass
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import PIL
7
- import torch
8
- import torch.nn.functional as F
9
- from diffusers.configuration_utils import register_to_config
10
- from diffusers.image_processor import VaeImageProcessor
11
- from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
12
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
13
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
15
- rescale_noise_cfg,
16
- )
17
- from diffusers.schedulers import KarrasDiffusionSchedulers
18
- from diffusers.utils import CONFIG_NAME, BaseOutput, deprecate, logging, randn_tensor
19
- from transformers import CLIPTextModel, CLIPTokenizer
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
-
24
- class VaeImageProcrssorAOV(VaeImageProcessor):
25
- """
26
- Image processor for VAE AOV.
27
-
28
- Args:
29
- do_resize (`bool`, *optional*, defaults to `True`):
30
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
31
- vae_scale_factor (`int`, *optional*, defaults to `8`):
32
- VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
33
- resample (`str`, *optional*, defaults to `lanczos`):
34
- Resampling filter to use when resizing the image.
35
- do_normalize (`bool`, *optional*, defaults to `True`):
36
- Whether to normalize the image to [-1,1].
37
- """
38
-
39
- config_name = CONFIG_NAME
40
-
41
- @register_to_config
42
- def __init__(
43
- self,
44
- do_resize: bool = True,
45
- vae_scale_factor: int = 8,
46
- resample: str = "lanczos",
47
- do_normalize: bool = True,
48
- ):
49
- super().__init__()
50
-
51
- def postprocess(
52
- self,
53
- image: torch.FloatTensor,
54
- output_type: str = "pil",
55
- do_denormalize: Optional[List[bool]] = None,
56
- do_gamma_correction: bool = True,
57
- ):
58
- if not isinstance(image, torch.Tensor):
59
- raise ValueError(
60
- f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
61
- )
62
- if output_type not in ["latent", "pt", "np", "pil"]:
63
- deprecation_message = (
64
- f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
65
- "`pil`, `np`, `pt`, `latent`"
66
- )
67
- deprecate(
68
- "Unsupported output_type",
69
- "1.0.0",
70
- deprecation_message,
71
- standard_warn=False,
72
- )
73
- output_type = "np"
74
-
75
- if output_type == "latent":
76
- return image
77
-
78
- if do_denormalize is None:
79
- do_denormalize = [self.config.do_normalize] * image.shape[0]
80
-
81
- image = torch.stack(
82
- [
83
- self.denormalize(image[i]) if do_denormalize[i] else image[i]
84
- for i in range(image.shape[0])
85
- ]
86
- )
87
-
88
- # Gamma correction
89
- if do_gamma_correction:
90
- image = torch.pow(image, 1.0 / 2.2)
91
-
92
- if output_type == "pt":
93
- return image
94
-
95
- image = self.pt_to_numpy(image)
96
-
97
- if output_type == "np":
98
- return image
99
-
100
- if output_type == "pil":
101
- return self.numpy_to_pil(image)
102
-
103
- def preprocess_normal(
104
- self,
105
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
106
- height: Optional[int] = None,
107
- width: Optional[int] = None,
108
- ) -> torch.Tensor:
109
- image = torch.stack([image], axis=0)
110
- return image
111
-
112
-
113
- @dataclass
114
- class StableDiffusionAOVPipelineOutput(BaseOutput):
115
- """
116
- Output class for Stable Diffusion AOV pipelines.
117
-
118
- Args:
119
- images (`List[PIL.Image.Image]` or `np.ndarray`)
120
- List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
121
- num_channels)`.
122
- nsfw_content_detected (`List[bool]`)
123
- List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
124
- `None` if safety checking could not be performed.
125
- """
126
-
127
- images: Union[List[PIL.Image.Image], np.ndarray]
128
- predicted_x0_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] = None
129
-
130
-
131
- class StableDiffusionAOVDropoutPipeline(
132
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin
133
- ):
134
- r"""
135
- Pipeline for AOVs.
136
-
137
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
138
- implemented for all pipelines (downloading, saving, running on a particular device, etc.).
139
-
140
- The pipeline also inherits the following loading methods:
141
- - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
142
- - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
143
- - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
144
-
145
- Args:
146
- vae ([`AutoencoderKL`]):
147
- Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
148
- text_encoder ([`~transformers.CLIPTextModel`]):
149
- Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
150
- tokenizer ([`~transformers.CLIPTokenizer`]):
151
- A `CLIPTokenizer` to tokenize text.
152
- unet ([`UNet2DConditionModel`]):
153
- A `UNet2DConditionModel` to denoise the encoded image latents.
154
- scheduler ([`SchedulerMixin`]):
155
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
156
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
157
- """
158
-
159
- def __init__(
160
- self,
161
- vae: AutoencoderKL,
162
- text_encoder: CLIPTextModel,
163
- tokenizer: CLIPTokenizer,
164
- unet: UNet2DConditionModel,
165
- scheduler: KarrasDiffusionSchedulers,
166
- ):
167
- super().__init__()
168
-
169
- self.register_modules(
170
- vae=vae,
171
- text_encoder=text_encoder,
172
- tokenizer=tokenizer,
173
- unet=unet,
174
- scheduler=scheduler,
175
- )
176
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
177
- self.image_processor = VaeImageProcrssorAOV(
178
- vae_scale_factor=self.vae_scale_factor
179
- )
180
- self.register_to_config()
181
-
182
- def _encode_prompt(
183
- self,
184
- prompt,
185
- device,
186
- num_images_per_prompt,
187
- do_classifier_free_guidance,
188
- negative_prompt=None,
189
- prompt_embeds: Optional[torch.FloatTensor] = None,
190
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
191
- ):
192
- r"""
193
- Encodes the prompt into text encoder hidden states.
194
-
195
- Args:
196
- prompt (`str` or `List[str]`, *optional*):
197
- prompt to be encoded
198
- device: (`torch.device`):
199
- torch device
200
- num_images_per_prompt (`int`):
201
- number of images that should be generated per prompt
202
- do_classifier_free_guidance (`bool`):
203
- whether to use classifier free guidance or not
204
- negative_ prompt (`str` or `List[str]`, *optional*):
205
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
206
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
207
- less than `1`).
208
- prompt_embeds (`torch.FloatTensor`, *optional*):
209
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
210
- provided, text embeddings will be generated from `prompt` input argument.
211
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
212
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
213
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
214
- argument.
215
- """
216
- if prompt is not None and isinstance(prompt, str):
217
- batch_size = 1
218
- elif prompt is not None and isinstance(prompt, list):
219
- batch_size = len(prompt)
220
- else:
221
- batch_size = prompt_embeds.shape[0]
222
-
223
- if prompt_embeds is None:
224
- # textual inversion: procecss multi-vector tokens if necessary
225
- if isinstance(self, TextualInversionLoaderMixin):
226
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
227
-
228
- text_inputs = self.tokenizer(
229
- prompt,
230
- padding="max_length",
231
- max_length=self.tokenizer.model_max_length,
232
- truncation=True,
233
- return_tensors="pt",
234
- )
235
- text_input_ids = text_inputs.input_ids
236
- untruncated_ids = self.tokenizer(
237
- prompt, padding="longest", return_tensors="pt"
238
- ).input_ids
239
-
240
- if untruncated_ids.shape[-1] >= text_input_ids.shape[
241
- -1
242
- ] and not torch.equal(text_input_ids, untruncated_ids):
243
- removed_text = self.tokenizer.batch_decode(
244
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
- )
246
- logger.warning(
247
- "The following part of your input was truncated because CLIP can only handle sequences up to"
248
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
- )
250
-
251
- if (
252
- hasattr(self.text_encoder.config, "use_attention_mask")
253
- and self.text_encoder.config.use_attention_mask
254
- ):
255
- attention_mask = text_inputs.attention_mask.to(device)
256
- else:
257
- attention_mask = None
258
-
259
- prompt_embeds = self.text_encoder(
260
- text_input_ids.to(device),
261
- attention_mask=attention_mask,
262
- )
263
- prompt_embeds = prompt_embeds[0]
264
-
265
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
-
267
- bs_embed, seq_len, _ = prompt_embeds.shape
268
- # duplicate text embeddings for each generation per prompt, using mps friendly method
269
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
- prompt_embeds = prompt_embeds.view(
271
- bs_embed * num_images_per_prompt, seq_len, -1
272
- )
273
-
274
- # get unconditional embeddings for classifier free guidance
275
- if do_classifier_free_guidance and negative_prompt_embeds is None:
276
- uncond_tokens: List[str]
277
- if negative_prompt is None:
278
- uncond_tokens = [""] * batch_size
279
- elif type(prompt) is not type(negative_prompt):
280
- raise TypeError(
281
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
- f" {type(prompt)}."
283
- )
284
- elif isinstance(negative_prompt, str):
285
- uncond_tokens = [negative_prompt]
286
- elif batch_size != len(negative_prompt):
287
- raise ValueError(
288
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
- " the batch size of `prompt`."
291
- )
292
- else:
293
- uncond_tokens = negative_prompt
294
-
295
- # textual inversion: procecss multi-vector tokens if necessary
296
- if isinstance(self, TextualInversionLoaderMixin):
297
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
298
-
299
- max_length = prompt_embeds.shape[1]
300
- uncond_input = self.tokenizer(
301
- uncond_tokens,
302
- padding="max_length",
303
- max_length=max_length,
304
- truncation=True,
305
- return_tensors="pt",
306
- )
307
-
308
- if (
309
- hasattr(self.text_encoder.config, "use_attention_mask")
310
- and self.text_encoder.config.use_attention_mask
311
- ):
312
- attention_mask = uncond_input.attention_mask.to(device)
313
- else:
314
- attention_mask = None
315
-
316
- negative_prompt_embeds = self.text_encoder(
317
- uncond_input.input_ids.to(device),
318
- attention_mask=attention_mask,
319
- )
320
- negative_prompt_embeds = negative_prompt_embeds[0]
321
-
322
- if do_classifier_free_guidance:
323
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
324
- seq_len = negative_prompt_embeds.shape[1]
325
-
326
- negative_prompt_embeds = negative_prompt_embeds.to(
327
- dtype=self.text_encoder.dtype, device=device
328
- )
329
-
330
- negative_prompt_embeds = negative_prompt_embeds.repeat(
331
- 1, num_images_per_prompt, 1
332
- )
333
- negative_prompt_embeds = negative_prompt_embeds.view(
334
- batch_size * num_images_per_prompt, seq_len, -1
335
- )
336
-
337
- # For classifier free guidance, we need to do two forward passes.
338
- # Here we concatenate the unconditional and text embeddings into a single batch
339
- # to avoid doing two forward passes
340
- # pix2pix has two negative embeddings, and unlike in other pipelines latents are ordered [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
341
- prompt_embeds = torch.cat(
342
- [prompt_embeds, negative_prompt_embeds, negative_prompt_embeds]
343
- )
344
-
345
- return prompt_embeds
346
-
347
- def prepare_extra_step_kwargs(self, generator, eta):
348
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
- # and should be between [0, 1]
352
-
353
- accepts_eta = "eta" in set(
354
- inspect.signature(self.scheduler.step).parameters.keys()
355
- )
356
- extra_step_kwargs = {}
357
- if accepts_eta:
358
- extra_step_kwargs["eta"] = eta
359
-
360
- # check if the scheduler accepts generator
361
- accepts_generator = "generator" in set(
362
- inspect.signature(self.scheduler.step).parameters.keys()
363
- )
364
- if accepts_generator:
365
- extra_step_kwargs["generator"] = generator
366
- return extra_step_kwargs
367
-
368
- def check_inputs(
369
- self,
370
- prompt,
371
- callback_steps,
372
- negative_prompt=None,
373
- prompt_embeds=None,
374
- negative_prompt_embeds=None,
375
- ):
376
- if (callback_steps is None) or (
377
- callback_steps is not None
378
- and (not isinstance(callback_steps, int) or callback_steps <= 0)
379
- ):
380
- raise ValueError(
381
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
382
- f" {type(callback_steps)}."
383
- )
384
-
385
- if prompt is not None and prompt_embeds is not None:
386
- raise ValueError(
387
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
388
- " only forward one of the two."
389
- )
390
- elif prompt is None and prompt_embeds is None:
391
- raise ValueError(
392
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
393
- )
394
- elif prompt is not None and (
395
- not isinstance(prompt, str) and not isinstance(prompt, list)
396
- ):
397
- raise ValueError(
398
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
399
- )
400
-
401
- if negative_prompt is not None and negative_prompt_embeds is not None:
402
- raise ValueError(
403
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
404
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
405
- )
406
-
407
- if prompt_embeds is not None and negative_prompt_embeds is not None:
408
- if prompt_embeds.shape != negative_prompt_embeds.shape:
409
- raise ValueError(
410
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
411
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
412
- f" {negative_prompt_embeds.shape}."
413
- )
414
-
415
- def prepare_latents(
416
- self,
417
- batch_size,
418
- num_channels_latents,
419
- height,
420
- width,
421
- dtype,
422
- device,
423
- generator,
424
- latents=None,
425
- ):
426
- shape = (
427
- batch_size,
428
- num_channels_latents,
429
- height // self.vae_scale_factor,
430
- width // self.vae_scale_factor,
431
- )
432
- if isinstance(generator, list) and len(generator) != batch_size:
433
- raise ValueError(
434
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
435
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
436
- )
437
-
438
- if latents is None:
439
- latents = randn_tensor(
440
- shape, generator=generator, device=device, dtype=dtype
441
- )
442
- else:
443
- latents = latents.to(device)
444
-
445
- # scale the initial noise by the standard deviation required by the scheduler
446
- latents = latents * self.scheduler.init_noise_sigma
447
- return latents
448
-
449
- def prepare_image_latents(
450
- self,
451
- image,
452
- batch_size,
453
- num_images_per_prompt,
454
- dtype,
455
- device,
456
- do_classifier_free_guidance,
457
- generator=None,
458
- ):
459
- if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
460
- raise ValueError(
461
- f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
462
- )
463
-
464
- image = image.to(device=device, dtype=dtype)
465
-
466
- batch_size = batch_size * num_images_per_prompt
467
-
468
- if image.shape[1] == 4:
469
- image_latents = image
470
- else:
471
- if isinstance(generator, list) and len(generator) != batch_size:
472
- raise ValueError(
473
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
474
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
475
- )
476
-
477
- if isinstance(generator, list):
478
- image_latents = [
479
- self.vae.encode(image[i : i + 1]).latent_dist.mode()
480
- for i in range(batch_size)
481
- ]
482
- image_latents = torch.cat(image_latents, dim=0)
483
- else:
484
- image_latents = self.vae.encode(image).latent_dist.mode()
485
-
486
- if (
487
- batch_size > image_latents.shape[0]
488
- and batch_size % image_latents.shape[0] == 0
489
- ):
490
- # expand image_latents for batch_size
491
- deprecation_message = (
492
- f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
493
- " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
494
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
495
- " your script to pass as many initial images as text prompts to suppress this warning."
496
- )
497
- deprecate(
498
- "len(prompt) != len(image)",
499
- "1.0.0",
500
- deprecation_message,
501
- standard_warn=False,
502
- )
503
- additional_image_per_prompt = batch_size // image_latents.shape[0]
504
- image_latents = torch.cat(
505
- [image_latents] * additional_image_per_prompt, dim=0
506
- )
507
- elif (
508
- batch_size > image_latents.shape[0]
509
- and batch_size % image_latents.shape[0] != 0
510
- ):
511
- raise ValueError(
512
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
513
- )
514
- else:
515
- image_latents = torch.cat([image_latents], dim=0)
516
-
517
- if do_classifier_free_guidance:
518
- uncond_image_latents = torch.zeros_like(image_latents)
519
- image_latents = torch.cat(
520
- [image_latents, image_latents, uncond_image_latents], dim=0
521
- )
522
-
523
- return image_latents
524
-
525
- @torch.no_grad()
526
- def __call__(
527
- self,
528
- height: int,
529
- width: int,
530
- prompt: Union[str, List[str]] = None,
531
- albedo: Optional[
532
- Union[
533
- torch.FloatTensor,
534
- PIL.Image.Image,
535
- np.ndarray,
536
- List[torch.FloatTensor],
537
- List[PIL.Image.Image],
538
- List[np.ndarray],
539
- ]
540
- ] = None,
541
- normal: Optional[
542
- Union[
543
- torch.FloatTensor,
544
- PIL.Image.Image,
545
- np.ndarray,
546
- List[torch.FloatTensor],
547
- List[PIL.Image.Image],
548
- List[np.ndarray],
549
- ]
550
- ] = None,
551
- roughness: Optional[
552
- Union[
553
- torch.FloatTensor,
554
- PIL.Image.Image,
555
- np.ndarray,
556
- List[torch.FloatTensor],
557
- List[PIL.Image.Image],
558
- List[np.ndarray],
559
- ]
560
- ] = None,
561
- metallic: Optional[
562
- Union[
563
- torch.FloatTensor,
564
- PIL.Image.Image,
565
- np.ndarray,
566
- List[torch.FloatTensor],
567
- List[PIL.Image.Image],
568
- List[np.ndarray],
569
- ]
570
- ] = None,
571
- irradiance: Optional[
572
- Union[
573
- torch.FloatTensor,
574
- PIL.Image.Image,
575
- np.ndarray,
576
- List[torch.FloatTensor],
577
- List[PIL.Image.Image],
578
- List[np.ndarray],
579
- ]
580
- ] = None,
581
- guidance_scale: float = 0.0,
582
- image_guidance_scale: float = 0.0,
583
- guidance_rescale: float = 0.0,
584
- num_inference_steps: int = 100,
585
- required_aovs: List[str] = ["albedo"],
586
- return_predicted_x0s: bool = False,
587
- negative_prompt: Optional[Union[str, List[str]]] = None,
588
- num_images_per_prompt: Optional[int] = 1,
589
- eta: float = 0.0,
590
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
591
- latents: Optional[torch.FloatTensor] = None,
592
- prompt_embeds: Optional[torch.FloatTensor] = None,
593
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
594
- output_type: Optional[str] = "pil",
595
- return_dict: bool = True,
596
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
597
- callback_steps: int = 1,
598
- ):
599
- r"""
600
- The call function to the pipeline for generation.
601
-
602
- Args:
603
- prompt (`str` or `List[str]`, *optional*):
604
- The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
605
- image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
606
- `Image` or tensor representing an image batch to be repainted according to `prompt`. Can also accept
607
- image latents as `image`, but if passing latents directly it is not encoded again.
608
- num_inference_steps (`int`, *optional*, defaults to 100):
609
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
610
- expense of slower inference.
611
- guidance_scale (`float`, *optional*, defaults to 7.5):
612
- A higher guidance scale value encourages the model to generate images closely linked to the text
613
- `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
614
- image_guidance_scale (`float`, *optional*, defaults to 1.5):
615
- Push the generated image towards the inital `image`. Image guidance scale is enabled by setting
616
- `image_guidance_scale > 1`. Higher image guidance scale encourages generated images that are closely
617
- linked to the source `image`, usually at the expense of lower image quality. This pipeline requires a
618
- value of at least `1`.
619
- negative_prompt (`str` or `List[str]`, *optional*):
620
- The prompt or prompts to guide what to not include in image generation. If not defined, you need to
621
- pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
622
- num_images_per_prompt (`int`, *optional*, defaults to 1):
623
- The number of images to generate per prompt.
624
- eta (`float`, *optional*, defaults to 0.0):
625
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
626
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
627
- generator (`torch.Generator`, *optional*):
628
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
629
- generation deterministic.
630
- latents (`torch.FloatTensor`, *optional*):
631
- Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
632
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
633
- tensor is generated by sampling using the supplied random `generator`.
634
- prompt_embeds (`torch.FloatTensor`, *optional*):
635
- Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
636
- provided, text embeddings are generated from the `prompt` input argument.
637
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
638
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
639
- not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
640
- output_type (`str`, *optional*, defaults to `"pil"`):
641
- The output format of the generated image. Choose between `PIL.Image` or `np.array`.
642
- return_dict (`bool`, *optional*, defaults to `True`):
643
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
644
- plain tuple.
645
- callback (`Callable`, *optional*):
646
- A function that calls every `callback_steps` steps during inference. The function is called with the
647
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
648
- callback_steps (`int`, *optional*, defaults to 1):
649
- The frequency at which the `callback` function is called. If not specified, the callback is called at
650
- every step.
651
-
652
- Examples:
653
-
654
- ```py
655
- >>> import PIL
656
- >>> import requests
657
- >>> import torch
658
- >>> from io import BytesIO
659
-
660
- >>> from diffusers import StableDiffusionInstructPix2PixPipeline
661
-
662
-
663
- >>> def download_image(url):
664
- ... response = requests.get(url)
665
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
666
-
667
-
668
- >>> img_url = "https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png"
669
-
670
- >>> image = download_image(img_url).resize((512, 512))
671
-
672
- >>> pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
673
- ... "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
674
- ... )
675
- >>> pipe = pipe.to("cuda")
676
-
677
- >>> prompt = "make the mountains snowy"
678
- >>> image = pipe(prompt=prompt, image=image).images[0]
679
- ```
680
-
681
- Returns:
682
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
683
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
684
- otherwise a `tuple` is returned where the first element is a list with the generated images and the
685
- second element is a list of `bool`s indicating whether the corresponding generated image contains
686
- "not-safe-for-work" (nsfw) content.
687
- """
688
- # 0. Check inputs
689
- self.check_inputs(
690
- prompt,
691
- callback_steps,
692
- negative_prompt,
693
- prompt_embeds,
694
- negative_prompt_embeds,
695
- )
696
-
697
- # 1. Define call parameters
698
- if prompt is not None and isinstance(prompt, str):
699
- batch_size = 1
700
- elif prompt is not None and isinstance(prompt, list):
701
- batch_size = len(prompt)
702
- else:
703
- batch_size = prompt_embeds.shape[0]
704
-
705
- device = self._execution_device
706
- do_classifier_free_guidance = (
707
- guidance_scale >= 1.0 and image_guidance_scale >= 1.0
708
- )
709
- # check if scheduler is in sigmas space
710
- scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
711
-
712
- # 2. Encode input prompt
713
- prompt_embeds = self._encode_prompt(
714
- prompt,
715
- device,
716
- num_images_per_prompt,
717
- do_classifier_free_guidance,
718
- negative_prompt,
719
- prompt_embeds=prompt_embeds,
720
- negative_prompt_embeds=negative_prompt_embeds,
721
- )
722
-
723
- # 3. Preprocess image
724
- # For normal, the preprocessing does nothing
725
- # For others, the preprocessing remap the values to [-1, 1]
726
- preprocessed_aovs = {}
727
- for aov_name in required_aovs:
728
- if aov_name == "albedo":
729
- if albedo is not None:
730
- preprocessed_aovs[aov_name] = self.image_processor.preprocess(
731
- albedo
732
- )
733
- else:
734
- preprocessed_aovs[aov_name] = None
735
-
736
- if aov_name == "normal":
737
- if normal is not None:
738
- preprocessed_aovs[aov_name] = (
739
- self.image_processor.preprocess_normal(normal)
740
- )
741
- else:
742
- preprocessed_aovs[aov_name] = None
743
-
744
- if aov_name == "roughness":
745
- if roughness is not None:
746
- preprocessed_aovs[aov_name] = self.image_processor.preprocess(
747
- roughness
748
- )
749
- else:
750
- preprocessed_aovs[aov_name] = None
751
- if aov_name == "metallic":
752
- if metallic is not None:
753
- preprocessed_aovs[aov_name] = self.image_processor.preprocess(
754
- metallic
755
- )
756
- else:
757
- preprocessed_aovs[aov_name] = None
758
- if aov_name == "irradiance":
759
- if irradiance is not None:
760
- preprocessed_aovs[aov_name] = self.image_processor.preprocess(
761
- irradiance
762
- )
763
- else:
764
- preprocessed_aovs[aov_name] = None
765
-
766
- # 4. set timesteps
767
- self.scheduler.set_timesteps(num_inference_steps, device=device)
768
- timesteps = self.scheduler.timesteps
769
-
770
- # 5. Prepare latent variables
771
- num_channels_latents = self.vae.config.latent_channels
772
- latents = self.prepare_latents(
773
- batch_size * num_images_per_prompt,
774
- num_channels_latents,
775
- height,
776
- width,
777
- prompt_embeds.dtype,
778
- device,
779
- generator,
780
- latents,
781
- )
782
-
783
- height_latent, width_latent = latents.shape[-2:]
784
-
785
- # 6. Prepare Image latents
786
- image_latents = []
787
- # Magicial scaling factors for each AOV (calculated from the training data)
788
- scaling_factors = {
789
- "albedo": 0.17301377137652138,
790
- "normal": 0.17483895473058078,
791
- "roughness": 0.1680724853626448,
792
- "metallic": 0.13135013390855135,
793
- }
794
- for aov_name, aov in preprocessed_aovs.items():
795
- if aov is None:
796
- image_latent = torch.zeros(
797
- batch_size,
798
- num_channels_latents,
799
- height_latent,
800
- width_latent,
801
- dtype=prompt_embeds.dtype,
802
- device=device,
803
- )
804
- if aov_name == "irradiance":
805
- image_latent = image_latent[:, 0:3]
806
- if do_classifier_free_guidance:
807
- image_latents.append(
808
- torch.cat([image_latent, image_latent, image_latent], dim=0)
809
- )
810
- else:
811
- image_latents.append(image_latent)
812
- else:
813
- if aov_name == "irradiance":
814
- image_latent = F.interpolate(
815
- aov.to(device=device, dtype=prompt_embeds.dtype),
816
- size=(height_latent, width_latent),
817
- mode="bilinear",
818
- align_corners=False,
819
- antialias=True,
820
- )
821
- if do_classifier_free_guidance:
822
- uncond_image_latent = torch.zeros_like(image_latent)
823
- image_latent = torch.cat(
824
- [image_latent, image_latent, uncond_image_latent], dim=0
825
- )
826
- else:
827
- scaling_factor = scaling_factors[aov_name]
828
- image_latent = (
829
- self.prepare_image_latents(
830
- aov,
831
- batch_size,
832
- num_images_per_prompt,
833
- prompt_embeds.dtype,
834
- device,
835
- do_classifier_free_guidance,
836
- generator,
837
- )
838
- * scaling_factor
839
- )
840
- image_latents.append(image_latent)
841
- image_latents = torch.cat(image_latents, dim=1)
842
-
843
- # 7. Check that shapes of latents and image match the UNet channels
844
- num_channels_image = image_latents.shape[1]
845
- if num_channels_latents + num_channels_image != self.unet.config.in_channels:
846
- raise ValueError(
847
- f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
848
- f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
849
- f" `num_channels_image`: {num_channels_image} "
850
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
851
- " `pipeline.unet` or your `image` input."
852
- )
853
-
854
- # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
855
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
856
-
857
- predicted_x0s = []
858
-
859
- # 9. Denoising loop
860
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
861
- with self.progress_bar(total=num_inference_steps) as progress_bar:
862
- for i, t in enumerate(timesteps):
863
- # Expand the latents if we are doing classifier free guidance.
864
- # The latents are expanded 3 times because for pix2pix the guidance\
865
- # is applied for both the text and the input image.
866
- latent_model_input = (
867
- torch.cat([latents] * 3) if do_classifier_free_guidance else latents
868
- )
869
-
870
- # concat latents, image_latents in the channel dimension
871
- scaled_latent_model_input = self.scheduler.scale_model_input(
872
- latent_model_input, t
873
- )
874
- scaled_latent_model_input = torch.cat(
875
- [scaled_latent_model_input, image_latents], dim=1
876
- )
877
-
878
- # predict the noise residual
879
- noise_pred = self.unet(
880
- scaled_latent_model_input,
881
- t,
882
- encoder_hidden_states=prompt_embeds,
883
- return_dict=False,
884
- )[0]
885
-
886
- # perform guidance
887
- if do_classifier_free_guidance:
888
- (
889
- noise_pred_text,
890
- noise_pred_image,
891
- noise_pred_uncond,
892
- ) = noise_pred.chunk(3)
893
- noise_pred = (
894
- noise_pred_uncond
895
- + guidance_scale * (noise_pred_text - noise_pred_image)
896
- + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
897
- )
898
-
899
- if do_classifier_free_guidance and guidance_rescale > 0.0:
900
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
901
- noise_pred = rescale_noise_cfg(
902
- noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
903
- )
904
-
905
- # compute the previous noisy sample x_t -> x_t-1
906
- output = self.scheduler.step(
907
- noise_pred, t, latents, **extra_step_kwargs, return_dict=True
908
- )
909
-
910
- latents = output[0]
911
-
912
- if return_predicted_x0s:
913
- predicted_x0s.append(output[1])
914
-
915
- # call the callback, if provided
916
- if i == len(timesteps) - 1 or (
917
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
918
- ):
919
- progress_bar.update()
920
- if callback is not None and i % callback_steps == 0:
921
- callback(i, t, latents)
922
-
923
- if not output_type == "latent":
924
- image = self.vae.decode(
925
- latents / self.vae.config.scaling_factor, return_dict=False
926
- )[0]
927
-
928
- if return_predicted_x0s:
929
- predicted_x0_images = [
930
- self.vae.decode(
931
- predicted_x0 / self.vae.config.scaling_factor, return_dict=False
932
- )[0]
933
- for predicted_x0 in predicted_x0s
934
- ]
935
- else:
936
- image = latents
937
- predicted_x0_images = predicted_x0s
938
-
939
- do_denormalize = [True] * image.shape[0]
940
-
941
- image = self.image_processor.postprocess(
942
- image, output_type=output_type, do_denormalize=do_denormalize
943
- )
944
-
945
- if return_predicted_x0s:
946
- predicted_x0_images = [
947
- self.image_processor.postprocess(
948
- predicted_x0_image,
949
- output_type=output_type,
950
- do_denormalize=do_denormalize,
951
- )
952
- for predicted_x0_image in predicted_x0_images
953
- ]
954
-
955
- # Offload last model to CPU
956
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
957
- self.final_offload_hook.offload()
958
-
959
- if not return_dict:
960
- return image
961
-
962
- if return_predicted_x0s:
963
- return StableDiffusionAOVPipelineOutput(
964
- images=image, predicted_x0_images=predicted_x0_images
965
- )
966
- else:
967
- return StableDiffusionAOVPipelineOutput(images=image)