Profakerr commited on
Commit
bd92452
·
verified ·
1 Parent(s): 2aa689b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +253 -0
  2. pipeline_fill_sd_xl.py +521 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from RealESRGAN import RealESRGAN
4
+ import torch
5
+ from diffusers import AutoencoderKL, TCDScheduler, DPMSolverMultistepScheduler
6
+ from diffusers.models.model_loading_utils import load_state_dict
7
+ from gradio_imageslider import ImageSlider
8
+ from huggingface_hub import hf_hub_download
9
+ from PIL import ImageDraw, ImageFont, Image
10
+
11
+ from controlnet_union import ControlNetModel_Union
12
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
13
+
14
+ MODELS = {
15
+ "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
16
+ }
17
+
18
+ config_file = hf_hub_download(
19
+ "xinsir/controlnet-union-sdxl-1.0",
20
+ filename="config_promax.json",
21
+ )
22
+
23
+ config = ControlNetModel_Union.load_config(config_file)
24
+ controlnet_model = ControlNetModel_Union.from_config(config)
25
+ model_file = hf_hub_download(
26
+ "xinsir/controlnet-union-sdxl-1.0",
27
+ filename="diffusion_pytorch_model_promax.safetensors",
28
+ )
29
+ state_dict = load_state_dict(model_file)
30
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
+ )
33
+ model.to(device="cuda", dtype=torch.float16)
34
+
35
+ vae = AutoencoderKL.from_pretrained(
36
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
37
+ ).to("cuda")
38
+
39
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
+ "SG161222/RealVisXL_V5.0_Lightning",
41
+ torch_dtype=torch.float16,
42
+ vae=vae,
43
+ controlnet=model,
44
+ variant="fp16",
45
+ ).to("cuda")
46
+
47
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config,algorithm_type="dpmsolver++",use_karras_sigmas=True)
48
+
49
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
50
+ model2 = RealESRGAN(device, scale=2)
51
+ model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
52
+ model4 = RealESRGAN(device, scale=4)
53
+ model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
54
+
55
+
56
+ @spaces.GPU
57
+ def inference(image, size):
58
+ global model2
59
+ global model4
60
+ global model8
61
+ if image is None:
62
+ raise gr.Error("Image not uploaded")
63
+
64
+
65
+ if torch.cuda.is_available():
66
+ torch.cuda.empty_cache()
67
+
68
+ if size == '2x':
69
+ try:
70
+ result = model2.predict(image.convert('RGB'))
71
+ except torch.cuda.OutOfMemoryError as e:
72
+ print(e)
73
+ model2 = RealESRGAN(device, scale=2)
74
+ model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
75
+ result = model2.predict(image.convert('RGB'))
76
+ elif size == '4x':
77
+ try:
78
+ result = model4.predict(image.convert('RGB'))
79
+ except torch.cuda.OutOfMemoryError as e:
80
+ print(e)
81
+ model4 = RealESRGAN(device, scale=4)
82
+ model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
83
+ result = model2.predict(image.convert('RGB'))
84
+
85
+ print(f"Image size ({device}): {size} ... OK")
86
+ return result
87
+
88
+ def add_watermark(image, text="ProFaker", font_path="BRLNSDB.TTF", font_size=25):
89
+ # Load the Berlin Sans Demi font with the specified size
90
+ font = ImageFont.truetype(font_path, font_size)
91
+
92
+ # Position the watermark in the bottom right corner, adjusting for text size
93
+ text_bbox = font.getbbox(text)
94
+ text_width, text_height = text_bbox[2], text_bbox[3]
95
+ watermark_position = (image.width - text_width - 100, image.height - text_height - 150)
96
+
97
+ # Draw the watermark text with a translucent white color
98
+ draw = ImageDraw.Draw(image)
99
+ draw.text(watermark_position, text, font=font, fill=(255, 255, 255, 150)) # RGBA for transparency
100
+
101
+ return image
102
+
103
+ @spaces.GPU
104
+ def fill_image(prompt, negative_prompt, image, model_selection, paste_back, guidance_scale, num_steps, size):
105
+ (
106
+ prompt_embeds,
107
+ negative_prompt_embeds,
108
+ pooled_prompt_embeds,
109
+ negative_pooled_prompt_embeds,
110
+ ) = pipe.encode_prompt(prompt, "cuda", True,negative_prompt=negative_prompt)
111
+
112
+ source = image["background"]
113
+ mask = image["layers"][0]
114
+
115
+ alpha_channel = mask.split()[3]
116
+ binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
117
+ cnet_image = source.copy()
118
+ cnet_image.paste(0, (0, 0), binary_mask)
119
+
120
+ for image in pipe(
121
+ prompt_embeds=prompt_embeds,
122
+ negative_prompt_embeds=negative_prompt_embeds,
123
+ pooled_prompt_embeds=pooled_prompt_embeds,
124
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
125
+ image=cnet_image,
126
+ guidance_scale = guidance_scale,
127
+ num_inference_steps = num_steps,
128
+ ):
129
+ yield image, cnet_image
130
+
131
+ print(f"{model_selection=}")
132
+ print(f"{paste_back=}")
133
+
134
+ if paste_back:
135
+ image = image.convert("RGBA")
136
+ cnet_image.paste(image, (0, 0), binary_mask)
137
+ else:
138
+ cnet_image = image
139
+
140
+ cnet_image = add_watermark(cnet_image)
141
+ if size !="0":
142
+ cnet_image = inference(cnet_image,size)
143
+ yield source, cnet_image
144
+
145
+
146
+ def clear_result():
147
+ return gr.update(value=None)
148
+
149
+
150
+ title = """<h1 align="center">ProFaker</h1>"""
151
+
152
+ with gr.Blocks() as demo:
153
+ gr.HTML(title)
154
+ with gr.Row():
155
+ with gr.Column():
156
+ prompt = gr.Textbox(
157
+ label="Prompt",
158
+ info="Describe what to inpaint the mask with",
159
+ lines=3,
160
+ )
161
+
162
+ with gr.Accordion("Advanced Options", open=False):
163
+ negative_prompt = gr.Textbox(
164
+ label="Negative Prompt",
165
+ info="Describe what you dont want in the mask",
166
+ lines=3,
167
+ )
168
+ guidance_scale = gr.Slider(
169
+ minimum=1,
170
+ maximum=10,
171
+ value=1.5,
172
+ step=0.1,
173
+ label="Guidance Scale"
174
+ )
175
+ num_steps = gr.Slider(
176
+ minimum=5,
177
+ maximum=100,
178
+ value=10,
179
+ step=1,
180
+ label="Steps"
181
+ )
182
+ size = gr.Radio(["0", "2x", "4x"], type="value", value="0", label="Image Quality")
183
+
184
+ input_image = gr.ImageMask(
185
+ type="pil", label="Input Image", crop_size=(1024,1024), layers=False
186
+ )
187
+ with gr.Column():
188
+ model_selection = gr.Dropdown(
189
+ choices=list(MODELS.keys()),
190
+ value="RealVisXL V5.0 Lightning",
191
+ label="Model",
192
+ )
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ run_button = gr.Button("Generate")
197
+
198
+ with gr.Column():
199
+ paste_back = gr.Checkbox(True, label="Paste back original")
200
+
201
+ result = ImageSlider(
202
+ interactive=False,
203
+ label="Generated Image",
204
+ type="pil"
205
+ )
206
+
207
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
208
+
209
+ def use_output_as_input(output_image):
210
+ return gr.update(value=output_image[1])
211
+
212
+ use_as_input_button.click(
213
+ fn=use_output_as_input, inputs=[result], outputs=[input_image]
214
+ )
215
+
216
+ run_button.click(
217
+ fn=clear_result,
218
+ inputs=None,
219
+ outputs=result,
220
+ ).then(
221
+ fn=lambda: gr.update(visible=False),
222
+ inputs=None,
223
+ outputs=use_as_input_button,
224
+ ).then(
225
+ fn=fill_image,
226
+ inputs=[prompt, negative_prompt, input_image, model_selection, paste_back, guidance_scale, num_steps, size],
227
+ outputs=result,
228
+ ).then(
229
+ fn=lambda: gr.update(visible=True),
230
+ inputs=None,
231
+ outputs=use_as_input_button,
232
+ )
233
+
234
+ prompt.submit(
235
+ fn=clear_result,
236
+ inputs=None,
237
+ outputs=result,
238
+ ).then(
239
+ fn=lambda: gr.update(visible=False),
240
+ inputs=None,
241
+ outputs=use_as_input_button,
242
+ ).then(
243
+ fn=fill_image,
244
+ inputs=[prompt, negative_prompt, input_image, model_selection, paste_back, guidance_scale, num_steps, size],
245
+ outputs=result,
246
+ ).then(
247
+ fn=lambda: gr.update(visible=True),
248
+ inputs=None,
249
+ outputs=use_as_input_button,
250
+ )
251
+
252
+
253
+ demo.queue(max_size=12).launch(share=False)
pipeline_fill_sd_xl.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ import cv2
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
22
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
24
+ from diffusers.schedulers import KarrasDiffusionSchedulers
25
+ from diffusers import DPMSolverMultistepScheduler
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
28
+
29
+ from controlnet_union import ControlNetModel_Union
30
+
31
+
32
+ def latents_to_rgb(latents):
33
+ weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
34
+
35
+ weights_tensor = torch.t(
36
+ torch.tensor(weights, dtype=latents.dtype).to(latents.device)
37
+ )
38
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
39
+ latents.device
40
+ )
41
+ rgb_tensor = torch.einsum(
42
+ "...lxy,lr -> ...rxy", latents, weights_tensor
43
+ ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
44
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
45
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
46
+
47
+ denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
48
+ blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
49
+ final_image = PIL.Image.fromarray(blurred_image)
50
+
51
+ width, height = final_image.size
52
+ final_image = final_image.resize(
53
+ (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
54
+ )
55
+
56
+ return final_image
57
+
58
+
59
+ def retrieve_timesteps(
60
+ scheduler,
61
+ num_inference_steps: Optional[int] = None,
62
+ device: Optional[Union[str, torch.device]] = None,
63
+ **kwargs,
64
+ ):
65
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
66
+ timesteps = scheduler.timesteps
67
+
68
+ return timesteps, num_inference_steps
69
+
70
+
71
+ class StableDiffusionXLFillPipeline(DiffusionPipeline, StableDiffusionMixin):
72
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
73
+ _optional_components = [
74
+ "tokenizer",
75
+ "tokenizer_2",
76
+ "text_encoder",
77
+ "text_encoder_2",
78
+ ]
79
+
80
+ def __init__(
81
+ self,
82
+ vae: AutoencoderKL,
83
+ text_encoder: CLIPTextModel,
84
+ text_encoder_2: CLIPTextModelWithProjection,
85
+ tokenizer: CLIPTokenizer,
86
+ tokenizer_2: CLIPTokenizer,
87
+ unet: UNet2DConditionModel,
88
+ controlnet: ControlNetModel_Union,
89
+ scheduler: DPMSolverMultistepScheduler,
90
+ force_zeros_for_empty_prompt: bool = True,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.register_modules(
95
+ vae=vae,
96
+ text_encoder=text_encoder,
97
+ text_encoder_2=text_encoder_2,
98
+ tokenizer=tokenizer,
99
+ tokenizer_2=tokenizer_2,
100
+ unet=unet,
101
+ controlnet=controlnet,
102
+ scheduler=scheduler,
103
+ )
104
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
105
+ self.image_processor = VaeImageProcessor(
106
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
107
+ )
108
+ self.control_image_processor = VaeImageProcessor(
109
+ vae_scale_factor=self.vae_scale_factor,
110
+ do_convert_rgb=True,
111
+ do_normalize=False,
112
+ )
113
+
114
+ self.register_to_config(
115
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
116
+ )
117
+
118
+ def encode_prompt(
119
+ self,
120
+ prompt: str,
121
+ device: Optional[torch.device] = None,
122
+ do_classifier_free_guidance: bool = True,
123
+ negative_prompt: Optional[str] = None,
124
+ ):
125
+ device = device or self._execution_device
126
+ prompt = [prompt] if isinstance(prompt, str) else prompt
127
+
128
+ if prompt is not None:
129
+ batch_size = len(prompt)
130
+ else:
131
+ raise ValueError("Prompt cannot be None")
132
+
133
+ # Handle negative prompt
134
+ if negative_prompt is None:
135
+ negative_prompt = "" if do_classifier_free_guidance else None
136
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
137
+
138
+ # Define tokenizers and text encoders
139
+ tokenizers = (
140
+ [self.tokenizer, self.tokenizer_2]
141
+ if self.tokenizer is not None
142
+ else [self.tokenizer_2]
143
+ )
144
+ text_encoders = (
145
+ [self.text_encoder, self.text_encoder_2]
146
+ if self.text_encoder is not None
147
+ else [self.text_encoder_2]
148
+ )
149
+
150
+ prompt_2 = prompt
151
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
152
+ negative_prompt_2 = negative_prompt
153
+ negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
154
+
155
+ # Process prompt embeddings
156
+ prompt_embeds_list = []
157
+ prompts = [prompt, prompt_2]
158
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
159
+ text_inputs = tokenizer(
160
+ prompt,
161
+ padding="max_length",
162
+ truncation=True,
163
+ return_tensors="pt",
164
+ )
165
+
166
+ text_input_ids = text_inputs.input_ids
167
+ prompt_embeds = text_encoder(
168
+ text_input_ids.to(device),
169
+ output_hidden_states=True,
170
+ )
171
+
172
+ # We are only ALWAYS interested in the pooled output of the final text encoder
173
+ pooled_prompt_embeds = prompt_embeds[0]
174
+ prompt_embeds = prompt_embeds.hidden_states[-2]
175
+ prompt_embeds_list.append(prompt_embeds)
176
+
177
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
178
+
179
+ # Process negative prompt embeddings
180
+ negative_prompt_embeds_list = []
181
+ if do_classifier_free_guidance:
182
+ negative_prompts = [negative_prompt, negative_prompt_2]
183
+ for neg_prompt, tokenizer, text_encoder in zip(negative_prompts, tokenizers, text_encoders):
184
+ uncond_input = tokenizer(
185
+ neg_prompt,
186
+ padding="max_length",
187
+ max_length=text_inputs.input_ids.shape[-1],
188
+ truncation=True,
189
+ return_tensors="pt",
190
+ )
191
+
192
+ negative_prompt_embeds = text_encoder(
193
+ uncond_input.input_ids.to(device),
194
+ output_hidden_states=True,
195
+ )
196
+ # Get pooled and hidden state embeddings
197
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
198
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
199
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
200
+
201
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
202
+ else:
203
+ negative_prompt_embeds = None
204
+ negative_pooled_prompt_embeds = None
205
+
206
+ # Convert to proper dtype
207
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
208
+ if negative_prompt_embeds is not None:
209
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
210
+
211
+ # Reshape embeddings
212
+ bs_embed, seq_len, _ = prompt_embeds.shape
213
+ prompt_embeds = prompt_embeds.repeat(1, 1, 1)
214
+ prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
215
+
216
+ if do_classifier_free_guidance:
217
+ seq_len = negative_prompt_embeds.shape[1]
218
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)
219
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * 1, seq_len, -1)
220
+
221
+ # Handle pooled embeddings
222
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
223
+ if do_classifier_free_guidance:
224
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
225
+
226
+ return (
227
+ prompt_embeds,
228
+ negative_prompt_embeds,
229
+ pooled_prompt_embeds,
230
+ negative_pooled_prompt_embeds,
231
+ )
232
+
233
+ def check_inputs(
234
+ self,
235
+ prompt_embeds,
236
+ negative_prompt_embeds,
237
+ pooled_prompt_embeds,
238
+ negative_pooled_prompt_embeds,
239
+ image,
240
+ controlnet_conditioning_scale=1.0,
241
+ ):
242
+ if prompt_embeds is None:
243
+ raise ValueError(
244
+ "Provide `prompt_embeds`. Cannot leave `prompt_embeds` undefined."
245
+ )
246
+
247
+ if negative_prompt_embeds is None:
248
+ raise ValueError(
249
+ "Provide `negative_prompt_embeds`. Cannot leave `negative_prompt_embeds` undefined."
250
+ )
251
+
252
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
253
+ raise ValueError(
254
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
255
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
256
+ f" {negative_prompt_embeds.shape}."
257
+ )
258
+
259
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
260
+ raise ValueError(
261
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
262
+ )
263
+
264
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
265
+ raise ValueError(
266
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
267
+ )
268
+
269
+ # Check `image`
270
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
271
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
272
+ )
273
+ if (
274
+ isinstance(self.controlnet, ControlNetModel_Union)
275
+ or is_compiled
276
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
277
+ ):
278
+ if not isinstance(image, PIL.Image.Image):
279
+ raise TypeError(
280
+ f"image must be passed and has to be a PIL image, but is {type(image)}"
281
+ )
282
+
283
+ else:
284
+ assert False
285
+
286
+ # Check `controlnet_conditioning_scale`
287
+ if (
288
+ isinstance(self.controlnet, ControlNetModel_Union)
289
+ or is_compiled
290
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
291
+ ):
292
+ if not isinstance(controlnet_conditioning_scale, float):
293
+ raise TypeError(
294
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
295
+ )
296
+ else:
297
+ assert False
298
+
299
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
300
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
301
+
302
+ image_batch_size = image.shape[0]
303
+
304
+ image = image.repeat_interleave(image_batch_size, dim=0)
305
+ image = image.to(device=device, dtype=dtype)
306
+
307
+ if do_classifier_free_guidance:
308
+ image = torch.cat([image] * 2)
309
+
310
+ return image
311
+
312
+ def prepare_latents(
313
+ self, batch_size, num_channels_latents, height, width, dtype, device
314
+ ):
315
+ shape = (
316
+ batch_size,
317
+ num_channels_latents,
318
+ int(height) // self.vae_scale_factor,
319
+ int(width) // self.vae_scale_factor,
320
+ )
321
+
322
+ latents = randn_tensor(shape, device=device, dtype=dtype)
323
+
324
+ # scale the initial noise by the standard deviation required by the scheduler
325
+ latents = latents * self.scheduler.init_noise_sigma
326
+ return latents
327
+
328
+ @property
329
+ def guidance_scale(self):
330
+ return self._guidance_scale
331
+
332
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
333
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
334
+ # corresponds to doing no classifier free guidance.
335
+ @property
336
+ def do_classifier_free_guidance(self):
337
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
338
+
339
+ @property
340
+ def num_timesteps(self):
341
+ return self._num_timesteps
342
+
343
+ @torch.no_grad()
344
+ def __call__(
345
+ self,
346
+ prompt_embeds: torch.Tensor,
347
+ negative_prompt_embeds: torch.Tensor,
348
+ pooled_prompt_embeds: torch.Tensor,
349
+ negative_pooled_prompt_embeds: torch.Tensor,
350
+ image: PipelineImageInput = None,
351
+ num_inference_steps: int = 15,
352
+ guidance_scale: float = 1.5,
353
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
354
+ ):
355
+ # 1. Check inputs. Raise error if not correct
356
+ self.check_inputs(
357
+ prompt_embeds,
358
+ negative_prompt_embeds,
359
+ pooled_prompt_embeds,
360
+ negative_pooled_prompt_embeds,
361
+ image,
362
+ controlnet_conditioning_scale,
363
+ )
364
+
365
+ self._guidance_scale = guidance_scale
366
+
367
+ # 2. Define call parameters
368
+ batch_size = 1
369
+ device = self._execution_device
370
+
371
+ # 4. Prepare image
372
+ if isinstance(self.controlnet, ControlNetModel_Union):
373
+ image = self.prepare_image(
374
+ image=image,
375
+ device=device,
376
+ dtype=self.controlnet.dtype,
377
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
378
+ )
379
+ height, width = image.shape[-2:]
380
+ else:
381
+ assert False
382
+
383
+ # 5. Prepare timesteps
384
+ timesteps, num_inference_steps = retrieve_timesteps(
385
+ self.scheduler, num_inference_steps, device
386
+ )
387
+ self._num_timesteps = len(timesteps)
388
+
389
+ # 6. Prepare latent variables
390
+ num_channels_latents = self.unet.config.in_channels
391
+ latents = self.prepare_latents(
392
+ batch_size,
393
+ num_channels_latents,
394
+ height,
395
+ width,
396
+ prompt_embeds.dtype,
397
+ device,
398
+ )
399
+
400
+ # 7 Prepare added time ids & embeddings
401
+ add_text_embeds = pooled_prompt_embeds
402
+
403
+ add_time_ids = negative_add_time_ids = torch.tensor(
404
+ image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:]
405
+ ).unsqueeze(0)
406
+
407
+ if self.do_classifier_free_guidance:
408
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
409
+ add_text_embeds = torch.cat(
410
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
411
+ )
412
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
413
+
414
+ prompt_embeds = prompt_embeds.to(device)
415
+ add_text_embeds = add_text_embeds.to(device)
416
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
417
+
418
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
419
+ union_control_type = (
420
+ torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0])
421
+ .to(device, dtype=prompt_embeds.dtype)
422
+ .repeat(batch_size * 2, 1)
423
+ )
424
+
425
+ added_cond_kwargs = {
426
+ "text_embeds": add_text_embeds,
427
+ "time_ids": add_time_ids,
428
+ "control_type": union_control_type,
429
+ }
430
+
431
+ controlnet_prompt_embeds = prompt_embeds
432
+ controlnet_added_cond_kwargs = added_cond_kwargs
433
+
434
+ # 8. Denoising loop
435
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
436
+
437
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
438
+ for i, t in enumerate(timesteps):
439
+ # expand the latents if we are doing classifier free guidance
440
+ latent_model_input = (
441
+ torch.cat([latents] * 2)
442
+ if self.do_classifier_free_guidance
443
+ else latents
444
+ )
445
+ latent_model_input = self.scheduler.scale_model_input(
446
+ latent_model_input, t
447
+ )
448
+
449
+ # controlnet(s) inference
450
+ control_model_input = latent_model_input
451
+
452
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
453
+ control_model_input,
454
+ t,
455
+ encoder_hidden_states=controlnet_prompt_embeds,
456
+ controlnet_cond_list=controlnet_image_list,
457
+ conditioning_scale=controlnet_conditioning_scale,
458
+ guess_mode=False,
459
+ added_cond_kwargs=controlnet_added_cond_kwargs,
460
+ return_dict=False,
461
+ )
462
+
463
+ # predict the noise residual
464
+ noise_pred = self.unet(
465
+ latent_model_input,
466
+ t,
467
+ encoder_hidden_states=prompt_embeds,
468
+ timestep_cond=None,
469
+ cross_attention_kwargs={},
470
+ down_block_additional_residuals=down_block_res_samples,
471
+ mid_block_additional_residual=mid_block_res_sample,
472
+ added_cond_kwargs=added_cond_kwargs,
473
+ return_dict=False,
474
+ )[0]
475
+
476
+ # perform guidance
477
+ if self.do_classifier_free_guidance:
478
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
479
+ noise_pred = noise_pred_uncond + guidance_scale * (
480
+ noise_pred_text - noise_pred_uncond
481
+ )
482
+
483
+ # compute the previous noisy sample x_t -> x_t-1
484
+ latents = self.scheduler.step(
485
+ noise_pred, t, latents, return_dict=False
486
+ )[0]
487
+
488
+ if i == 2:
489
+ prompt_embeds = prompt_embeds[-1:]
490
+ add_text_embeds = add_text_embeds[-1:]
491
+ add_time_ids = add_time_ids[-1:]
492
+ union_control_type = union_control_type[-1:]
493
+
494
+ added_cond_kwargs = {
495
+ "text_embeds": add_text_embeds,
496
+ "time_ids": add_time_ids,
497
+ "control_type": union_control_type,
498
+ }
499
+
500
+ controlnet_prompt_embeds = prompt_embeds
501
+ controlnet_added_cond_kwargs = added_cond_kwargs
502
+
503
+ image = image[-1:]
504
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
505
+
506
+ self._guidance_scale = 0.0
507
+
508
+ if i == len(timesteps) - 1 or (
509
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
510
+ ):
511
+ progress_bar.update()
512
+ yield latents_to_rgb(latents)
513
+
514
+ latents = latents / self.vae.config.scaling_factor
515
+ image = self.vae.decode(latents, return_dict=False)[0]
516
+ image = self.image_processor.postprocess(image)[0]
517
+
518
+ # Offload all models
519
+ self.maybe_free_model_hooks()
520
+
521
+ yield image