scoresdeve-ema-multi-dsprites-64 / sde_ve_pipeline.py
giulio98's picture
Update sde_ve_pipeline.py
9f7dfeb verified
raw
history blame
32.6 kB
from typing import List, Optional, Tuple, Union
import torch
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
set_W_to_weight: Optional[bool] = True,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=set_W_to_weight)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample)
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass
class SdeVeOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample` over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"""
`ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
snr (`float`, defaults to 0.15):
A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
sigma_min (`float`, defaults to 0.01):
The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
the distribution of the data.
sigma_max (`float`, defaults to 1348.0):
The maximum value used for the range of continuous timesteps passed into the model.
sampling_eps (`float`, defaults to 1e-5):
The end value of sampling where timesteps decrease progressively from 1 to epsilon.
correct_steps (`int`, defaults to 1):
The number of correction steps performed on a produced sample.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 2000,
snr: float = 0.15,
sigma_min: float = 0.01,
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
# setable values
self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample
def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
):
"""
Sets the continuous timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, *optional*):
The final timestep value (overrides value given during scheduler instantiation).
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
):
"""
Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
of the `drift` and `diffusion` components of the sample update.
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
The initial noise scale value (overrides value given during scheduler instantiation).
sigma_max (`float`, optional):
The final noise scale value (overrides value given during scheduler instantiation).
sampling_eps (`float`, optional):
The final timestep value (overrides value given during scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if self.timesteps is None:
self.set_timesteps(num_inference_steps, sampling_eps)
self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
def get_adjacent_sigma(self, timesteps, t):
return torch.where(
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
def step_pred(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
is returned where the first element is the sample tensor.
"""
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * torch.ones(
sample.shape[0], device=sample.device
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.discrete_sigmas.device)
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = torch.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion.unsqueeze(-1)
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
noise = randn_tensor(
sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean)
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
def step_correct(
self,
model_output: torch.FloatTensor,
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
making the prediction for the previous timestep.
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
is returned where the first element is the sample tensor.
"""
if self.timesteps is None:
raise ValueError(
"`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator, device=sample.device).to(sample.device)
# compute step size from the model_output, the noise, and the snr
grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
# self.repeat_scalar(step_size, sample.shape[0])
# compute corrected sample: model_output term and noise term
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size.unsqueeze(-1)
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
timesteps = timesteps.to(original_samples.device)
sigmas = self.config.sigma_min * (self.config.sigma_max / self.config.sigma_min) ** timesteps
noise = (
noise * sigmas[:, None, None, None]
if noise is not None
else torch.randn_like(original_samples) * sigmas[:, None, None, None]
)
noisy_samples = noise + original_samples
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class ScoreSdeVePipeline(DiffusionPipeline):
r"""
Pipeline for unconditional image generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
unet ([`UNet2DModel`]):
A `UNet2DModel` to denoise the encoded image.
scheduler ([`ScoreSdeVeScheduler`]):
A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image.
"""
unet: UNet2DModel
scheduler: ScoreSdeVeScheduler
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
batch_size: int = 1,
num_inference_steps: int = 2000,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, `optional`):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
output_type (`str`, `optional`, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)
model = self.unet
sample = randn_tensor(shape, generator=generator, device=self.device) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step
for _ in range(self.scheduler.config.correct_steps):
model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
# prediction step
model_output = model(sample, sigma_t).sample
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output.prev_sample, output.prev_sample_mean
sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
if not return_dict:
return (sample,)
return ImagePipelineOutput(images=sample)