Create custom_pipeline/sde_ve_pipeline.py
Browse files
    	
        custom_pipeline/sde_ve_pipeline.py
    ADDED
    
    | @@ -0,0 +1,89 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from ..scheduler import ScoreSdeVeScheduler
         | 
| 2 | 
            +
            from ..unet import UNet2DModel
         | 
| 3 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 4 | 
            +
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class ScoreSdeVePipeline(DiffusionPipeline):
         | 
| 8 | 
            +
                r"""
         | 
| 9 | 
            +
                Pipeline for unconditional image generation.
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
         | 
| 12 | 
            +
                implemented for all pipelines (downloading, saving, running on a particular device, etc.).
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Parameters:
         | 
| 15 | 
            +
                    unet ([`UNet2DModel`]):
         | 
| 16 | 
            +
                        A `UNet2DModel` to denoise the encoded image.
         | 
| 17 | 
            +
                    scheduler ([`ScoreSdeVeScheduler`]):
         | 
| 18 | 
            +
                        A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image.
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                unet: UNet2DModel
         | 
| 22 | 
            +
                scheduler: ScoreSdeVeScheduler
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.register_modules(unet=unet, scheduler=scheduler)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @torch.no_grad()
         | 
| 29 | 
            +
                def __call__(
         | 
| 30 | 
            +
                    self,
         | 
| 31 | 
            +
                    batch_size: int = 1,
         | 
| 32 | 
            +
                    num_inference_steps: int = 2000,
         | 
| 33 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 34 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 35 | 
            +
                    return_dict: bool = True,
         | 
| 36 | 
            +
                    **kwargs,
         | 
| 37 | 
            +
                ) -> Union[ImagePipelineOutput, Tuple]:
         | 
| 38 | 
            +
                    r"""
         | 
| 39 | 
            +
                    The call function to the pipeline for generation.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    Args:
         | 
| 42 | 
            +
                        batch_size (`int`, *optional*, defaults to 1):
         | 
| 43 | 
            +
                            The number of images to generate.
         | 
| 44 | 
            +
                        generator (`torch.Generator`, `optional`):
         | 
| 45 | 
            +
                            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
         | 
| 46 | 
            +
                            generation deterministic.
         | 
| 47 | 
            +
                        output_type (`str`, `optional`, defaults to `"pil"`):
         | 
| 48 | 
            +
                            The output format of the generated image. Choose between `PIL.Image` or `np.array`.
         | 
| 49 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 50 | 
            +
                            Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    Returns:
         | 
| 53 | 
            +
                        [`~pipelines.ImagePipelineOutput`] or `tuple`:
         | 
| 54 | 
            +
                            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
         | 
| 55 | 
            +
                            returned where the first element is a list with the generated images.
         | 
| 56 | 
            +
                    """
         | 
| 57 | 
            +
                    img_size = self.unet.config.sample_size
         | 
| 58 | 
            +
                    shape = (batch_size, 3, img_size, img_size)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    model = self.unet
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma
         | 
| 63 | 
            +
                    sample = sample.to(self.device)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps)
         | 
| 66 | 
            +
                    self.scheduler.set_sigmas(num_inference_steps)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
         | 
| 69 | 
            +
                        sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                        # correction step
         | 
| 72 | 
            +
                        for _ in range(self.scheduler.config.correct_steps):
         | 
| 73 | 
            +
                            model_output = self.unet(sample, sigma_t).sample
         | 
| 74 | 
            +
                            sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                        # prediction step
         | 
| 77 | 
            +
                        model_output = model(sample, sigma_t).sample
         | 
| 78 | 
            +
                        output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                        sample, sample_mean = output.prev_sample, output.prev_sample_mean
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    sample = sample_mean.clamp(0, 1)
         | 
| 83 | 
            +
                    sample = sample.cpu().permute(0, 2, 3, 1).numpy()
         | 
| 84 | 
            +
                    if output_type == "pil":
         | 
| 85 | 
            +
                        sample = self.numpy_to_pil(sample)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    if not return_dict:
         | 
| 88 | 
            +
                        return (sample,)
         | 
| 89 | 
            +
                    return ImagePipelineOutput(images=sample)
         | 
