Profakerr commited on
Commit
2aa689b
·
verified ·
1 Parent(s): 09f2d9c

Delete pipeline_fill_sd_xl.py

Browse files
Files changed (1) hide show
  1. pipeline_fill_sd_xl.py +0 -559
pipeline_fill_sd_xl.py DELETED
@@ -1,559 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import List, Optional, Union
16
-
17
- import cv2
18
- import PIL.Image
19
- import torch
20
- import torch.nn.functional as F
21
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
22
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
23
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
24
- from diffusers.schedulers import KarrasDiffusionSchedulers
25
- from diffusers.utils.torch_utils import randn_tensor
26
- from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
27
-
28
- from controlnet_union import ControlNetModel_Union
29
-
30
-
31
- def latents_to_rgb(latents):
32
- weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
33
-
34
- weights_tensor = torch.t(
35
- torch.tensor(weights, dtype=latents.dtype).to(latents.device)
36
- )
37
- biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
38
- latents.device
39
- )
40
- rgb_tensor = torch.einsum(
41
- "...lxy,lr -> ...rxy", latents, weights_tensor
42
- ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
43
- image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
44
- image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
45
-
46
- denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
47
- blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
48
- final_image = PIL.Image.fromarray(blurred_image)
49
-
50
- width, height = final_image.size
51
- final_image = final_image.resize(
52
- (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
53
- )
54
-
55
- return final_image
56
-
57
-
58
- def retrieve_timesteps(
59
- scheduler,
60
- num_inference_steps: Optional[int] = None,
61
- device: Optional[Union[str, torch.device]] = None,
62
- **kwargs,
63
- ):
64
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
65
- timesteps = scheduler.timesteps
66
-
67
- return timesteps, num_inference_steps
68
-
69
-
70
- class StableDiffusionXLFillPipeline(DiffusionPipeline, StableDiffusionMixin):
71
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
72
- _optional_components = [
73
- "tokenizer",
74
- "tokenizer_2",
75
- "text_encoder",
76
- "text_encoder_2",
77
- ]
78
-
79
- def __init__(
80
- self,
81
- vae: AutoencoderKL,
82
- text_encoder: CLIPTextModel,
83
- text_encoder_2: CLIPTextModelWithProjection,
84
- tokenizer: CLIPTokenizer,
85
- tokenizer_2: CLIPTokenizer,
86
- unet: UNet2DConditionModel,
87
- controlnet: ControlNetModel_Union,
88
- scheduler: KarrasDiffusionSchedulers,
89
- force_zeros_for_empty_prompt: bool = True,
90
- ):
91
- super().__init__()
92
-
93
- self.register_modules(
94
- vae=vae,
95
- text_encoder=text_encoder,
96
- text_encoder_2=text_encoder_2,
97
- tokenizer=tokenizer,
98
- tokenizer_2=tokenizer_2,
99
- unet=unet,
100
- controlnet=controlnet,
101
- scheduler=scheduler,
102
- )
103
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
104
- self.image_processor = VaeImageProcessor(
105
- vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
106
- )
107
- self.control_image_processor = VaeImageProcessor(
108
- vae_scale_factor=self.vae_scale_factor,
109
- do_convert_rgb=True,
110
- do_normalize=False,
111
- )
112
-
113
- self.register_to_config(
114
- force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
115
- )
116
-
117
- def encode_prompt(
118
- self,
119
- prompt: str,
120
- device: Optional[torch.device] = None,
121
- do_classifier_free_guidance: bool = True,
122
- ):
123
- device = device or self._execution_device
124
- prompt = [prompt] if isinstance(prompt, str) else prompt
125
-
126
- if prompt is not None:
127
- batch_size = len(prompt)
128
-
129
- # Define tokenizers and text encoders
130
- tokenizers = (
131
- [self.tokenizer, self.tokenizer_2]
132
- if self.tokenizer is not None
133
- else [self.tokenizer_2]
134
- )
135
- text_encoders = (
136
- [self.text_encoder, self.text_encoder_2]
137
- if self.text_encoder is not None
138
- else [self.text_encoder_2]
139
- )
140
-
141
- prompt_2 = prompt
142
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
143
-
144
- # textual inversion: process multi-vector tokens if necessary
145
- prompt_embeds_list = []
146
- prompts = [prompt, prompt_2]
147
- for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
148
- text_inputs = tokenizer(
149
- prompt,
150
- padding="max_length",
151
- max_length=tokenizer.model_max_length,
152
- truncation=True,
153
- return_tensors="pt",
154
- )
155
-
156
- text_input_ids = text_inputs.input_ids
157
-
158
- prompt_embeds = text_encoder(
159
- text_input_ids.to(device), output_hidden_states=True
160
- )
161
-
162
- # We are only ALWAYS interested in the pooled output of the final text encoder
163
- pooled_prompt_embeds = prompt_embeds[0]
164
- prompt_embeds = prompt_embeds.hidden_states[-2]
165
- prompt_embeds_list.append(prompt_embeds)
166
-
167
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
168
-
169
- # get unconditional embeddings for classifier free guidance
170
- zero_out_negative_prompt = True
171
- negative_prompt_embeds = None
172
- negative_pooled_prompt_embeds = None
173
-
174
- if do_classifier_free_guidance and zero_out_negative_prompt:
175
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
176
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
177
- elif do_classifier_free_guidance and negative_prompt_embeds is None:
178
- negative_prompt = ""
179
- negative_prompt_2 = negative_prompt
180
-
181
- # normalize str to list
182
- negative_prompt = (
183
- batch_size * [negative_prompt]
184
- if isinstance(negative_prompt, str)
185
- else negative_prompt
186
- )
187
- negative_prompt_2 = (
188
- batch_size * [negative_prompt_2]
189
- if isinstance(negative_prompt_2, str)
190
- else negative_prompt_2
191
- )
192
-
193
- uncond_tokens: List[str]
194
- if prompt is not None and type(prompt) is not type(negative_prompt):
195
- raise TypeError(
196
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
197
- f" {type(prompt)}."
198
- )
199
- elif batch_size != len(negative_prompt):
200
- raise ValueError(
201
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
202
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
203
- " the batch size of `prompt`."
204
- )
205
- else:
206
- uncond_tokens = [negative_prompt, negative_prompt_2]
207
-
208
- negative_prompt_embeds_list = []
209
- for negative_prompt, tokenizer, text_encoder in zip(
210
- uncond_tokens, tokenizers, text_encoders
211
- ):
212
- max_length = prompt_embeds.shape[1]
213
- uncond_input = tokenizer(
214
- negative_prompt,
215
- padding="max_length",
216
- max_length=max_length,
217
- truncation=True,
218
- return_tensors="pt",
219
- )
220
-
221
- negative_prompt_embeds = text_encoder(
222
- uncond_input.input_ids.to(device),
223
- output_hidden_states=True,
224
- )
225
- # We are only ALWAYS interested in the pooled output of the final text encoder
226
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
227
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
228
-
229
- negative_prompt_embeds_list.append(negative_prompt_embeds)
230
-
231
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
232
-
233
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
234
-
235
- bs_embed, seq_len, _ = prompt_embeds.shape
236
- # duplicate text embeddings for each generation per prompt, using mps friendly method
237
- prompt_embeds = prompt_embeds.repeat(1, 1, 1)
238
- prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
239
-
240
- if do_classifier_free_guidance:
241
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
242
- seq_len = negative_prompt_embeds.shape[1]
243
-
244
- if self.text_encoder_2 is not None:
245
- negative_prompt_embeds = negative_prompt_embeds.to(
246
- dtype=self.text_encoder_2.dtype, device=device
247
- )
248
- else:
249
- negative_prompt_embeds = negative_prompt_embeds.to(
250
- dtype=self.unet.dtype, device=device
251
- )
252
-
253
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)
254
- negative_prompt_embeds = negative_prompt_embeds.view(
255
- batch_size * 1, seq_len, -1
256
- )
257
-
258
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
259
- if do_classifier_free_guidance:
260
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
261
- 1, 1
262
- ).view(bs_embed * 1, -1)
263
-
264
- return (
265
- prompt_embeds,
266
- negative_prompt_embeds,
267
- pooled_prompt_embeds,
268
- negative_pooled_prompt_embeds,
269
- )
270
-
271
- def check_inputs(
272
- self,
273
- prompt_embeds,
274
- negative_prompt_embeds,
275
- pooled_prompt_embeds,
276
- negative_pooled_prompt_embeds,
277
- image,
278
- controlnet_conditioning_scale=1.0,
279
- ):
280
- if prompt_embeds is None:
281
- raise ValueError(
282
- "Provide `prompt_embeds`. Cannot leave `prompt_embeds` undefined."
283
- )
284
-
285
- if negative_prompt_embeds is None:
286
- raise ValueError(
287
- "Provide `negative_prompt_embeds`. Cannot leave `negative_prompt_embeds` undefined."
288
- )
289
-
290
- if prompt_embeds.shape != negative_prompt_embeds.shape:
291
- raise ValueError(
292
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
293
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
294
- f" {negative_prompt_embeds.shape}."
295
- )
296
-
297
- if prompt_embeds is not None and pooled_prompt_embeds is None:
298
- raise ValueError(
299
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
300
- )
301
-
302
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
303
- raise ValueError(
304
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
305
- )
306
-
307
- # Check `image`
308
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
309
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
310
- )
311
- if (
312
- isinstance(self.controlnet, ControlNetModel_Union)
313
- or is_compiled
314
- and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
315
- ):
316
- if not isinstance(image, PIL.Image.Image):
317
- raise TypeError(
318
- f"image must be passed and has to be a PIL image, but is {type(image)}"
319
- )
320
-
321
- else:
322
- assert False
323
-
324
- # Check `controlnet_conditioning_scale`
325
- if (
326
- isinstance(self.controlnet, ControlNetModel_Union)
327
- or is_compiled
328
- and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
329
- ):
330
- if not isinstance(controlnet_conditioning_scale, float):
331
- raise TypeError(
332
- "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
333
- )
334
- else:
335
- assert False
336
-
337
- def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
338
- image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
339
-
340
- image_batch_size = image.shape[0]
341
-
342
- image = image.repeat_interleave(image_batch_size, dim=0)
343
- image = image.to(device=device, dtype=dtype)
344
-
345
- if do_classifier_free_guidance:
346
- image = torch.cat([image] * 2)
347
-
348
- return image
349
-
350
- def prepare_latents(
351
- self, batch_size, num_channels_latents, height, width, dtype, device
352
- ):
353
- shape = (
354
- batch_size,
355
- num_channels_latents,
356
- int(height) // self.vae_scale_factor,
357
- int(width) // self.vae_scale_factor,
358
- )
359
-
360
- latents = randn_tensor(shape, device=device, dtype=dtype)
361
-
362
- # scale the initial noise by the standard deviation required by the scheduler
363
- latents = latents * self.scheduler.init_noise_sigma
364
- return latents
365
-
366
- @property
367
- def guidance_scale(self):
368
- return self._guidance_scale
369
-
370
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
371
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
372
- # corresponds to doing no classifier free guidance.
373
- @property
374
- def do_classifier_free_guidance(self):
375
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
376
-
377
- @property
378
- def num_timesteps(self):
379
- return self._num_timesteps
380
-
381
- @torch.no_grad()
382
- def __call__(
383
- self,
384
- prompt_embeds: torch.Tensor,
385
- negative_prompt_embeds: torch.Tensor,
386
- pooled_prompt_embeds: torch.Tensor,
387
- negative_pooled_prompt_embeds: torch.Tensor,
388
- image: PipelineImageInput = None,
389
- num_inference_steps: int = 8,
390
- guidance_scale: float = 1.5,
391
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
392
- ):
393
- # 1. Check inputs. Raise error if not correct
394
- self.check_inputs(
395
- prompt_embeds,
396
- negative_prompt_embeds,
397
- pooled_prompt_embeds,
398
- negative_pooled_prompt_embeds,
399
- image,
400
- controlnet_conditioning_scale,
401
- )
402
-
403
- self._guidance_scale = guidance_scale
404
-
405
- # 2. Define call parameters
406
- batch_size = 1
407
- device = self._execution_device
408
-
409
- # 4. Prepare image
410
- if isinstance(self.controlnet, ControlNetModel_Union):
411
- image = self.prepare_image(
412
- image=image,
413
- device=device,
414
- dtype=self.controlnet.dtype,
415
- do_classifier_free_guidance=self.do_classifier_free_guidance,
416
- )
417
- height, width = image.shape[-2:]
418
- else:
419
- assert False
420
-
421
- # 5. Prepare timesteps
422
- timesteps, num_inference_steps = retrieve_timesteps(
423
- self.scheduler, num_inference_steps, device
424
- )
425
- self._num_timesteps = len(timesteps)
426
-
427
- # 6. Prepare latent variables
428
- num_channels_latents = self.unet.config.in_channels
429
- latents = self.prepare_latents(
430
- batch_size,
431
- num_channels_latents,
432
- height,
433
- width,
434
- prompt_embeds.dtype,
435
- device,
436
- )
437
-
438
- # 7 Prepare added time ids & embeddings
439
- add_text_embeds = pooled_prompt_embeds
440
-
441
- add_time_ids = negative_add_time_ids = torch.tensor(
442
- image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:]
443
- ).unsqueeze(0)
444
-
445
- if self.do_classifier_free_guidance:
446
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
447
- add_text_embeds = torch.cat(
448
- [negative_pooled_prompt_embeds, add_text_embeds], dim=0
449
- )
450
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
451
-
452
- prompt_embeds = prompt_embeds.to(device)
453
- add_text_embeds = add_text_embeds.to(device)
454
- add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
455
-
456
- controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
457
- union_control_type = (
458
- torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0])
459
- .to(device, dtype=prompt_embeds.dtype)
460
- .repeat(batch_size * 2, 1)
461
- )
462
-
463
- added_cond_kwargs = {
464
- "text_embeds": add_text_embeds,
465
- "time_ids": add_time_ids,
466
- "control_type": union_control_type,
467
- }
468
-
469
- controlnet_prompt_embeds = prompt_embeds
470
- controlnet_added_cond_kwargs = added_cond_kwargs
471
-
472
- # 8. Denoising loop
473
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
474
-
475
- with self.progress_bar(total=num_inference_steps) as progress_bar:
476
- for i, t in enumerate(timesteps):
477
- # expand the latents if we are doing classifier free guidance
478
- latent_model_input = (
479
- torch.cat([latents] * 2)
480
- if self.do_classifier_free_guidance
481
- else latents
482
- )
483
- latent_model_input = self.scheduler.scale_model_input(
484
- latent_model_input, t
485
- )
486
-
487
- # controlnet(s) inference
488
- control_model_input = latent_model_input
489
-
490
- down_block_res_samples, mid_block_res_sample = self.controlnet(
491
- control_model_input,
492
- t,
493
- encoder_hidden_states=controlnet_prompt_embeds,
494
- controlnet_cond_list=controlnet_image_list,
495
- conditioning_scale=controlnet_conditioning_scale,
496
- guess_mode=False,
497
- added_cond_kwargs=controlnet_added_cond_kwargs,
498
- return_dict=False,
499
- )
500
-
501
- # predict the noise residual
502
- noise_pred = self.unet(
503
- latent_model_input,
504
- t,
505
- encoder_hidden_states=prompt_embeds,
506
- timestep_cond=None,
507
- cross_attention_kwargs={},
508
- down_block_additional_residuals=down_block_res_samples,
509
- mid_block_additional_residual=mid_block_res_sample,
510
- added_cond_kwargs=added_cond_kwargs,
511
- return_dict=False,
512
- )[0]
513
-
514
- # perform guidance
515
- if self.do_classifier_free_guidance:
516
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
517
- noise_pred = noise_pred_uncond + guidance_scale * (
518
- noise_pred_text - noise_pred_uncond
519
- )
520
-
521
- # compute the previous noisy sample x_t -> x_t-1
522
- latents = self.scheduler.step(
523
- noise_pred, t, latents, return_dict=False
524
- )[0]
525
-
526
- if i == 2:
527
- prompt_embeds = prompt_embeds[-1:]
528
- add_text_embeds = add_text_embeds[-1:]
529
- add_time_ids = add_time_ids[-1:]
530
- union_control_type = union_control_type[-1:]
531
-
532
- added_cond_kwargs = {
533
- "text_embeds": add_text_embeds,
534
- "time_ids": add_time_ids,
535
- "control_type": union_control_type,
536
- }
537
-
538
- controlnet_prompt_embeds = prompt_embeds
539
- controlnet_added_cond_kwargs = added_cond_kwargs
540
-
541
- image = image[-1:]
542
- controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
543
-
544
- self._guidance_scale = 0.0
545
-
546
- if i == len(timesteps) - 1 or (
547
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
548
- ):
549
- progress_bar.update()
550
- yield latents_to_rgb(latents)
551
-
552
- latents = latents / self.vae.config.scaling_factor
553
- image = self.vae.decode(latents, return_dict=False)[0]
554
- image = self.image_processor.postprocess(image)[0]
555
-
556
- # Offload all models
557
- self.maybe_free_model_hooks()
558
-
559
- yield image