Create scheduler/__main__.py
Browse files- scheduler/__main__.py +266 -0
    	
        scheduler/__main__.py
    ADDED
    
    | @@ -0,0 +1,266 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional, Tuple, Union
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 5 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 6 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 7 | 
            +
            from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            @dataclass
         | 
| 10 | 
            +
            class SdeVeOutput(BaseOutput):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                Output class for the scheduler's `step` function output.
         | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 15 | 
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 16 | 
            +
                        denoising loop.
         | 
| 17 | 
            +
                    prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 18 | 
            +
                        Mean averaged `prev_sample` over previous timesteps.
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 22 | 
            +
                prev_sample_mean: torch.FloatTensor
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                `ScoreSdeVeScheduler` is a variance exploding stochastic differential equation (SDE) scheduler.
         | 
| 28 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
         | 
| 29 | 
            +
                methods the library implements for all schedulers such as loading and saving.
         | 
| 30 | 
            +
                Args:
         | 
| 31 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 32 | 
            +
                        The number of diffusion steps to train the model.
         | 
| 33 | 
            +
                    snr (`float`, defaults to 0.15):
         | 
| 34 | 
            +
                        A coefficient weighting the step from the `model_output` sample (from the network) to the random noise.
         | 
| 35 | 
            +
                    sigma_min (`float`, defaults to 0.01):
         | 
| 36 | 
            +
                        The initial noise scale for the sigma sequence in the sampling procedure. The minimum sigma should mirror
         | 
| 37 | 
            +
                        the distribution of the data.
         | 
| 38 | 
            +
                    sigma_max (`float`, defaults to 1348.0):
         | 
| 39 | 
            +
                        The maximum value used for the range of continuous timesteps passed into the model.
         | 
| 40 | 
            +
                    sampling_eps (`float`, defaults to 1e-5):
         | 
| 41 | 
            +
                        The end value of sampling where timesteps decrease progressively from 1 to epsilon.
         | 
| 42 | 
            +
                    correct_steps (`int`, defaults to 1):
         | 
| 43 | 
            +
                        The number of correction steps performed on a produced sample.
         | 
| 44 | 
            +
                """
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                order = 1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @register_to_config
         | 
| 49 | 
            +
                def __init__(
         | 
| 50 | 
            +
                    self,
         | 
| 51 | 
            +
                    num_train_timesteps: int = 2000,
         | 
| 52 | 
            +
                    snr: float = 0.15,
         | 
| 53 | 
            +
                    sigma_min: float = 0.01,
         | 
| 54 | 
            +
                    sigma_max: float = 1348.0,
         | 
| 55 | 
            +
                    sampling_eps: float = 1e-5,
         | 
| 56 | 
            +
                    correct_steps: int = 1,
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
                    # standard deviation of the initial noise distribution
         | 
| 59 | 
            +
                    self.init_noise_sigma = sigma_max
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # setable values
         | 
| 62 | 
            +
                    self.timesteps = None
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 69 | 
            +
                    current timestep.
         | 
| 70 | 
            +
                    Args:
         | 
| 71 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 72 | 
            +
                            The input sample.
         | 
| 73 | 
            +
                        timestep (`int`, *optional*):
         | 
| 74 | 
            +
                            The current timestep in the diffusion chain.
         | 
| 75 | 
            +
                    Returns:
         | 
| 76 | 
            +
                        `torch.FloatTensor`:
         | 
| 77 | 
            +
                            A scaled input sample.
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    return sample
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def set_timesteps(
         | 
| 82 | 
            +
                    self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
         | 
| 83 | 
            +
                ):
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    Sets the continuous timesteps used for the diffusion chain (to be run before inference).
         | 
| 86 | 
            +
                    Args:
         | 
| 87 | 
            +
                        num_inference_steps (`int`):
         | 
| 88 | 
            +
                            The number of diffusion steps used when generating samples with a pre-trained model.
         | 
| 89 | 
            +
                        sampling_eps (`float`, *optional*):
         | 
| 90 | 
            +
                            The final timestep value (overrides value given during scheduler instantiation).
         | 
| 91 | 
            +
                        device (`str` or `torch.device`, *optional*):
         | 
| 92 | 
            +
                            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 93 | 
            +
                    """
         | 
| 94 | 
            +
                    sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def set_sigmas(
         | 
| 99 | 
            +
                    self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
         | 
| 100 | 
            +
                ):
         | 
| 101 | 
            +
                    """
         | 
| 102 | 
            +
                    Sets the noise scales used for the diffusion chain (to be run before inference). The sigmas control the weight
         | 
| 103 | 
            +
                    of the `drift` and `diffusion` components of the sample update.
         | 
| 104 | 
            +
                    Args:
         | 
| 105 | 
            +
                        num_inference_steps (`int`):
         | 
| 106 | 
            +
                            The number of diffusion steps used when generating samples with a pre-trained model.
         | 
| 107 | 
            +
                        sigma_min (`float`, optional):
         | 
| 108 | 
            +
                            The initial noise scale value (overrides value given during scheduler instantiation).
         | 
| 109 | 
            +
                        sigma_max (`float`, optional):
         | 
| 110 | 
            +
                            The final noise scale value (overrides value given during scheduler instantiation).
         | 
| 111 | 
            +
                        sampling_eps (`float`, optional):
         | 
| 112 | 
            +
                            The final timestep value (overrides value given during scheduler instantiation).
         | 
| 113 | 
            +
                    """
         | 
| 114 | 
            +
                    sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
         | 
| 115 | 
            +
                    sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
         | 
| 116 | 
            +
                    sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
         | 
| 117 | 
            +
                    if self.timesteps is None:
         | 
| 118 | 
            +
                        self.set_timesteps(num_inference_steps, sampling_eps)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    self.sigmas = sigma_min * (sigma_max / sigma_min) ** (self.timesteps / sampling_eps)
         | 
| 121 | 
            +
                    self.discrete_sigmas = torch.exp(torch.linspace(math.log(sigma_min), math.log(sigma_max), num_inference_steps))
         | 
| 122 | 
            +
                    self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def get_adjacent_sigma(self, timesteps, t):
         | 
