Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import Any, Dict, Optional, Tuple | |
import torch.nn as nn | |
from torch import torch | |
from shap_e.models.renderer import Renderer | |
from shap_e.util.collections import AttrDict | |
from .bottleneck import latent_bottleneck_from_config, latent_warp_from_config | |
from .params_proj import flatten_param_shapes, params_proj_from_config | |
class Encoder(nn.Module, ABC): | |
def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]): | |
""" | |
Instantiate the encoder with information about the renderer's input | |
parameters. This information can be used to create output layers to | |
generate the necessary latents. | |
""" | |
super().__init__() | |
self.param_shapes = param_shapes | |
self.device = device | |
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: | |
""" | |
Encode a batch of data into a batch of latent information. | |
""" | |
class VectorEncoder(Encoder): | |
def __init__( | |
self, | |
*, | |
device: torch.device, | |
param_shapes: Dict[str, Tuple[int]], | |
params_proj: Dict[str, Any], | |
d_latent: int, | |
latent_bottleneck: Optional[Dict[str, Any]] = None, | |
latent_warp: Optional[Dict[str, Any]] = None, | |
): | |
super().__init__(device=device, param_shapes=param_shapes) | |
if latent_bottleneck is None: | |
latent_bottleneck = dict(name="identity") | |
if latent_warp is None: | |
latent_warp = dict(name="identity") | |
self.d_latent = d_latent | |
self.params_proj = params_proj_from_config( | |
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent | |
) | |
self.latent_bottleneck = latent_bottleneck_from_config( | |
latent_bottleneck, device=device, d_latent=d_latent | |
) | |
self.latent_warp = latent_warp_from_config(latent_warp, device=device) | |
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: | |
h = self.encode_to_bottleneck(batch, options=options) | |
return self.bottleneck_to_params(h, options=options) | |
def encode_to_bottleneck( | |
self, batch: AttrDict, options: Optional[AttrDict] = None | |
) -> torch.Tensor: | |
return self.latent_warp.warp( | |
self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options), | |
options=options, | |
) | |
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: | |
""" | |
Encode the batch into a single latent vector. | |
""" | |
def bottleneck_to_params( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> AttrDict: | |
_ = options | |
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) | |
class ChannelsEncoder(VectorEncoder): | |
def __init__( | |
self, | |
*, | |
device: torch.device, | |
param_shapes: Dict[str, Tuple[int]], | |
params_proj: Dict[str, Any], | |
d_latent: int, | |
latent_bottleneck: Optional[Dict[str, Any]] = None, | |
latent_warp: Optional[Dict[str, Any]] = None, | |
): | |
super().__init__( | |
device=device, | |
param_shapes=param_shapes, | |
params_proj=params_proj, | |
d_latent=d_latent, | |
latent_bottleneck=latent_bottleneck, | |
latent_warp=latent_warp, | |
) | |
self.flat_shapes = flatten_param_shapes(param_shapes) | |
self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values()) | |
def encode_to_channels( | |
self, batch: AttrDict, options: Optional[AttrDict] = None | |
) -> torch.Tensor: | |
""" | |
Encode the batch into a per-data-point set of latents. | |
:return: [batch_size, latent_ctx, latent_width] | |
""" | |
def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: | |
return self.encode_to_channels(batch, options=options).flatten(1) | |
def bottleneck_to_channels( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> torch.Tensor: | |
_ = options | |
return vector.view(vector.shape[0], self.latent_ctx, -1) | |
def bottleneck_to_params( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> AttrDict: | |
_ = options | |
# if vector.requires_grad: | |
# vector.register_hook(lambda grad: print("latent grad", grad.min(), grad.max())) | |
return self.params_proj( | |
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options | |
) | |
class Transmitter(nn.Module): | |
def __init__(self, encoder: Encoder, renderer: Renderer): | |
super().__init__() | |
self.encoder = encoder | |
self.renderer = renderer | |
def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: | |
""" | |
Transmit the batch through the encoder and then the renderer. | |
""" | |
params = self.encoder(batch, options=options) | |
return self.renderer(batch, params=params, options=options) | |
class VectorDecoder(nn.Module): | |
def __init__( | |
self, | |
*, | |
device: torch.device, | |
param_shapes: Dict[str, Tuple[int]], | |
params_proj: Dict[str, Any], | |
d_latent: int, | |
latent_warp: Optional[Dict[str, Any]] = None, | |
renderer: Renderer, | |
): | |
super().__init__() | |
self.device = device | |
self.param_shapes = param_shapes | |
if latent_warp is None: | |
latent_warp = dict(name="identity") | |
self.d_latent = d_latent | |
self.params_proj = params_proj_from_config( | |
params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent | |
) | |
self.latent_warp = latent_warp_from_config(latent_warp, device=device) | |
self.renderer = renderer | |
def bottleneck_to_params( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> AttrDict: | |
_ = options | |
return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) | |
class ChannelsDecoder(VectorDecoder): | |
def __init__( | |
self, | |
*, | |
latent_ctx: int, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.latent_ctx = latent_ctx | |
def bottleneck_to_channels( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> torch.Tensor: | |
_ = options | |
return vector.view(vector.shape[0], self.latent_ctx, -1) | |
def bottleneck_to_params( | |
self, vector: torch.Tensor, options: Optional[AttrDict] = None | |
) -> AttrDict: | |
_ = options | |
return self.params_proj( | |
self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options | |
) | |