Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py
Browse files
    	
        diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py
    CHANGED
    
    | @@ -12,13 +12,11 @@ | |
| 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 | 
             
            # limitations under the License.
         | 
| 14 |  | 
| 15 | 
            -
             | 
| 16 | 
            -
            import numpy as np
         | 
| 17 | 
            -
            import PIL.Image
         | 
| 18 | 
             
            import torch
         | 
| 19 | 
            -
             | 
|  | |
| 20 |  | 
| 21 | 
            -
             | 
| 22 |  | 
| 23 | 
             
            EXAMPLE_DOC_STRING = """
         | 
| 24 | 
             
                Examples:
         | 
| @@ -98,15 +96,11 @@ def prepare_mask_and_masked_image(image, mask): | |
| 98 | 
             
                """
         | 
| 99 | 
             
                if isinstance(image, torch.Tensor):
         | 
| 100 | 
             
                    if not isinstance(mask, torch.Tensor):
         | 
| 101 | 
            -
                        raise TypeError(
         | 
| 102 | 
            -
                            f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
         | 
| 103 | 
            -
                        )
         | 
| 104 |  | 
| 105 | 
             
                    # Batch single image
         | 
| 106 | 
             
                    if image.ndim == 3:
         | 
| 107 | 
            -
                        assert (
         | 
| 108 | 
            -
                            image.shape[0] == 3
         | 
| 109 | 
            -
                        ), "Image outside a batch should be of shape (3, H, W)"
         | 
| 110 | 
             
                        image = image.unsqueeze(0)
         | 
| 111 |  | 
| 112 | 
             
                    # Batch and add channel dim for single mask
         | 
| @@ -123,15 +117,9 @@ def prepare_mask_and_masked_image(image, mask): | |
| 123 | 
             
                        else:
         | 
| 124 | 
             
                            mask = mask.unsqueeze(1)
         | 
| 125 |  | 
| 126 | 
            -
                    assert  | 
| 127 | 
            -
             | 
| 128 | 
            -
                     | 
| 129 | 
            -
                    assert (
         | 
| 130 | 
            -
                        image.shape[-2:] == mask.shape[-2:]
         | 
| 131 | 
            -
                    ), "Image and Mask must have the same spatial dimensions"
         | 
| 132 | 
            -
                    assert (
         | 
| 133 | 
            -
                        image.shape[0] == mask.shape[0]
         | 
| 134 | 
            -
                    ), "Image and Mask must have the same batch size"
         | 
| 135 |  | 
| 136 | 
             
                    # Check image is in [-1, 1]
         | 
| 137 | 
             
                    if image.min() < -1 or image.max() > 1:
         | 
| @@ -148,9 +136,7 @@ def prepare_mask_and_masked_image(image, mask): | |
| 148 | 
             
                    # Image as float32
         | 
| 149 | 
             
                    image = image.to(dtype=torch.float32)
         | 
| 150 | 
             
                elif isinstance(mask, torch.Tensor):
         | 
| 151 | 
            -
                    raise TypeError(
         | 
| 152 | 
            -
                        f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
         | 
| 153 | 
            -
                    )
         | 
| 154 | 
             
                else:
         | 
| 155 | 
             
                    # preprocess image
         | 
| 156 | 
             
                    if isinstance(image, (PIL.Image.Image, np.ndarray)):
         | 
| @@ -170,9 +156,7 @@ def prepare_mask_and_masked_image(image, mask): | |
| 170 | 
             
                        mask = [mask]
         | 
| 171 |  | 
| 172 | 
             
                    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
         | 
