File size: 8,565 Bytes
15acbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import inspect
from typing import List, Optional, Tuple, Union
import torch
from diffusers.models import UNet2DModel, VQModel
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import randn_tensor
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
import copy
class SDMLDMPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
[`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
"""
def __init__(self, vae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler, torch_dtype=torch.float16, resolution=512, resolution_type="city"):
super().__init__()
self.register_modules(vae=vae, unet=unet, scheduler=scheduler)
self.torch_dtype = torch_dtype
self.resolution = resolution
self.resolution_type = resolution_type
@torch.no_grad()
def __call__(
self,
segmap = None,
batch_size: int = 8,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
eta: float = 0.0,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
every_step_save: int = None,
s: int = 1,
num_evolution_per_mask = 10,
debug = False,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
Number of images to generate.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.model.ImagePipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# self.unet.config.sample_size = (64, 64) # (135,180)
# self.unet.config.sample_size = (135,180)
if self.resolution_type == "crack":
self.unet.config.sample_size = (64,64)
elif self.resolution_type == "crack_256":
self.unet.config.sample_size = (256,256)
else:
sc = 1080 // self.resolution
latent_size = (self.resolution // 4, 1440 // (sc*4))
self.unet.config.sample_size = latent_size
#
if not isinstance(self.unet.config.sample_size, tuple):
self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size)
if segmap is None:
print("Didn't inpute any segmap, use the empty as the input")
segmap = torch.zeros(batch_size,self.unet.config.segmap_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1])
segmap = segmap.to(self.device).type(self.torch_dtype)
if batch_size == 1 and num_evolution_per_mask > batch_size:
latents = randn_tensor(
(num_evolution_per_mask, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
generator=generator,
)
else:
latents = randn_tensor(
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size[0], self.unet.config.sample_size[1]),
generator=generator,
)
latents = latents.to(self.device).type(self.torch_dtype)
# scale the initial noise by the standard deviation required by the scheduler (need to check)
latents = latents * self.scheduler.init_noise_sigma
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta
step_latent = []
learn_sigma = True if hasattr(self.scheduler, "variance_type") else False
if debug:
extra_list_list = []
self.unet.debug=True
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
latent_model_input = self.scheduler.scale_model_input(latents, t)
# predict the noise residual
if debug:
output, extra_list = self.unet(latent_model_input, segmap, t)
noise_prediction = output.sample
extra_list_list.append(extra_list)
else:
noise_prediction = self.unet(latent_model_input, segmap, t).sample
# compute the previous noisy sample x_t -> x_t-1
if learn_sigma and "learn" in self.scheduler.variance_type:
model_pred, var_pred = torch.split(noise_prediction, latents.shape[1], dim=1)
else:
model_pred = noise_prediction
if s > 1.0:
if debug:
model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t)[0].sample
else:
model_output_zero = self.unet(latent_model_input, torch.zeros_like(segmap), t).sample
if learn_sigma and "learn" in self.scheduler.variance_type:
model_output_zero,_ = torch.split(model_output_zero, latents.shape[1], dim=1)
model_pred = model_pred + s * (model_pred - model_output_zero)
if learn_sigma and "learn" in self.scheduler.variance_type:
recombined = torch.cat((model_pred, var_pred), dim=1)
# when apply different scheduler, mean only !!
if learn_sigma and "learn" in self.scheduler.variance_type:
latents = self.scheduler.step(recombined, t, latents, **extra_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
if every_step_save is not None:
if (i+1) % every_step_save == 0:
step_latent.append(copy.deepcopy(latents))
if debug:
return extra_list_list[-1]
# decode the image latents with the VAE
if every_step_save is not None:
image = []
for i, l in enumerate(step_latent):
l /= self.vae.config.scaling_factor # (0.18215)
#latents /= 7.706491063029163
l = self.vae.decode(l, segmap)
l = (l / 2 + 0.5).clamp(0, 1)
l = l.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
l = self.numpy_to_pil(l)
image.append(l)
else:
latents /= self.vae.config.scaling_factor#(0.18215)
#latents /= 7.706491063029163
# image = self.vae.decode(latents, segmap).sample
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image) |