| 125 | 
            +
                    return torch.where(
         | 
| 126 | 
            +
                        timesteps == 0,
         | 
| 127 | 
            +
                        torch.zeros_like(t.to(timesteps.device)),
         | 
| 128 | 
            +
                        self.discrete_sigmas[timesteps - 1].to(timesteps.device),
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def step_pred(
         | 
| 132 | 
            +
                    self,
         | 
| 133 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 134 | 
            +
                    timestep: int,
         | 
| 135 | 
            +
                    sample: torch.FloatTensor,
         | 
| 136 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 137 | 
            +
                    return_dict: bool = True,
         | 
| 138 | 
            +
                ) -> Union[SdeVeOutput, Tuple]:
         | 
| 139 | 
            +
                    """
         | 
| 140 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 141 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 142 | 
            +
                    Args:
         | 
| 143 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 144 | 
            +
                            The direct output from learned diffusion model.
         | 
| 145 | 
            +
                        timestep (`int`):
         | 
| 146 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 147 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 148 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 149 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 150 | 
            +
                            A random number generator.
         | 
| 151 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 152 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
         | 
| 153 | 
            +
                    Returns:
         | 
| 154 | 
            +
                        [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
         | 
| 155 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
         | 
| 156 | 
            +
                            is returned where the first element is the sample tensor.
         | 
| 157 | 
            +
                    """
         | 
| 158 | 
            +
                    if self.timesteps is None:
         | 
| 159 | 
            +
                        raise ValueError(
         | 
| 160 | 
            +
                            "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
         | 
| 161 | 
            +
                        )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    timestep = timestep * torch.ones(
         | 
| 164 | 
            +
                        sample.shape[0], device=sample.device
         | 
| 165 | 
            +
                    )  # torch.repeat_interleave(timestep, sample.shape[0])
         | 
| 166 | 
            +
                    timesteps = (timestep * (len(self.timesteps) - 1)).long()
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # mps requires indices to be in the same device, so we use cpu as is the default with cuda
         | 
| 169 | 
            +
                    timesteps = timesteps.to(self.discrete_sigmas.device)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    sigma = self.discrete_sigmas[timesteps].to(sample.device)
         | 
| 172 | 
            +
                    adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
         | 
| 173 | 
            +
                    drift = torch.zeros_like(sample)
         | 
| 174 | 
            +
                    diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
         | 
| 177 | 
            +
                    # also equation 47 shows the analog from SDE models to ancestral sampling methods
         | 
| 178 | 
            +
                    diffusion = diffusion.flatten()
         | 
| 179 | 
            +
                    while len(diffusion.shape) < len(sample.shape):
         | 
| 180 | 
            +
                        diffusion = diffusion.unsqueeze(-1)
         | 
| 181 | 
            +
                    drift = drift - diffusion**2 * model_output
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    #  equation 6: sample noise for the diffusion term of
         | 
| 184 | 
            +
                    noise = randn_tensor(
         | 
| 185 | 
            +
                        sample.shape, layout=sample.layout, generator=generator, device=sample.device, dtype=sample.dtype
         | 
| 186 | 
            +
                    )
         | 
| 187 | 
            +
                    prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
         | 
| 188 | 
            +
                    # TODO is the variable diffusion the correct scaling term for the noise?
         | 
| 189 | 
            +
                    prev_sample = prev_sample_mean + diffusion * noise  # add impact of diffusion field g
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    if not return_dict:
         | 
| 192 | 
            +
                        return (prev_sample, prev_sample_mean)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def step_correct(
         | 
| 197 | 
            +
                    self,
         | 
| 198 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 199 | 
            +
                    sample: torch.FloatTensor,
         | 
| 200 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 201 | 
            +
                    return_dict: bool = True,
         | 
| 202 | 
            +
                ) -> Union[SchedulerOutput, Tuple]:
         | 
| 203 | 
            +
                    """
         | 
| 204 | 
            +
                    Correct the predicted sample based on the `model_output` of the network. This is often run repeatedly after
         | 
| 205 | 
            +
                    making the prediction for the previous timestep.
         | 
| 206 | 
            +
                    Args:
         | 
| 207 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 208 | 
            +
                            The direct output from learned diffusion model.
         | 
| 209 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 210 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 211 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 212 | 
            +
                            A random number generator.
         | 
| 213 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 214 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`.
         | 
| 215 | 
            +
                    Returns:
         | 
| 216 | 
            +
                        [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`:
         | 
| 217 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_sde_ve.SdeVeOutput`] is returned, otherwise a tuple
         | 
| 218 | 
            +
                            is returned where the first element is the sample tensor.
         | 
| 219 | 
            +
                    """
         | 
| 220 | 
            +
                    if self.timesteps is None:
         | 
| 221 | 
            +
                        raise ValueError(
         | 
| 222 | 
            +
                            "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
         | 
| 223 | 
            +
                        )
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
         | 
| 226 | 
            +
                    # sample noise for correction
         | 
| 227 | 
            +
                    noise = randn_tensor(sample.shape, layout=sample.layout, generator=generator, device=sample.device).to(sample.device)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # compute step size from the model_output, the noise, and the snr
         | 
| 230 | 
            +
                    grad_norm = torch.norm(model_output.reshape(model_output.shape[0], -1), dim=-1).mean()
         | 
| 231 | 
            +
                    noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
         | 
| 232 | 
            +
                    step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
         | 
| 233 | 
            +
                    step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
         | 
| 234 | 
            +
                    # self.repeat_scalar(step_size, sample.shape[0])
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # compute corrected sample: model_output term and noise term
         | 
| 237 | 
            +
                    step_size = step_size.flatten()
         | 
| 238 | 
            +
                    while len(step_size.shape) < len(sample.shape):
         | 
| 239 | 
            +
                        step_size = step_size.unsqueeze(-1)
         | 
| 240 | 
            +
                    prev_sample_mean = sample + step_size * model_output
         | 
| 241 | 
            +
                    prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    if not return_dict:
         | 
| 244 | 
            +
                        return (prev_sample,)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    return SchedulerOutput(prev_sample=prev_sample)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def add_noise(
         | 
| 249 | 
            +
                    self,
         | 
| 250 | 
            +
                    original_samples: torch.FloatTensor,
         | 
| 251 | 
            +
                    noise: torch.FloatTensor,
         | 
| 252 | 
            +
                    timesteps: torch.FloatTensor,
         | 
| 253 | 
            +
                ) -> torch.FloatTensor:
         | 
| 254 | 
            +
                    # Make sure sigmas and timesteps have the same device and dtype as original_samples
         | 
| 255 | 
            +
                    timesteps = timesteps.to(original_samples.device)
         | 
| 256 | 
            +
                    sigmas = self.config.sigma_min * (self.config.sigma_max / self.config.sigma_min) ** timesteps
         | 
| 257 | 
            +
                    noise = (
         | 
| 258 | 
            +
                        noise * sigmas[:, None, None, None]
         | 
| 259 | 
            +
                        if noise is not None
         | 
| 260 | 
            +
                        else torch.randn_like(original_samples) * sigmas[:, None, None, None]
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
                    noisy_samples = noise + original_samples
         | 
| 263 | 
            +
                    return noisy_samples
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def __len__(self):
         | 
| 266 | 
            +
                    return self.config.num_train_timesteps
         | 