| 173 | 
            -
                        mask = np.concatenate(
         | 
| 174 | 
            -
                            [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
         | 
| 175 | 
            -
                        )
         | 
| 176 | 
             
                        mask = mask.astype(np.float32) / 255.0
         | 
| 177 | 
             
                    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
         | 
| 178 | 
             
                        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
         | 
| @@ -185,10 +169,7 @@ def prepare_mask_and_masked_image(image, mask): | |
| 185 |  | 
| 186 | 
             
                return mask, masked_image
         | 
| 187 |  | 
| 188 | 
            -
             | 
| 189 | 
            -
            class StableDiffusionControlNetInpaintPipeline(
         | 
| 190 | 
            -
                StableDiffusionControlNetPipeline
         | 
| 191 | 
            -
            ):
         | 
| 192 | 
             
                r"""
         | 
| 193 | 
             
                Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
         | 
| 194 |  | 
| @@ -217,28 +198,15 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 217 | 
             
                    feature_extractor ([`CLIPFeatureExtractor`]):
         | 
| 218 | 
             
                        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
         | 
| 219 | 
             
                """
         | 
| 220 | 
            -
             | 
| 221 | 
             
                def prepare_mask_latents(
         | 
| 222 | 
            -
                    self,
         | 
| 223 | 
            -
                    mask,
         | 
| 224 | 
            -
                    masked_image,
         | 
| 225 | 
            -
                    batch_size,
         | 
| 226 | 
            -
                    height,
         | 
| 227 | 
            -
                    width,
         | 
| 228 | 
            -
                    dtype,
         | 
| 229 | 
            -
                    device,
         | 
| 230 | 
            -
                    generator,
         | 
| 231 | 
            -
                    do_classifier_free_guidance,
         | 
| 232 | 
             
                ):
         | 
| 233 | 
             
                    # resize the mask to latents shape as we concatenate the mask to the latents
         | 
| 234 | 
             
                    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         | 
| 235 | 
             
                    # and half precision
         | 
| 236 | 
             
                    mask = torch.nn.functional.interpolate(
         | 
| 237 | 
            -
                        mask,
         | 
| 238 | 
            -
                        size=(
         | 
| 239 | 
            -
                            height // self.vae_scale_factor,
         | 
| 240 | 
            -
                            width // self.vae_scale_factor,
         | 
| 241 | 
            -
                        ),
         | 
| 242 | 
             
                    )
         | 
| 243 | 
             
                    mask = mask.to(device=device, dtype=dtype)
         | 
| 244 |  | 
| @@ -247,19 +215,13 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 247 | 
             
                    # encode the mask image into latents space so we can concatenate it to the latents
         | 
| 248 | 
             
                    if isinstance(generator, list):
         | 
| 249 | 
             
                        masked_image_latents = [
         | 
| 250 | 
            -
                            self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
         | 
| 251 | 
            -
                                generator=generator[i]
         | 
| 252 | 
            -
                            )
         | 
| 253 | 
             
                            for i in range(batch_size)
         | 
| 254 | 
             
                        ]
         | 
| 255 | 
             
                        masked_image_latents = torch.cat(masked_image_latents, dim=0)
         | 
| 256 | 
             
                    else:
         | 
| 257 | 
            -
                        masked_image_latents = self.vae.encode(
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                        ).latent_dist.sample(generator=generator)
         | 
| 260 | 
            -
                    masked_image_latents = (
         | 
| 261 | 
            -
                        self.vae.config.scaling_factor * masked_image_latents
         | 
| 262 | 
            -
                    )
         | 
| 263 |  | 
| 264 | 
             
                    # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
         | 
| 265 | 
             
                    if mask.shape[0] < batch_size:
         | 
| @@ -277,35 +239,24 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 277 | 
             
                                f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
         | 
| 278 | 
             
                                " Make sure the number of images that you pass is divisible by the total requested batch size."
         | 
| 279 | 
             
                            )
         | 
| 280 | 
            -
                        masked_image_latents = masked_image_latents.repeat(
         | 
| 281 | 
            -
                            batch_size // masked_image_latents.shape[0], 1, 1, 1
         | 
| 282 | 
            -
                        )
         | 
| 283 |  | 
| 284 | 
             
                    mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
         | 
| 285 | 
             
                    masked_image_latents = (
         | 
| 286 | 
            -
                        torch.cat([masked_image_latents] * 2)
         | 
| 287 | 
            -
                        if do_classifier_free_guidance
         | 
| 288 | 
            -
                        else masked_image_latents
         | 
| 289 | 
             
                    )
         | 
| 290 |  | 
| 291 | 
             
                    # aligning device to prevent device errors when concating it with the latent model input
         | 
| 292 | 
            -
                    masked_image_latents = masked_image_latents.to(
         | 
| 293 | 
            -
                        device=device, dtype=dtype
         | 
| 294 | 
            -
                    )
         | 
| 295 | 
             
                    return mask, masked_image_latents
         | 
| 296 | 
            -
             | 
| 297 | 
             
                @torch.no_grad()
         | 
