Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import enum | |
| import math | |
| from typing import Callable | |
| import numpy as np | |
| import torch as th | |
| from . import path | |
| from .integrators import ode, sde | |
| from .utils import mean_flat, time_shift, get_lin_function | |
| class ModelType(enum.Enum): | |
| """ | |
| Which type of output the model predicts. | |
| """ | |
| NOISE = enum.auto() # the model predicts epsilon | |
| SCORE = enum.auto() # the model predicts \nabla \log p(x) | |
| VELOCITY = enum.auto() # the model predicts v(x) | |
| class PathType(enum.Enum): | |
| """ | |
| Which type of path to use. | |
| """ | |
| LINEAR = enum.auto() | |
| GVP = enum.auto() | |
| VP = enum.auto() | |
| class WeightType(enum.Enum): | |
| """ | |
| Which type of weighting to use. | |
| """ | |
| NONE = enum.auto() | |
| VELOCITY = enum.auto() | |
| LIKELIHOOD = enum.auto() | |
| class Transport: | |
| def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift): | |
| path_options = { | |
| PathType.LINEAR: path.ICPlan, | |
| PathType.GVP: path.GVPCPlan, | |
| PathType.VP: path.VPCPlan, | |
| } | |
| self.loss_type = loss_type | |
| self.model_type = model_type | |
| self.path_sampler = path_options[path_type]() | |
| self.train_eps = train_eps | |
| self.sample_eps = sample_eps | |
| self.snr_type = snr_type | |
| self.do_shift = do_shift | |
| def prior_logp(self, z): | |
| """ | |
| Standard multivariate normal prior | |
| Assume z is batched | |
| """ | |
| shape = th.tensor(z.size()) | |
| N = th.prod(shape[1:]) | |
| _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0 | |
| return th.vmap(_fn)(z) | |
| def check_interval( | |
| self, | |
| train_eps, | |
| sample_eps, | |
| *, | |
| diffusion_form="SBDM", | |
| sde=False, | |
| reverse=False, | |
| eval=False, | |
| last_step_size=0.0, | |
| ): | |
| t0 = 0 | |
| t1 = 1 | |
| eps = train_eps if not eval else sample_eps | |
| if type(self.path_sampler) in [path.VPCPlan]: | |
| t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size | |
| elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and ( | |
| self.model_type != ModelType.VELOCITY or sde | |
| ): # avoid numerical issue by taking a first semi-implicit step | |
| t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 | |
| t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size | |
| if reverse: | |
| t0, t1 = 1 - t0, 1 - t1 | |
| return t0, t1 | |
| def sample(self, x1, snr_type=None): | |
| """Sampling x0 & t based on shape of x1 (if needed) | |
| Args: | |
| x1 - data point; [batch, *dim] | |
| """ | |
| if isinstance(x1, (list, tuple)): | |
| x0 = [th.randn_like(img_start) for img_start in x1] | |
| else: | |
| x0 = th.randn_like(x1) | |
| t0, t1 = self.check_interval(self.train_eps, self.sample_eps) | |
| if snr_type is None: | |
| snr_type = self.snr_type | |
| if snr_type.startswith("uniform"): | |
| if "_" in snr_type: | |
| _, t0, t1 = snr_type.split("_") | |
| t0, t1 = float(t0), float(t1) | |
| t = th.rand((len(x1),)) * (t1 - t0) + t0 | |
| elif snr_type == "lognorm": | |
| u = th.normal(mean=0.0, std=1.0, size=(len(x1),)) | |
| t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0 | |
| else: | |
| raise NotImplementedError("Not implemented snr_type %s" % snr_type) | |
| if self.do_shift: | |
| base_shift: float = 0.5 | |
| max_shift: float = 1.15 | |
| mu = get_lin_function(y1=base_shift, y2=max_shift)(x1.shape[1]) | |
| t = time_shift(mu, 1.0, t) | |
| t = t.to(x1[0]) | |
| return t, x0, x1 | |
| def training_losses(self, model, x1, model_kwargs=None, extra_kwargs=None): | |
| """Loss for training the score model | |
| Args: | |
| - model: backbone model; could be score, noise, or velocity | |
| - x1: datapoint | |
| - model_kwargs: additional arguments for the model | |
| """ | |
| if model_kwargs == None: | |
| model_kwargs = {} | |
| t, x0, x1 = self.sample(x1) | |
| t, xt, ut = self.path_sampler.plan(t, x0, x1) | |
| B = len(x0) | |
| if "cond" in extra_kwargs and extra_kwargs["cond"] is not None: | |
| out = model(th.cat((xt, extra_kwargs["cond"]), dim=-1), timesteps=1 - t, **model_kwargs) | |
| else: | |
| out = model(xt, timesteps=1 - t, **model_kwargs) | |
| model_output = -out | |
| terms = {} | |
| if self.model_type == ModelType.VELOCITY: | |
| if isinstance(x1, (list, tuple)): | |
| assert len(model_output) == len(ut) == len(x1) | |
| for i in range(B): | |
| assert ( | |
| model_output[i].shape == ut[i].shape == x1[i].shape | |
| ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}" | |
| terms["task_loss"] = th.stack( | |
| [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], | |
| dim=0, | |
| ) | |
| else: | |
| if "img_mask" in model_kwargs: | |
| # print("loss", model_output.shape, model_kwargs["img_mask"].shape, model_kwargs["img_mask"].sum(dim=1), model_kwargs["img_mask"].sum()) | |
| B, L, D = model_output.shape | |
| img_mask = model_kwargs["img_mask"] | |
| mask_loss = (model_output - ut) * img_mask.unsqueeze(-1) # [B, L, D] | |
| terms["task_loss"] = (mask_loss ** 2).sum(dim=list(range(1, ut.dim()))) / (img_mask.sum(dim=1) * D) | |
| else: | |
| terms["task_loss"] = mean_flat(((model_output - ut) ** 2)) | |
| terms["loss"] = terms["task_loss"] | |
| terms["task_loss"] = terms["task_loss"].clone().detach() | |
| terms["t"] = t | |
| return terms | |
| def get_drift(self): | |
| """member function for obtaining the drift of the probability flow ODE""" | |
| def score_ode(x, t, model, **model_kwargs): | |
| drift_mean, drift_var = self.path_sampler.compute_drift(x, t) | |
| model_output = model(x, t, **model_kwargs) | |
| return -drift_mean + drift_var * model_output # by change of variable | |
| def noise_ode(x, t, model, **model_kwargs): | |
| drift_mean, drift_var = self.path_sampler.compute_drift(x, t) | |
| sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) | |
| model_output = model(x, t, **model_kwargs) | |
| score = model_output / -sigma_t | |
| return -drift_mean + drift_var * score | |
| def velocity_ode(x, t, model, **model_kwargs): | |
| if "cond" in model_kwargs and model_kwargs["cond"] is not None: | |
| x = th.cat((x, model_kwargs["cond"]), dim=-1) | |
| model_kwargs.pop("cond") | |
| model_output = model(x, timesteps=t, **model_kwargs) | |
| return model_output | |
| if self.model_type == ModelType.NOISE: | |
| drift_fn = noise_ode | |
| elif self.model_type == ModelType.SCORE: | |
| drift_fn = score_ode | |
| else: | |
| drift_fn = velocity_ode | |
| def body_fn(x, t, model, **model_kwargs): | |
| model_output = drift_fn(x, t, model, **model_kwargs) | |
| assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" | |
| return model_output | |
| return body_fn | |
| def get_score( | |
| self, | |
| ): | |
| """member function for obtaining score of | |
| x_t = alpha_t * x + sigma_t * eps""" | |
| if self.model_type == ModelType.NOISE: | |
| score_fn = ( | |
| lambda x, t, model, **kwargs: model(x, t, **kwargs) | |
| / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] | |
| ) | |
| elif self.model_type == ModelType.SCORE: | |
| score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) | |
| elif self.model_type == ModelType.VELOCITY: | |
| score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity( | |
| model(x, t, **kwargs), x, t | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| return score_fn | |
| class Sampler: | |
| """Sampler class for the transport model""" | |
| def __init__( | |
| self, | |
| transport, | |
| ): | |
| """Constructor for a general sampler; supporting different sampling methods | |
| Args: | |
| - transport: an tranport object specify model prediction & interpolant type | |
| """ | |
| self.transport = transport | |
| self.drift = self.transport.get_drift() | |
| self.score = self.transport.get_score() | |
| def __get_sde_diffusion_and_drift( | |
| self, | |
| *, | |
| diffusion_form="SBDM", | |
| diffusion_norm=1.0, | |
| ): | |
| def diffusion_fn(x, t): | |
| diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm) | |
| return diffusion | |
| sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score( | |
| x, t, model, **kwargs | |
| ) | |
| sde_diffusion = diffusion_fn | |
| return sde_drift, sde_diffusion | |
| def __get_last_step( | |
| self, | |
| sde_drift, | |
| *, | |
| last_step, | |
| last_step_size, | |
| ): | |
| """Get the last step function of the SDE solver""" | |
| if last_step is None: | |
| last_step_fn = lambda x, t, model, **model_kwargs: x | |
| elif last_step == "Mean": | |
| last_step_fn = ( | |
| lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size | |
| ) | |
| elif last_step == "Tweedie": | |
| alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long | |
| sigma = self.transport.path_sampler.compute_sigma_t | |
| last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][ | |
| 0 | |
| ] * self.score(x, t, model, **model_kwargs) | |
| elif last_step == "Euler": | |
| last_step_fn = ( | |
| lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size | |
| ) | |
| else: | |
| raise NotImplementedError() | |
| return last_step_fn | |
| def sample_sde( | |
| self, | |
| *, | |
| sampling_method="Euler", | |
| diffusion_form="SBDM", | |
| diffusion_norm=1.0, | |
| last_step="Mean", | |
| last_step_size=0.04, | |
| num_steps=250, | |
| ): | |
| """returns a sampling function with given SDE settings | |
| Args: | |
| - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama | |
| - diffusion_form: function form of diffusion coefficient; default to be matching SBDM | |
| - diffusion_norm: function magnitude of diffusion coefficient; default to 1 | |
| - last_step: type of the last step; default to identity | |
| - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] | |
| - num_steps: total integration step of SDE | |
| """ | |
| if last_step is None: | |
| last_step_size = 0.0 | |
| sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( | |
| diffusion_form=diffusion_form, | |
| diffusion_norm=diffusion_norm, | |
| ) | |
| t0, t1 = self.transport.check_interval( | |
| self.transport.train_eps, | |
| self.transport.sample_eps, | |
| diffusion_form=diffusion_form, | |
| sde=True, | |
| eval=True, | |
| reverse=False, | |
| last_step_size=last_step_size, | |
| ) | |
| _sde = sde( | |
| sde_drift, | |
| sde_diffusion, | |
| t0=t0, | |
| t1=t1, | |
| num_steps=num_steps, | |
| sampler_type=sampling_method, | |
| ) | |
| last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) | |
| def _sample(init, model, **model_kwargs): | |
| xs = _sde.sample(init, model, **model_kwargs) | |
| ts = th.ones(init.size(0), device=init.device) * t1 | |
| x = last_step_fn(xs[-1], ts, model, **model_kwargs) | |
| xs.append(x) | |
| assert len(xs) == num_steps, "Samples does not match the number of steps" | |
| return xs | |
| return _sample | |
| def sample_ode( | |
| self, | |
| *, | |
| sampling_method="dopri5", | |
| num_steps=50, | |
| atol=1e-6, | |
| rtol=1e-3, | |
| reverse=False, | |
| do_shift=True, | |
| time_shifting_factor=None, | |
| strength=None | |
| ): | |
| """returns a sampling function with given ODE settings | |
| Args: | |
| - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 | |
| - num_steps: | |
| - fixed solver (Euler, Heun): the actual number of integration steps performed | |
| - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation | |
| - atol: absolute error tolerance for the solver | |
| - rtol: relative error tolerance for the solver | |
| """ | |
| # for flux | |
| drift = lambda x, t, model, **kwargs: -self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) | |
| t0, t1 = self.transport.check_interval( | |
| self.transport.train_eps, | |
| self.transport.sample_eps, | |
| sde=False, | |
| eval=True, | |
| reverse=reverse, | |
| last_step_size=0.0, | |
| ) | |
| if strength is not None: | |
| t0 = (t1 - t0) * strength + t0 | |
| _ode = ode( | |
| drift=drift, | |
| t0=t0, | |
| t1=t1, | |
| sampler_type=sampling_method, | |
| num_steps=num_steps, | |
| atol=atol, | |
| rtol=rtol, | |
| do_shift=do_shift, | |
| time_shifting_factor=time_shifting_factor, | |
| ) | |
| return _ode.sample | |
| def sample_ode_likelihood( | |
| self, | |
| *, | |
| sampling_method="dopri5", | |
| num_steps=50, | |
| atol=1e-6, | |
| rtol=1e-3, | |
| ): | |
| """returns a sampling function for calculating likelihood with given ODE settings | |
| Args: | |
| - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 | |
| - num_steps: | |
| - fixed solver (Euler, Heun): the actual number of integration steps performed | |
| - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation | |
| - atol: absolute error tolerance for the solver | |
| - rtol: relative error tolerance for the solver | |
| """ | |
| def _likelihood_drift(x, t, model, **model_kwargs): | |
| x, _ = x | |
| eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 | |
| t = th.ones_like(t) * (1 - t) | |
| with th.enable_grad(): | |
| x.requires_grad = True | |
| grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] | |
| logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) | |
| drift = self.drift(x, t, model, **model_kwargs) | |
| return (-drift, logp_grad) | |
| t0, t1 = self.transport.check_interval( | |
| self.transport.train_eps, | |
| self.transport.sample_eps, | |
| sde=False, | |
| eval=True, | |
| reverse=False, | |
| last_step_size=0.0, | |
| ) | |
| _ode = ode( | |
| drift=_likelihood_drift, | |
| t0=t0, | |
| t1=t1, | |
| sampler_type=sampling_method, | |
| num_steps=num_steps, | |
| atol=atol, | |
| rtol=rtol, | |
| ) | |
| def _sample_fn(x, model, **model_kwargs): | |
| init_logp = th.zeros(x.size(0)).to(x) | |
| input = (x, init_logp) | |
| drift, delta_logp = _ode.sample(input, model, **model_kwargs) | |
| drift, delta_logp = drift[-1], delta_logp[-1] | |
| prior_logp = self.transport.prior_logp(drift) | |
| logp = prior_logp - delta_logp | |
| return logp, drift | |
| return _sample_fn | |
 
			
