Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import Any, Dict, Optional | |
import numpy as np | |
import torch.nn as nn | |
from torch import torch | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.util.collections import AttrDict | |
class LatentBottleneck(nn.Module, ABC): | |
def __init__(self, *, device: torch.device, d_latent: int): | |
super().__init__() | |
self.device = device | |
self.d_latent = d_latent | |
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
pass | |
class LatentWarp(nn.Module, ABC): | |
def __init__(self, *, device: torch.device): | |
super().__init__() | |
self.device = device | |
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
pass | |
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
pass | |
class IdentityLatentWarp(LatentWarp): | |
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
return x | |
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
return x | |
class Tan2LatentWarp(LatentWarp): | |
def __init__(self, *, coeff1: float = 1.0, device: torch.device): | |
super().__init__(device=device) | |
self.coeff1 = coeff1 | |
self.scale = np.tan(np.tan(1.0) * coeff1) | |
def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype) | |
def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype) | |
class IdentityLatentBottleneck(LatentBottleneck): | |
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
return x | |
class ClampNoiseBottleneck(LatentBottleneck): | |
def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float): | |
super().__init__(device=device, d_latent=d_latent) | |
self.noise_scale = noise_scale | |
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
x = x.tanh() | |
if not self.training: | |
return x | |
return x + torch.randn_like(x) * self.noise_scale | |
class ClampDiffusionNoiseBottleneck(LatentBottleneck): | |
def __init__( | |
self, | |
*, | |
device: torch.device, | |
d_latent: int, | |
diffusion: Dict[str, Any], | |
diffusion_prob: float = 1.0, | |
): | |
super().__init__(device=device, d_latent=d_latent) | |
self.diffusion = diffusion_from_config(diffusion) | |
self.diffusion_prob = diffusion_prob | |
def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: | |
_ = options | |
x = x.tanh() | |
if not self.training: | |
return x | |
t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device) | |
t = torch.where( | |
torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t) | |
) | |
return self.diffusion.q_sample(x, t) | |
def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int): | |
name = config.pop("name") | |
if name == "clamp_noise": | |
return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent) | |
elif name == "identity": | |
return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent) | |
elif name == "clamp_diffusion_noise": | |
return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent) | |
else: | |
raise ValueError(f"unknown latent bottleneck: {name}") | |
def latent_warp_from_config(config: Dict[str, Any], device: torch.device): | |
name = config.pop("name") | |
if name == "identity": | |
print("indentity warp") | |
return IdentityLatentWarp(**config, device=device) | |
elif name == "tan2": | |
print("tan2 warp") | |
return Tan2LatentWarp(**config, device=device) | |
else: | |
raise ValueError(f"unknown latent warping function: {name}") | |