File size: 5,203 Bytes
2a59fa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
from functools import partial
from dataclasses import dataclass

import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput

from ..modules.ae_modules import Encoder, Decoder
from ..modules.ae_dualref_modules import VideoDecoder
from ..utils import instantiate_from_config


@dataclass
class DecoderOutput(BaseOutput):
    """
    Output of decoding method.

    Args:
        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Decoded output sample of the model. Output of the last layer of the model.
    """

    sample: torch.FloatTensor


@dataclass
class AutoencoderKLOutput(BaseOutput):
    """
    Output of AutoencoderKL encoding method.

    Args:
        latent_dist (`DiagonalGaussianDistribution`):
            Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
            `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
    """

    latent_dist: "DiagonalGaussianDistribution"


class AutoencoderKL(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(self,
                 ddconfig,
                 embed_dim,
                 image_key="image",
                 input_dim=4,
                 use_checkpoint=False,
                 ):
        super().__init__()
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        self.input_dim = input_dim
        self.use_checkpoint = use_checkpoint

    def encode(self, x, return_hidden_states=False, **kwargs):
        if return_hidden_states:
            h, hidden = self.encoder(x, return_hidden_states)
            moments = self.quant_conv(h)
            posterior = DiagonalGaussianDistribution(moments)
            return AutoencoderKLOutput(latent_dist=posterior), hidden
        else:
            h = self.encoder(x)
            moments = self.quant_conv(h)
            posterior = DiagonalGaussianDistribution(moments)
            return AutoencoderKLOutput(latent_dist=posterior)

    def decode(self, z, **kwargs):
        if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
            z = self.post_quant_conv(z)
        dec = self.decoder(z, **kwargs)  ##change for SVD decoder by adding **kwargs
        return dec

    def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
        input_tuple = (input, )
        forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
        return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)


    def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
        posterior = self.encode(input)[0]
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z, **additional_decode_kwargs)
        ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
        return dec, posterior

    def get_input(self, batch, k):
        x = batch[k]
        if x.dim() == 5 and self.input_dim == 4:
            b,c,t,h,w = x.shape
            self.b = b
            self.t = t
            x = rearrange(x, 'b c t h w -> (b t) c h w')

        return x

    def get_last_layer(self):
        return self.decoder.conv_out.weight


class AutoencoderKL_Dualref(AutoencoderKL):
    @register_to_config
    def __init__(self,
                 ddconfig,
                 embed_dim,
                 image_key="image",
                 input_dim=4,
                 use_checkpoint=False,
                 ):
        super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
        self.decoder = VideoDecoder(**ddconfig)

    def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
        posterior, hidden_states = self.encode(input, return_hidden_states=True)

        hidden_states_first_last = []
        ### use only the first and last hidden states
        for hid in hidden_states:
            hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
            hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
            hidden_states_first_last.append(hid_new)

        if sample_posterior:
            z = posterior[0].sample()
        else:
            z = posterior[0].mode()
        dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
        ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
        return dec, posterior