from functools import partial from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler, DPMSolverSinglestepScheduler from diffusers.pipeline_utils import DiffusionPipeline import torch from diffusers.configuration_utils import ConfigMixin, register_to_config from typing import List, Optional, Tuple, Union import numpy as np from diffusers.schedulers.scheduling_utils import SchedulerOutput from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput from diffusers.utils import randn_tensor, BaseOutput ### Testing the DDPM Scheduler for Variant class ModifiedDDPMScheduler(DDPMScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def step( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True, ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: print("Conidtion is trigger") model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) # [2,3, 64, 128] else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample ยต_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample # 6. Add noise variance = 0 if t > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": variance = self._get_variance(t, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise else: variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance print(pred_prev_sample.shape) if not return_dict: return (pred_prev_sample,) return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) class ModifiedUniPCScheduler(UniPCMultistepScheduler): ''' This is the modification of UniPCMultistepScheduler, which is the same as UniPCMultistepScheduler except for the _get_variance function. ''' def __init__(self, variance_type: str = "fixed_small", *args, **kwargs): super().__init__(*args, **kwargs) self.custom_timesteps = False self.variance_type=variance_type self.config.timestep_spacing="leading" def previous_timestep(self, timestep): if self.custom_timesteps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: num_inference_steps = ( self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps ) prev_t = timestep - self.config.num_train_timesteps // num_inference_steps return prev_t def _get_variance(self, t, predicted_variance=None, variance_type="learned_range"): prev_t = self.previous_timestep(t) alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t variance = torch.clamp(variance, min=1e-20) if variance_type is None: variance_type = self.config.variance_type if variance_type == "fixed_small": variance = variance elif variance_type == "fixed_small_log": variance = torch.log(variance) variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = current_beta_t elif variance_type == "fixed_large_log": variance = torch.log(current_beta_t) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = torch.log(variance) max_log = torch.log(current_beta_t) frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log return variance def step(self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, return_dict: bool = True) -> Union[SchedulerOutput, Tuple]: if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: print("condition using predicted_variance is trigger") model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None super_output = super().step(model_output, timestep, sample, return_dict=False) prev_sample = super_output[0] # breakpoint() variance = 0 if timestep > 0: device = model_output.device variance_noise = randn_tensor( model_output.shape, generator=None, device=device, dtype=model_output.dtype ) if self.variance_type == "fixed_small_log": variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise elif self.variance_type == "learned_range": # breakpoint() variance = self._get_variance(timestep, predicted_variance=predicted_variance) variance = torch.exp(0.5 * variance) * variance_noise # breakpoint() else: variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * variance_noise # breakpoint() print("time step is ", timestep) prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) return DDPMSchedulerOutput(prev_sample=prev_sample,pred_original_sample=prev_sample) #return SchedulerOutput(prev_sample=prev_sample) def build_proc(sch_cfg=None, _sch=None, **kwargs): if kwargs: return _sch(**kwargs) type_str = str(type(sch_cfg)) if 'dict' in type_str: return _sch.from_config(**sch_cfg) return _sch.from_config(sch_cfg, subfolder="scheduler") scheduler_factory = { 'UniPC' : partial(build_proc, _sch=UniPCMultistepScheduler), 'modifiedUniPC' : partial(build_proc, _sch=ModifiedUniPCScheduler), # DPM family 'DDPM' : partial(build_proc, _sch=DDPMScheduler), 'DPMSolver' : partial(build_proc, _sch=DPMSolverMultistepScheduler, algorithm_type='dpmsolver'), 'DPMSolver++' : partial(build_proc, _sch=DPMSolverMultistepScheduler), 'DPMSolverSingleStep' : partial(build_proc, _sch=DPMSolverSinglestepScheduler) } def scheduler_setup(pipe : DiffusionPipeline = None, scheduler_type : str = 'UniPC', from_config=None, **kwargs): if not isinstance(pipe, DiffusionPipeline): raise TypeError(f'pipe should be DiffusionPipeline, but given {type(pipe)}\n') sch_cfg = from_config if from_config else pipe.scheduler.config #sch_cfg = diffusers.configuration_utils.FrozenDict({**sch_cfg, 'solver_order':3}) #pipe.scheduler = scheduler_factory[scheduler_type](**kwargs) if kwargs \ # else scheduler_factory[scheduler_type](sch_cfg) # pipe.scheduler = DPMSolverSinglestepScheduler() # #pipe.scheduler = DDPMScheduler(beta_schedule="linear", variance_type="learned_range") # print(pipe.scheduler) print("Scheduler type in Scheduler_factory.py is Hard-coded to modifyUniPC, Please change it back to AutoDetect functionality if you want to change scheudler") pipe.scheduler = ModifiedUniPCScheduler(variance_type="learned_range", ) # pipe.scheduler = ModifiedDDPMScheduler(beta_schedule="linear", variance_type="learned_range") #pipe.scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") #pipe.scheduler._get_variance = _get_variance return pipe # unittest of scheduler.. if __name__ == "__main__": def ld_mod(): noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to("cuda").to(torch.float16) unet = SDMUNet2DModel.from_pretrained("/data/harry/Data_generation/diffusers-main/examples/VAESDM/LDM-sdm-model/checkpoint-46000", subfolder="unet").to("cuda").to(torch.float16) return noise_scheduler, vae, unet from Pipline import SDMLDMPipeline from diffusers import StableDiffusionPipeline import torch path = "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) # change scheduler # customized args : once you customized, customize forever ~ no from_config #pipe = scheduler_setup(pipe, 'DPMSolver++', thresholding=True) # from_config pipe = scheduler_setup(pipe, 'DPMSolverSingleStep') pipe = pipe.to("cuda") prompt = "a highly realistic photo of green turtle" generator = torch.manual_seed(0) # only 15 steps are needed for good results => 2-4 seconds on GPU image = pipe(prompt, generator=generator, num_inference_steps=15).images[0] # save image image.save("turtle.png") ''' # load & wrap submodules into pipe-API noise_scheduler, vae, unet = ld_mod() pipe = SDMLDMPipeline( unet=unet, vqvae=vae, scheduler=noise_scheduler, torch_dtype=torch.float16 ) # change scheduler pipe = scheduler_setup(pipe, 'DPMSolverSingleStep') pipe = pipe.to("cuda") '''