Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,250 Bytes
1da48bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from diffusion.respace import SpacedDiffusion
from .gaussian_diffusion import _extract_into_tensor
import torch as th
class InpaintingGaussianDiffusion(SpacedDiffusion):
def q_sample(self, x_start, t, noise=None, model_kwargs=None):
"""
overrides q_sample to use the inpainting mask
same usage as in GaussianDiffusion
"""
if noise is None:
noise = th.randn_like(x_start)
assert noise.shape == x_start.shape
bs, feat, _, frames = noise.shape
inpainting_mask = th.zeros_like(noise).to(noise.device)
inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for training
noise *= 1. - inpainting_mask
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def p_sample(
self,
model,
x,
t,
clip_denoised=True,
denoised_fn=None,
cond_fn=None,
model_kwargs=None,
const_noise=False,
):
"""
overrides p_sample to use the inpainting mask
same usage as in GaussianDiffusion
"""
out = self.p_mean_variance(
model,
x,
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
if const_noise:
noise = noise[[0]].repeat(x.shape[0], 1, 1, 1)
inpainting_mask = th.zeros_like(noise).to(noise.device)
inpainting_mask[:,:10] = 1 #just inpainting root trajectory, for inference
noise *= 1. - inpainting_mask
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]} |