File size: 4,313 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/lucidrains/NWT-pytorch/

import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import EinMix as Mix


class Memcodes(nn.Module):
    def __init__(
        self,
        *,
        dim,
        codebook_size,
        heads = 1,
        temperature = 1.,
        channel_last = False,
        accept_image_fmap = True,
        **kwargs,
    ):
        super().__init__()
        assert (dim % heads) == 0, 'dimension must be divisible by number of heads'
        self.heads = heads
        self.dim = dim
        self.scale = (dim // heads) ** -0.5
        self.temperature = temperature
        self.codebook_size = codebook_size
        self.accept_image_fmap = accept_image_fmap
        self.channel_last = channel_last

        num_codebooks = heads
        codebook_dim = dim // heads

        self.codes = nn.Parameter(torch.randn(num_codebooks, codebook_size, codebook_dim))
        self.to_k = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)
        self.to_v = Mix('h n d -> h n c', weight_shape = 'h d c', h = heads, d = codebook_dim, c = codebook_dim)

    def indices_to_embedding(self, indices):
        batch = indices.shape[0]

        values = self.to_v(self.codes)
        values = repeat(values, 'h n d -> b h n d', b = batch)

        indices = repeat(indices, '... -> ... d', d = values.shape[-1]).squeeze(2)

        if self.accept_image_fmap and len(indices.size())==4:
            out = values.gather(2, indices)
            out = rearrange(out, 'b h n d -> b (h d) n 1')
            return out
        else:
            out = values.gather(2, indices.unsqueeze(2))
            return rearrange(out, 'b h n d -> b n (h d)')

    def forward(self, x):
        need_transpose = not self.channel_last and not self.accept_image_fmap

        if self.accept_image_fmap:
            height, width = x.shape[-2:]
            x = rearrange(x, 'b c h w -> b (h w) c')

        if need_transpose:
            x = rearrange(x, 'b d n -> b n d')
        
        assert x.shape[-1] == self.dim

        # split out heads

        q = rearrange(x, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale

        # get key / values of codes

        k, v = self.to_k(self.codes), self.to_v(self.codes)

        # straight through gumbel softmax

        logits = einsum('b h i d, h j d -> b h i j', q, k)

        if self.training:
            attn = F.gumbel_softmax(logits, tau = self.temperature, dim = -1, hard = True)
            codebook_indices = attn.argmax(dim = -1)
        else:
            codebook_indices = logits.argmax(dim = -1)
            attn = F.one_hot(codebook_indices, num_classes = self.codebook_size).float()

        if self.heads == 1:
            codebook_indices = codebook_indices.squeeze(1)

        out = einsum('b h i j, h j d -> b h i d', attn, v)

        # merge heads
        out = rearrange(out, 'b h n d -> b n (h d)')

        if need_transpose:
            out = rearrange(out, 'b n d -> b d n')

        if self.accept_image_fmap:
            out = rearrange(out, 'b (h w) c -> b c h w', h = height, w = width)
            if self.heads == 1:
                codebook_indices = rearrange(codebook_indices, 'b (h w) -> b h w', h = height, w = width)
            else:
                codebook_indices = rearrange(codebook_indices, 'b n (h w) -> b n h w', h = height, w = width)

        # Dummy codebook loss for compatibility with other types of quantizers
        codebook_loss = torch.tensor([0.], device=x.device)

        return out, codebook_loss, codebook_indices