Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from functools import partial | |
from dataclasses import dataclass | |
import torch | |
import numpy as np | |
from einops import rearrange | |
import torch.nn.functional as F | |
from torch.utils.checkpoint import checkpoint | |
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models import ModelMixin | |
from diffusers.utils import BaseOutput | |
from ..modules.ae_modules import Encoder, Decoder | |
from ..modules.ae_dualref_modules import VideoDecoder | |
from ..utils import instantiate_from_config | |
class DecoderOutput(BaseOutput): | |
""" | |
Output of decoding method. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | |
Decoded output sample of the model. Output of the last layer of the model. | |
""" | |
sample: torch.FloatTensor | |
class AutoencoderKLOutput(BaseOutput): | |
""" | |
Output of AutoencoderKL encoding method. | |
Args: | |
latent_dist (`DiagonalGaussianDistribution`): | |
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. | |
`DiagonalGaussianDistribution` allows for sampling latents from the distribution. | |
""" | |
latent_dist: "DiagonalGaussianDistribution" | |
class AutoencoderKL(ModelMixin, ConfigMixin): | |
def __init__(self, | |
ddconfig, | |
embed_dim, | |
image_key="image", | |
input_dim=4, | |
use_checkpoint=False, | |
): | |
super().__init__() | |
self.image_key = image_key | |
self.encoder = Encoder(**ddconfig) | |
self.decoder = Decoder(**ddconfig) | |
assert ddconfig["double_z"] | |
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) | |
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) | |
self.embed_dim = embed_dim | |
self.input_dim = input_dim | |
self.use_checkpoint = use_checkpoint | |
def encode(self, x, return_hidden_states=False, **kwargs): | |
if return_hidden_states: | |
h, hidden = self.encoder(x, return_hidden_states) | |
moments = self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return AutoencoderKLOutput(latent_dist=posterior), hidden | |
else: | |
h = self.encoder(x) | |
moments = self.quant_conv(h) | |
posterior = DiagonalGaussianDistribution(moments) | |
return AutoencoderKLOutput(latent_dist=posterior) | |
def decode(self, z, **kwargs): | |
if len(kwargs) == 0: ## use the original decoder in AutoencoderKL | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs | |
return dec | |
def forward(self, input, sample_posterior=True, **additional_decode_kwargs): | |
input_tuple = (input, ) | |
forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs) | |
return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint) | |
def _forward(self, input, sample_posterior=True, **additional_decode_kwargs): | |
posterior = self.encode(input)[0] | |
if sample_posterior: | |
z = posterior.sample() | |
else: | |
z = posterior.mode() | |
dec = self.decode(z, **additional_decode_kwargs) | |
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) | |
return dec, posterior | |
def get_input(self, batch, k): | |
x = batch[k] | |
if x.dim() == 5 and self.input_dim == 4: | |
b,c,t,h,w = x.shape | |
self.b = b | |
self.t = t | |
x = rearrange(x, 'b c t h w -> (b t) c h w') | |
return x | |
def get_last_layer(self): | |
return self.decoder.conv_out.weight | |
class AutoencoderKL_Dualref(AutoencoderKL): | |
def __init__(self, | |
ddconfig, | |
embed_dim, | |
image_key="image", | |
input_dim=4, | |
use_checkpoint=False, | |
): | |
super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint) | |
self.decoder = VideoDecoder(**ddconfig) | |
def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs): | |
posterior, hidden_states = self.encode(input, return_hidden_states=True) | |
hidden_states_first_last = [] | |
### use only the first and last hidden states | |
for hid in hidden_states: | |
hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size) | |
hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2) | |
hidden_states_first_last.append(hid_new) | |
if sample_posterior: | |
z = posterior[0].sample() | |
else: | |
z = posterior[0].mode() | |
dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs) | |
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256]) | |
return dec, posterior |