| 298 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 299 | 
             
                def __call__(
         | 
| 300 | 
             
                    self,
         | 
| 301 | 
            -
                    prompt: Union[str, List[str]] = None,
         | 
| 302 | 
             
                    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 303 | 
            -
                    control_image: Union[
         | 
| 304 | 
            -
                        torch.FloatTensor,
         | 
| 305 | 
            -
                        PIL.Image.Image,
         | 
| 306 | 
            -
                        List[torch.FloatTensor],
         | 
| 307 | 
            -
                        List[PIL.Image.Image],
         | 
| 308 | 
            -
                    ] = None,
         | 
| 309 | 
             
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 310 | 
             
                    height: Optional[int] = None,
         | 
| 311 | 
             
                    width: Optional[int] = None,
         | 
| @@ -314,17 +265,13 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 314 | 
             
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 315 | 
             
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 316 | 
             
                    eta: float = 0.0,
         | 
| 317 | 
            -
                    generator: Optional[
         | 
| 318 | 
            -
                        Union[torch.Generator, List[torch.Generator]]
         | 
| 319 | 
            -
                    ] = None,
         | 
| 320 | 
             
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 321 | 
             
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 322 | 
             
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 323 | 
             
                    output_type: Optional[str] = "pil",
         | 
| 324 | 
             
                    return_dict: bool = True,
         | 
| 325 | 
            -
                    callback: Optional[
         | 
| 326 | 
            -
                        Callable[[int, int, torch.FloatTensor], None]
         | 
| 327 | 
            -
                    ] = None,
         | 
| 328 | 
             
                    callback_steps: int = 1,
         | 
| 329 | 
             
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 330 | 
             
                    controlnet_conditioning_scale: float = 1.0,
         | 
| @@ -346,7 +293,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 346 | 
             
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 347 | 
             
                            repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
         | 
| 348 | 
             
                            to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
         | 
| 349 | 
            -
                            instead of 3, so the expected shape would be `(B, H, W, 1)`.
         | 
| 350 | 
             
                        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         | 
| 351 | 
             
                            The height in pixels of the generated image.
         | 
| 352 | 
             
                        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         | 
| @@ -415,14 +362,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 415 |  | 
| 416 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 417 | 
             
                    self.check_inputs(
         | 
| 418 | 
            -
                        prompt,
         | 
| 419 | 
            -
                        control_image,
         | 
| 420 | 
            -
                        height,
         | 
| 421 | 
            -
                        width,
         | 
| 422 | 
            -
                        callback_steps,
         | 
| 423 | 
            -
                        negative_prompt,
         | 
| 424 | 
            -
                        prompt_embeds,
         | 
| 425 | 
            -
                        negative_prompt_embeds,
         | 
| 426 | 
             
                    )
         | 
| 427 |  | 
| 428 | 
             
                    # 2. Define call parameters
         | 
| @@ -452,15 +392,15 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 452 |  | 
| 453 | 
             
                    # 4. Prepare image
         | 
| 454 | 
             
                    control_image = self.prepare_image(
         | 
| 455 | 
            -
             | 
| 456 | 
            -
             | 
| 457 | 
            -
             | 
| 458 | 
            -
             | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
| 464 | 
             
                    if do_classifier_free_guidance:
         | 
| 465 | 
             
                        control_image = torch.cat([control_image] * 2)
         | 
| 466 |  | 
| @@ -469,7 +409,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 469 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 470 |  | 
| 471 | 
             
                    # 6. Prepare latent variables
         | 
| 472 | 
            -
                    num_channels_latents = self.controlnet.in_channels
         | 
| 473 | 
             
                    latents = self.prepare_latents(
         | 
| 474 | 
             
                        batch_size * num_images_per_prompt,
         | 
| 475 | 
             
                        num_channels_latents,
         | 
| @@ -480,7 +420,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 480 | 
             
                        generator,
         | 
| 481 | 
             
                        latents,
         | 
| 482 | 
             
                    )
         | 
| 483 | 
            -
             | 
| 484 | 
             
                    # EXTRA: prepare mask latents
         | 
| 485 | 
             
                    mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
         | 
| 486 | 
             
                    mask, masked_image_latents = self.prepare_mask_latents(
         | 
| @@ -499,20 +439,12 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 499 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 500 |  | 
| 501 | 
             
                    # 8. Denoising loop
         | 
| 502 | 
            -
                    num_warmup_steps = (
         | 
| 503 | 
            -
                        len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 504 | 
            -
                    )
         | 
| 505 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 506 | 
             
                        for i, t in enumerate(timesteps):
         | 
| 507 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 508 | 
            -
                            latent_model_input = (
         | 
| 509 | 
            -
             | 
| 510 | 
            -
                                if do_classifier_free_guidance
         | 
| 511 | 
            -
                                else latents
         | 
| 512 | 
            -
                            )
         | 
| 513 | 
            -
                            latent_model_input = self.scheduler.scale_model_input(
         | 
| 514 | 
            -
                                latent_model_input, t
         | 
| 515 | 
            -
                            )
         | 
| 516 |  | 
| 517 | 
             
                            down_block_res_samples, mid_block_res_sample = self.controlnet(
         | 
| 518 | 
             
                                latent_model_input,
         | 
| @@ -529,9 +461,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 529 | 
             
                            mid_block_res_sample *= controlnet_conditioning_scale
         | 
| 530 |  | 
| 531 | 
             
                            # predict the noise residual
         | 
| 532 | 
            -
                            latent_model_input = torch.cat(
         | 
| 533 | 
            -
                                [latent_model_input, mask, masked_image_latents], dim=1
         | 
| 534 | 
            -
                            )
         | 
| 535 | 
             
                            noise_pred = self.unet(
         | 
| 536 | 
             
                                latent_model_input,
         | 
| 537 | 
             
                                t,
         | 
| @@ -544,30 +474,20 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 544 | 
             
                            # perform guidance
         | 
| 545 | 
             
                            if do_classifier_free_guidance:
         | 
| 546 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 547 | 
            -
                                noise_pred = noise_pred_uncond + guidance_scale * (
         | 
| 548 | 
            -
                                    noise_pred_text - noise_pred_uncond
         | 
| 549 | 
            -
                                )
         | 
| 550 |  | 
| 551 | 
             
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 552 | 
            -
                            latents = self.scheduler.step(
         | 
| 553 | 
            -
                                noise_pred, t, latents, **extra_step_kwargs
         | 
| 554 | 
            -
                            ).prev_sample
         | 
| 555 |  | 
| 556 | 
             
                            # call the callback, if provided
         | 
| 557 | 
            -
                            if i == len(timesteps) - 1 or (
         | 
| 558 | 
            -
                                (i + 1) > num_warmup_steps
         | 
| 559 | 
            -
                                and (i + 1) % self.scheduler.order == 0
         | 
| 560 | 
            -
                            ):
         | 
| 561 | 
             
                                progress_bar.update()
         | 
| 562 | 
             
                                if callback is not None and i % callback_steps == 0:
         | 
| 563 | 
             
                                    callback(i, t, latents)
         | 
| 564 |  | 
| 565 | 
             
                    # If we do sequential model offloading, let's offload unet and controlnet
         | 
| 566 | 
             
                    # manually for max memory savings
         | 
| 567 | 
            -
                    if (
         | 
| 568 | 
            -
                        hasattr(self, "final_offload_hook")
         | 
| 569 | 
            -
                        and self.final_offload_hook is not None
         | 
| 570 | 
            -
                    ):
         | 
| 571 | 
             
                        self.unet.to("cpu")
         | 
| 572 | 
             
                        self.controlnet.to("cpu")
         | 
| 573 | 
             
                        torch.cuda.empty_cache()
         | 
| @@ -580,9 +500,7 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 580 | 
             
                        image = self.decode_latents(latents)
         | 
| 581 |  | 
| 582 | 
             
                        # 9. Run safety checker
         | 
| 583 | 
            -
                        image, has_nsfw_concept = self.run_safety_checker(
         | 
| 584 | 
            -
                            image, device, prompt_embeds.dtype
         | 
| 585 | 
            -
                        )
         | 
| 586 |  | 
| 587 | 
             
                        # 10. Convert to PIL
         | 
| 588 | 
             
                        image = self.numpy_to_pil(image)
         | 
| @@ -591,20 +509,13 @@ class StableDiffusionControlNetInpaintPipeline( | |
| 591 | 
             
                        image = self.decode_latents(latents)
         | 
| 592 |  | 
| 593 | 
             
                        # 9. Run safety checker
         | 
| 594 | 
            -
                        image, has_nsfw_concept = self.run_safety_checker(
         | 
| 595 | 
            -
                            image, device, prompt_embeds.dtype
         | 
| 596 | 
            -
                        )
         | 
| 597 |  | 
| 598 | 
             
                    # Offload last model to CPU
         | 
| 599 | 
            -
                    if (
         | 
| 600 | 
            -
                        hasattr(self, "final_offload_hook")
         | 
| 601 | 
            -
                        and self.final_offload_hook is not None
         | 
| 602 | 
            -
                    ):
         | 
| 603 | 
             
                        self.final_offload_hook.offload()
         | 
| 604 |  | 
| 605 | 
             
                    if not return_dict:
         | 
| 606 | 
             
                        return (image, has_nsfw_concept)
         | 
| 607 |  | 
| 608 | 
            -
                    return StableDiffusionPipelineOutput(
         | 
| 609 | 
            -
                        images=image, nsfw_content_detected=has_nsfw_concept
         | 
| 610 | 
            -
                    )
         | 
|  | |
| 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 | 
             
            # limitations under the License.
         | 
| 14 |  | 
|  | |
|  | |
|  | |
| 15 | 
             
            import torch
         | 
| 16 | 
            +
            import PIL.Image
         | 
| 17 | 
            +
            import numpy as np
         | 
| 18 |  | 
| 19 | 
            +
            from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
         | 
| 20 |  | 
| 21 | 
             
            EXAMPLE_DOC_STRING = """
         | 
| 22 | 
             
                Examples:
         | 
|  | |
| 96 | 
             
                """
         | 
| 97 | 
             
                if isinstance(image, torch.Tensor):
         | 
| 98 | 
             
                    if not isinstance(mask, torch.Tensor):
         | 
| 99 | 
            +
                        raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
         | 
|  | |
|  | |
| 100 |  | 
| 101 | 
             
                    # Batch single image
         | 
| 102 | 
             
                    if image.ndim == 3:
         | 
| 103 | 
            +
                        assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
         | 
|  | |
|  | |
| 104 | 
             
                        image = image.unsqueeze(0)
         | 
| 105 |  | 
| 106 | 
             
                    # Batch and add channel dim for single mask
         | 
|  | |
| 117 | 
             
                        else:
         | 
| 118 | 
             
                            mask = mask.unsqueeze(1)
         | 
| 119 |  | 
| 120 | 
            +
                    assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
         | 
| 121 | 
            +
                    assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
         | 
| 122 | 
            +
                    assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 123 |  | 
| 124 | 
             
                    # Check image is in [-1, 1]
         | 
| 125 | 
             
                    if image.min() < -1 or image.max() > 1:
         | 
|  | |
| 136 | 
             
                    # Image as float32
         | 
| 137 | 
             
                    image = image.to(dtype=torch.float32)
         | 
| 138 | 
             
                elif isinstance(mask, torch.Tensor):
         | 
| 139 | 
            +
                    raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
         | 
|  | |
|  | |
| 140 | 
             
                else:
         | 
| 141 | 
             
                    # preprocess image
         | 
| 142 | 
             
                    if isinstance(image, (PIL.Image.Image, np.ndarray)):
         | 
|  | |
| 156 | 
             
                        mask = [mask]
         | 
| 157 |  | 
| 158 | 
             
                    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
         | 
| 159 | 
            +
                        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
         | 
|  | |
|  | |
| 160 | 
             
                        mask = mask.astype(np.float32) / 255.0
         | 
| 161 | 
             
                    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
         | 
| 162 | 
             
                        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
         | 
|  | |
| 169 |  | 
| 170 | 
             
                return mask, masked_image
         | 
| 171 |  | 
| 172 | 
            +
            class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
         | 
|  | |
|  | |
|  | |
| 173 | 
             
                r"""
         | 
| 174 | 
             
                Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
         | 
| 175 |  | 
|  | |
| 198 | 
             
                    feature_extractor ([`CLIPFeatureExtractor`]):
         | 
| 199 | 
             
                        Model that extracts features from generated images to be used as inputs for the `safety_checker`.
         | 
| 200 | 
             
                """
         | 
| 201 | 
            +
                
         | 
| 202 | 
             
                def prepare_mask_latents(
         | 
| 203 | 
            +
                    self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 204 | 
             
                ):
         | 
| 205 | 
             
                    # resize the mask to latents shape as we concatenate the mask to the latents
         | 
| 206 | 
             
                    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         | 
| 207 | 
             
                    # and half precision
         | 
| 208 | 
             
                    mask = torch.nn.functional.interpolate(
         | 
| 209 | 
            +
                        mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
|  | |
|  | |
|  | |
|  | |
| 210 | 
             
                    )
         | 
| 211 | 
             
                    mask = mask.to(device=device, dtype=dtype)
         | 
| 212 |  | 
|  | |
| 215 | 
             
                    # encode the mask image into latents space so we can concatenate it to the latents
         | 
| 216 | 
             
                    if isinstance(generator, list):
         | 
| 217 | 
             
                        masked_image_latents = [
         | 
| 218 | 
            +
                            self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
         | 
|  | |
|  | |
| 219 | 
             
                            for i in range(batch_size)
         | 
| 220 | 
             
                        ]
         | 
| 221 | 
             
                        masked_image_latents = torch.cat(masked_image_latents, dim=0)
         | 
| 222 | 
             
                    else:
         | 
| 223 | 
            +
                        masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
         | 
| 224 | 
            +
                    masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
         | 
|  | |
|  | |
|  | |
|  | |
| 225 |  | 
| 226 | 
             
                    # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
         | 
| 227 | 
             
                    if mask.shape[0] < batch_size:
         | 
|  | |
| 239 | 
             
                                f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
         | 
| 240 | 
             
                                " Make sure the number of images that you pass is divisible by the total requested batch size."
         | 
| 241 | 
             
                            )
         | 
| 242 | 
            +
                        masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
         | 
|  | |
|  | |
| 243 |  | 
| 244 | 
             
                    mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
         | 
| 245 | 
             
                    masked_image_latents = (
         | 
| 246 | 
            +
                        torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
         | 
|  | |
|  | |
| 247 | 
             
                    )
         | 
| 248 |  | 
| 249 | 
             
                    # aligning device to prevent device errors when concating it with the latent model input
         | 
| 250 | 
            +
                    masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
         | 
|  | |
|  | |
| 251 | 
             
                    return mask, masked_image_latents
         | 
| 252 | 
            +
                
         | 
| 253 | 
             
                @torch.no_grad()
         | 
| 254 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 255 | 
             
                def __call__(
         | 
| 256 | 
             
                    self,
         | 
| 257 | 
            +
                    prompt: Union[str, List[str]] = None,        
         | 
| 258 | 
             
                    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 259 | 
            +
                    control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,        
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 260 | 
             
                    mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
         | 
| 261 | 
             
                    height: Optional[int] = None,
         | 
| 262 | 
             
                    width: Optional[int] = None,
         | 
|  | |
| 265 | 
             
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 266 | 
             
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 267 | 
             
                    eta: float = 0.0,
         | 
| 268 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
|  | |
|  | |
| 269 | 
             
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 270 | 
             
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 271 | 
             
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 272 | 
             
                    output_type: Optional[str] = "pil",
         | 
| 273 | 
             
                    return_dict: bool = True,
         | 
| 274 | 
            +
                    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
         | 
|  | |
|  | |
| 275 | 
             
                    callback_steps: int = 1,
         | 
| 276 | 
             
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 277 | 
             
                    controlnet_conditioning_scale: float = 1.0,
         | 
|  | |
| 293 | 
             
                            `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
         | 
| 294 | 
             
                            repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
         | 
| 295 | 
             
                            to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
         | 
| 296 | 
            +
                            instead of 3, so the expected shape would be `(B, H, W, 1)`.            
         | 
| 297 | 
             
                        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         | 
| 298 | 
             
                            The height in pixels of the generated image.
         | 
| 299 | 
             
                        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         | 
|  | |
| 362 |  | 
| 363 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 364 | 
             
                    self.check_inputs(
         | 
| 365 | 
            +
                        prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 366 | 
             
                    )
         | 
| 367 |  | 
| 368 | 
             
                    # 2. Define call parameters
         | 
|  | |
| 392 |  | 
| 393 | 
             
                    # 4. Prepare image
         | 
| 394 | 
             
                    control_image = self.prepare_image(
         | 
| 395 | 
            +
                        control_image,
         | 
| 396 | 
            +
                        width,
         | 
| 397 | 
            +
                        height,
         | 
| 398 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 399 | 
            +
                        num_images_per_prompt,
         | 
| 400 | 
            +
                        device,
         | 
| 401 | 
            +
                        self.controlnet.dtype,
         | 
| 402 | 
            +
                    )
         | 
| 403 | 
            +
             | 
| 404 | 
             
                    if do_classifier_free_guidance:
         | 
| 405 | 
             
                        control_image = torch.cat([control_image] * 2)
         | 
| 406 |  | 
|  | |
| 409 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 410 |  | 
| 411 | 
             
                    # 6. Prepare latent variables
         | 
| 412 | 
            +
                    num_channels_latents = self.controlnet.config.in_channels
         | 
| 413 | 
             
                    latents = self.prepare_latents(
         | 
| 414 | 
             
                        batch_size * num_images_per_prompt,
         | 
| 415 | 
             
                        num_channels_latents,
         | 
|  | |
| 420 | 
             
                        generator,
         | 
| 421 | 
             
                        latents,
         | 
| 422 | 
             
                    )
         | 
| 423 | 
            +
                    
         | 
| 424 | 
             
                    # EXTRA: prepare mask latents
         | 
| 425 | 
             
                    mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
         | 
| 426 | 
             
                    mask, masked_image_latents = self.prepare_mask_latents(
         | 
|  | |
| 439 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 440 |  | 
| 441 | 
             
                    # 8. Denoising loop
         | 
| 442 | 
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
|  | |
|  | |
| 443 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 444 | 
             
                        for i, t in enumerate(timesteps):
         | 
| 445 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 446 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         | 
| 447 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)                                
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 448 |  | 
| 449 | 
             
                            down_block_res_samples, mid_block_res_sample = self.controlnet(
         | 
| 450 | 
             
                                latent_model_input,
         | 
|  | |
| 461 | 
             
                            mid_block_res_sample *= controlnet_conditioning_scale
         | 
| 462 |  | 
| 463 | 
             
                            # predict the noise residual
         | 
| 464 | 
            +
                            latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
         | 
|  | |
|  | |
| 465 | 
             
                            noise_pred = self.unet(
         | 
| 466 | 
             
                                latent_model_input,
         | 
| 467 | 
             
                                t,
         | 
|  | |
| 474 | 
             
                            # perform guidance
         | 
| 475 | 
             
                            if do_classifier_free_guidance:
         | 
| 476 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 477 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
|  | |
|  | |
| 478 |  | 
| 479 | 
             
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 480 | 
            +
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
|  | |
|  | |
| 481 |  | 
| 482 | 
             
                            # call the callback, if provided
         | 
| 483 | 
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
|  | |
|  | |
|  | |
| 484 | 
             
                                progress_bar.update()
         | 
| 485 | 
             
                                if callback is not None and i % callback_steps == 0:
         | 
| 486 | 
             
                                    callback(i, t, latents)
         | 
| 487 |  | 
| 488 | 
             
                    # If we do sequential model offloading, let's offload unet and controlnet
         | 
| 489 | 
             
                    # manually for max memory savings
         | 
| 490 | 
            +
                    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
         | 
|  | |
|  | |
|  | |
| 491 | 
             
                        self.unet.to("cpu")
         | 
| 492 | 
             
                        self.controlnet.to("cpu")
         | 
| 493 | 
             
                        torch.cuda.empty_cache()
         | 
|  | |
| 500 | 
             
                        image = self.decode_latents(latents)
         | 
| 501 |  | 
| 502 | 
             
                        # 9. Run safety checker
         | 
| 503 | 
            +
                        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
         | 
|  | |
|  | |
| 504 |  | 
| 505 | 
             
                        # 10. Convert to PIL
         | 
| 506 | 
             
                        image = self.numpy_to_pil(image)
         | 
|  | |
| 509 | 
             
                        image = self.decode_latents(latents)
         | 
| 510 |  | 
| 511 | 
             
                        # 9. Run safety checker
         | 
| 512 | 
            +
                        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
         | 
|  | |
|  | |
| 513 |  | 
| 514 | 
             
                    # Offload last model to CPU
         | 
| 515 | 
            +
                    if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
         | 
|  | |
|  | |
|  | |
| 516 | 
             
                        self.final_offload_hook.offload()
         | 
| 517 |  | 
| 518 | 
             
                    if not return_dict:
         | 
| 519 | 
             
                        return (image, has_nsfw_concept)
         | 
| 520 |  | 
| 521 | 
            +
                    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
         | 
|  | |
|  | 
