File size: 4,309 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

import numpy as np
import torch.nn as nn
from torch import torch

from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.util.collections import AttrDict


class LatentBottleneck(nn.Module, ABC):
    def __init__(self, *, device: torch.device, d_latent: int):
        super().__init__()
        self.device = device
        self.d_latent = d_latent

    @abstractmethod
    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        pass


class LatentWarp(nn.Module, ABC):
    def __init__(self, *, device: torch.device):
        super().__init__()
        self.device = device

    @abstractmethod
    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        pass

    @abstractmethod
    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        pass


class IdentityLatentWarp(LatentWarp):
    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        return x

    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        return x


class Tan2LatentWarp(LatentWarp):
    def __init__(self, *, coeff1: float = 1.0, device: torch.device):
        super().__init__(device=device)
        self.coeff1 = coeff1
        self.scale = np.tan(np.tan(1.0) * coeff1)

    def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype)

    def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype)


class IdentityLatentBottleneck(LatentBottleneck):
    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        return x


class ClampNoiseBottleneck(LatentBottleneck):
    def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float):
        super().__init__(device=device, d_latent=d_latent)
        self.noise_scale = noise_scale

    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        x = x.tanh()
        if not self.training:
            return x
        return x + torch.randn_like(x) * self.noise_scale


class ClampDiffusionNoiseBottleneck(LatentBottleneck):
    def __init__(
        self,
        *,
        device: torch.device,
        d_latent: int,
        diffusion: Dict[str, Any],
        diffusion_prob: float = 1.0,
    ):
        super().__init__(device=device, d_latent=d_latent)
        self.diffusion = diffusion_from_config(diffusion)
        self.diffusion_prob = diffusion_prob

    def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict:
        _ = options
        x = x.tanh()
        if not self.training:
            return x
        t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device)
        t = torch.where(
            torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t)
        )
        return self.diffusion.q_sample(x, t)


def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int):
    name = config.pop("name")
    if name == "clamp_noise":
        return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent)
    elif name == "identity":
        return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent)
    elif name == "clamp_diffusion_noise":
        return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent)
    else:
        raise ValueError(f"unknown latent bottleneck: {name}")


def latent_warp_from_config(config: Dict[str, Any], device: torch.device):
    name = config.pop("name")
    if name == "identity":
        print("indentity warp")
        return IdentityLatentWarp(**config, device=device)
    elif name == "tan2":
        print("tan2 warp")
        return Tan2LatentWarp(**config, device=device)
    else:
        raise ValueError(f"unknown latent warping function: {name}")