from diffusion.respace import SpacedDiffusion from .gaussian_diffusion import _extract_into_tensor import torch as th class InpaintingGaussianDiffusion(SpacedDiffusion): def q_sample(self, x_start, t, noise=None, model_kwargs=None): """ overrides q_sample to use the inpainting mask same usage as in GaussianDiffusion """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape bs, feat, _, frames = noise.shape inpainting_mask = th.zeros_like(noise).to(noise.device) inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for training noise *= 1. - inpainting_mask return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def p_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ overrides p_sample to use the inpainting mask same usage as in GaussianDiffusion """ out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) if const_noise: noise = noise[[0]].repeat(x.shape[0], 1, 1, 1) inpainting_mask = th.zeros_like(noise).to(noise.device) inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for inference noise *= 1. - inpainting_mask nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]}