LayerAnimate / lvdm /models /autoencoder.py
YuxueYang
Upload demo
2a59fa8
raw
history blame
5.2 kB
import os
from functools import partial
from dataclasses import dataclass
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from ..modules.ae_modules import Encoder, Decoder
from ..modules.ae_dualref_modules import VideoDecoder
from ..utils import instantiate_from_config
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.use_checkpoint = use_checkpoint
def encode(self, x, return_hidden_states=False, **kwargs):
if return_hidden_states:
h, hidden = self.encoder(x, return_hidden_states)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior), hidden
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z, **kwargs):
if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
z = self.post_quant_conv(z)
dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs
return dec
def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
input_tuple = (input, )
forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)
def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
posterior = self.encode(input)[0]
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if x.dim() == 5 and self.input_dim == 4:
b,c,t,h,w = x.shape
self.b = b
self.t = t
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
class AutoencoderKL_Dualref(AutoencoderKL):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
self.decoder = VideoDecoder(**ddconfig)
def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
posterior, hidden_states = self.encode(input, return_hidden_states=True)
hidden_states_first_last = []
### use only the first and last hidden states
for hid in hidden_states:
hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
hidden_states_first_last.append(hid_new)
if sample_posterior:
z = posterior[0].sample()
else:
z = posterior[0].mode()
dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior