Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,203 Bytes
2a59fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
from functools import partial
from dataclasses import dataclass
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from ..modules.ae_modules import Encoder, Decoder
from ..modules.ae_dualref_modules import VideoDecoder
from ..utils import instantiate_from_config
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.use_checkpoint = use_checkpoint
def encode(self, x, return_hidden_states=False, **kwargs):
if return_hidden_states:
h, hidden = self.encoder(x, return_hidden_states)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior), hidden
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z, **kwargs):
if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
z = self.post_quant_conv(z)
dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs
return dec
def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
input_tuple = (input, )
forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)
def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
posterior = self.encode(input)[0]
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if x.dim() == 5 and self.input_dim == 4:
b,c,t,h,w = x.shape
self.b = b
self.t = t
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
class AutoencoderKL_Dualref(AutoencoderKL):
@register_to_config
def __init__(self,
ddconfig,
embed_dim,
image_key="image",
input_dim=4,
use_checkpoint=False,
):
super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
self.decoder = VideoDecoder(**ddconfig)
def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
posterior, hidden_states = self.encode(input, return_hidden_states=True)
hidden_states_first_last = []
### use only the first and last hidden states
for hid in hidden_states:
hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
hidden_states_first_last.append(hid_new)
if sample_posterior:
z = posterior[0].sample()
else:
z = posterior[0].mode()
dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
return dec, posterior |