Spaces:
Runtime error
Runtime error
File size: 6,916 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
import torch.nn as nn
from torch import torch
from shap_e.models.renderer import Renderer
from shap_e.util.collections import AttrDict
from .bottleneck import latent_bottleneck_from_config, latent_warp_from_config
from .params_proj import flatten_param_shapes, params_proj_from_config
class Encoder(nn.Module, ABC):
def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]):
"""
Instantiate the encoder with information about the renderer's input
parameters. This information can be used to create output layers to
generate the necessary latents.
"""
super().__init__()
self.param_shapes = param_shapes
self.device = device
@abstractmethod
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
"""
Encode a batch of data into a batch of latent information.
"""
class VectorEncoder(Encoder):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
):
super().__init__(device=device, param_shapes=param_shapes)
if latent_bottleneck is None:
latent_bottleneck = dict(name="identity")
if latent_warp is None:
latent_warp = dict(name="identity")
self.d_latent = d_latent
self.params_proj = params_proj_from_config(
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent
)
self.latent_bottleneck = latent_bottleneck_from_config(
latent_bottleneck, device=device, d_latent=d_latent
)
self.latent_warp = latent_warp_from_config(latent_warp, device=device)
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
h = self.encode_to_bottleneck(batch, options=options)
return self.bottleneck_to_params(h, options=options)
def encode_to_bottleneck(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
return self.latent_warp.warp(
self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options),
options=options,
)
@abstractmethod
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
"""
Encode the batch into a single latent vector.
"""
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)
class ChannelsEncoder(VectorEncoder):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_bottleneck: Optional[Dict[str, Any]] = None,
latent_warp: Optional[Dict[str, Any]] = None,
):
super().__init__(
device=device,
param_shapes=param_shapes,
params_proj=params_proj,
d_latent=d_latent,
latent_bottleneck=latent_bottleneck,
latent_warp=latent_warp,
)
self.flat_shapes = flatten_param_shapes(param_shapes)
self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values())
@abstractmethod
def encode_to_channels(
self, batch: AttrDict, options: Optional[AttrDict] = None
) -> torch.Tensor:
"""
Encode the batch into a per-data-point set of latents.
:return: [batch_size, latent_ctx, latent_width]
"""
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor:
return self.encode_to_channels(batch, options=options).flatten(1)
def bottleneck_to_channels(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> torch.Tensor:
_ = options
return vector.view(vector.shape[0], self.latent_ctx, -1)
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
# if vector.requires_grad:
# vector.register_hook(lambda grad: print("latent grad", grad.min(), grad.max()))
return self.params_proj(
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options
)
class Transmitter(nn.Module):
def __init__(self, encoder: Encoder, renderer: Renderer):
super().__init__()
self.encoder = encoder
self.renderer = renderer
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict:
"""
Transmit the batch through the encoder and then the renderer.
"""
params = self.encoder(batch, options=options)
return self.renderer(batch, params=params, options=options)
class VectorDecoder(nn.Module):
def __init__(
self,
*,
device: torch.device,
param_shapes: Dict[str, Tuple[int]],
params_proj: Dict[str, Any],
d_latent: int,
latent_warp: Optional[Dict[str, Any]] = None,
renderer: Renderer,
):
super().__init__()
self.device = device
self.param_shapes = param_shapes
if latent_warp is None:
latent_warp = dict(name="identity")
self.d_latent = d_latent
self.params_proj = params_proj_from_config(
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent
)
self.latent_warp = latent_warp_from_config(latent_warp, device=device)
self.renderer = renderer
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options)
class ChannelsDecoder(VectorDecoder):
def __init__(
self,
*,
latent_ctx: int,
**kwargs,
):
super().__init__(**kwargs)
self.latent_ctx = latent_ctx
def bottleneck_to_channels(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> torch.Tensor:
_ = options
return vector.view(vector.shape[0], self.latent_ctx, -1)
def bottleneck_to_params(
self, vector: torch.Tensor, options: Optional[AttrDict] = None
) -> AttrDict:
_ = options
return self.params_proj(
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options
)
|