michellemoorre commited on
Commit
6c4dee3
·
0 Parent(s):

Initial commit

Browse files
__pycache__/app.cpython-39.pyc ADDED
Binary file (2.75 kB). View file
 
__pycache__/dist.cpython-39.pyc ADDED
Binary file (5.93 kB). View file
 
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+
5
+ import spaces
6
+ from models import TVARPipeline
7
+ import torch
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "michellemoorre/var-test"
11
+
12
+
13
+ pipe = TVARPipeline.from_pretrained(model_repo_id, device=device)
14
+
15
+ MAX_SEED = np.iinfo(np.int32).max
16
+ MAX_IMAGE_SIZE = 1024
17
+
18
+
19
+ @spaces.GPU(duration=65)
20
+ def infer(
21
+ prompt,
22
+ negative_prompt="",
23
+ seed=42,
24
+ randomize_seed=False,
25
+ guidance_scale=4.0,
26
+ top_k=450,
27
+ top_p=0.95,
28
+ re=False,
29
+ re_max_depth=10,
30
+ progress=gr.Progress(track_tqdm=True),
31
+ ):
32
+ if randomize_seed:
33
+ seed = random.randint(0, MAX_SEED)
34
+
35
+ image = pipe(
36
+ prompt=prompt,
37
+ null_prompt=negative_prompt,
38
+ cfg=guidance_scale,
39
+ top_p=top_p,
40
+ top_k=top_k,
41
+ re=re,
42
+ g_seed=seed,
43
+ )[0]
44
+
45
+ return image, seed
46
+
47
+ # TODO: add examples from preview
48
+ examples = [
49
+ "A capybara wearing a suit holding a sign that reads Hello World",
50
+ ]
51
+
52
+ css = """
53
+ #col-container {
54
+ margin: 0 auto;
55
+ max-width: 640px;
56
+ }
57
+ """
58
+
59
+ with gr.Blocks(css=css) as demo:
60
+ with gr.Column(elem_id="col-container"):
61
+ gr.Markdown(" # [OpenTVAR](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
62
+ gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the OpenTVAR.")
63
+ with gr.Row():
64
+ prompt = gr.Text(
65
+ label="Prompt",
66
+ show_label=False,
67
+ max_lines=1,
68
+ placeholder="Enter your prompt",
69
+ container=False,
70
+ )
71
+
72
+ run_button = gr.Button("Run", scale=0, variant="primary")
73
+
74
+ result = gr.Image(label="Result", show_label=False)
75
+
76
+ with gr.Accordion("Advanced Settings", open=False):
77
+ negative_prompt = gr.Text(
78
+ label="Negative prompt",
79
+ max_lines=1,
80
+ placeholder="Enter a negative prompt",
81
+ visible=True,
82
+ )
83
+
84
+ seed = gr.Slider(
85
+ label="Seed",
86
+ minimum=0,
87
+ maximum=MAX_SEED,
88
+ step=1,
89
+ value=0,
90
+ )
91
+
92
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
93
+
94
+ with gr.Row():
95
+ guidance_scale = gr.Slider(
96
+ label="Guidance scale",
97
+ minimum=0.0,
98
+ maximum=7.5,
99
+ step=0.1,
100
+ value=4.5,
101
+ )
102
+ with gr.Row():
103
+ top_k = gr.Slider(
104
+ label="Sampling top k",
105
+ minimum=1,
106
+ maximum=1000,
107
+ step=10,
108
+ value=450,
109
+ )
110
+ top_p = gr.Slider(
111
+ label="Sampling top p",
112
+ minimum=0.0,
113
+ maximum=1.,
114
+ step=0.05,
115
+ value=0.95,
116
+ )
117
+ with gr.Row():
118
+ re = gr.Checkbox(label="Rejection Sampling", value=False)
119
+ re_max_depth = gr.Slider(
120
+ label="Rejection Sampling Depth",
121
+ minimum=0,
122
+ maximum=20,
123
+ step=1,
124
+ value=10,
125
+ )
126
+
127
+ gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
128
+ gr.on(
129
+ triggers=[run_button.click, prompt.submit],
130
+ fn=infer,
131
+ inputs=[
132
+ prompt,
133
+ negative_prompt,
134
+ seed,
135
+ randomize_seed,
136
+ guidance_scale,
137
+ top_k,
138
+ top_p,
139
+ re,
140
+ re_max_depth,
141
+ ],
142
+ outputs=[result, seed],
143
+ )
144
+
145
+ if __name__ == "__main__":
146
+ demo.launch()
dist.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+ import os
5
+ import socket
6
+
7
+ import torch as th
8
+ import torch.distributed as dist
9
+ from torch.distributed import barrier, is_initialized, broadcast
10
+
11
+ # Change this to reflect your cluster layout.
12
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
13
+ GPUS_PER_NODE = 8
14
+
15
+ SETUP_RETRY_COUNT = 3
16
+
17
+ import datetime
18
+ import os
19
+
20
+ import socket
21
+ from contextlib import closing
22
+
23
+
24
+ def find_free_port() -> int:
25
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
26
+ s.bind(("", 0))
27
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
28
+ return s.getsockname()[1]
29
+
30
+
31
+ def check_if_port_open(port: int) -> bool:
32
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
33
+ try:
34
+ s.bind(("", port))
35
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
36
+ return True
37
+ except OSError:
38
+ return False
39
+
40
+
41
+ def initialized():
42
+ return dist.is_initialized()
43
+
44
+
45
+ def finalize():
46
+ if dist.is_initialized():
47
+ dist.destroy_process_group()
48
+
49
+
50
+ def initialize():
51
+ is_mpirun = not (
52
+ "RANK" in os.environ
53
+ and "WORLD_SIZE" in os.environ
54
+ and "MASTER_ADDR" in os.environ
55
+ and "MASTER_PORT" in os.environ
56
+ )
57
+
58
+ if is_mpirun:
59
+ from mpi4py import MPI
60
+ import subprocess
61
+
62
+ comm = MPI.COMM_WORLD
63
+ rank = comm.Get_rank()
64
+ world_size = comm.Get_size()
65
+
66
+ master_addr = None
67
+ master_port = None
68
+ if rank == 0:
69
+ hostname_cmd = ["hostname -I"]
70
+ result = subprocess.check_output(hostname_cmd, shell=True)
71
+ master_addr = result.decode("utf-8").split()[0]
72
+
73
+ base_port = os.environ.get(
74
+ "MASTER_PORT", "29500"
75
+ ) # TORCH_DISTRIBUTED_DEFAULT_PORT
76
+ if check_if_port_open(int(base_port)):
77
+ master_port = base_port
78
+ else:
79
+ master_port = find_free_port()
80
+
81
+ master_addr = comm.bcast(master_addr, root=0)
82
+ master_port = comm.bcast(master_port, root=0)
83
+ # Determine local rank by assuming hostnames are unique
84
+ proc_name = MPI.Get_processor_name()
85
+ all_procs = comm.allgather(proc_name)
86
+ local_rank = sum([i == proc_name for i in all_procs[:rank]])
87
+ uniq_proc_names = set(all_procs)
88
+ host_rank = sorted(uniq_proc_names).index(proc_name)
89
+
90
+ os.environ["LOCAL_RANK"] = str(local_rank)
91
+ os.environ["HOST_RANK"] = str(host_rank)
92
+ os.environ["NUM_HOSTS"] = str(len(uniq_proc_names))
93
+
94
+ os.environ["RANK"] = str(rank)
95
+ os.environ["WORLD_SIZE"] = str(world_size)
96
+ os.environ["MASTER_ADDR"] = master_addr
97
+ os.environ["MASTER_PORT"] = str(master_port)
98
+ os.environ["OMP_NUM_THREADS"] = "1"
99
+
100
+ # Initialize torch distributed
101
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
102
+ dist.init_process_group(backend=backend, timeout=datetime.timedelta(0, 3600))
103
+ th.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
104
+
105
+ if is_mpirun and dist.get_rank() == 0:
106
+ print("Distributed setup")
107
+ print("LOCAL_RANK", os.environ['LOCAL_RANK'])
108
+ print("HOST_RANK", os.environ['HOST_RANK'])
109
+ print("NUM_HOSTS", os.environ['NUM_HOSTS'])
110
+ print("WORLD_SIZE", os.environ['WORLD_SIZE'])
111
+
112
+
113
+ def local_host_gather(data):
114
+ from mpi4py import MPI
115
+
116
+ comm = MPI.COMM_WORLD
117
+ host_rank = os.environ["HOST_RANK"]
118
+ all_data = comm.allgather((host_rank, data))
119
+ return [d[1] for d in all_data if d[0] == host_rank]
120
+
121
+
122
+ def in_distributed_mode():
123
+ return dist is not None
124
+
125
+
126
+ def is_master():
127
+ return get_rank() == 0
128
+
129
+
130
+ def is_local_master():
131
+ return get_local_rank() == 0
132
+
133
+
134
+ def get_rank():
135
+ return dist.get_rank() if in_distributed_mode() else 0
136
+
137
+
138
+ def get_local_rank():
139
+ return int(os.environ["LOCAL_RANK"])
140
+
141
+
142
+ def worker_host_idx():
143
+ return int(os.environ["HOST_RANK"])
144
+
145
+
146
+ def num_hosts():
147
+ return int(os.environ['NUM_HOSTS'])
148
+
149
+
150
+ def get_world_size():
151
+ return dist.get_world_size() if in_distributed_mode() else 1
152
+
153
+
154
+ def gpu_visible_device_list():
155
+ return str(dist.get_rank()) if in_distributed_mode() else None
156
+
157
+
158
+ def get_device():
159
+ """
160
+ Get the device to use for torch.distributed.
161
+ """
162
+ if th.cuda.is_available():
163
+ return th.device("cuda")
164
+ return th.device("cpu")
165
+
166
+
167
+ def sync_params(params):
168
+ """
169
+ Synchronize a sequence of Tensors across ranks from rank 0.
170
+ """
171
+ for p in params:
172
+ with th.no_grad():
173
+ dist.broadcast(p, 0)
174
+
175
+
176
+ def print0(*args, **kwargs):
177
+ if get_rank() == 0:
178
+ print(*args, **kwargs)
179
+
180
+
181
+ def allreduce(t: th.Tensor, async_op=False):
182
+ if dist.is_initialized():
183
+ if not t.is_cuda:
184
+ cu = t.detach().cuda()
185
+ ret = dist.all_reduce(cu, async_op=async_op)
186
+ t.copy_(cu.cpu())
187
+ else:
188
+ ret = dist.all_reduce(t, async_op=async_op)
189
+ return ret
190
+ return None
191
+
192
+
193
+ def allgather(t: th.Tensor, cat=True):
194
+ if dist.is_initialized():
195
+ if not t.is_cuda:
196
+ t = t.cuda()
197
+ ls = [th.empty_like(t) for _ in range(get_world_size())]
198
+ dist.all_gather(ls, t)
199
+ else:
200
+ ls = [t]
201
+ if cat:
202
+ ls = th.cat(ls, dim=0)
203
+ return ls
gradio_cached_examples/24/Result/408bda05bb7418a064b8/image.webp ADDED
gradio_cached_examples/24/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Result,Seed,flag,username,timestamp
2
+ "{""path"": ""gradio_cached_examples/24/Result/408bda05bb7418a064b8/image.webp"", ""url"": ""/file=/place/vartmp/gradio/162d5443a1136187c8b25737fbbcdc8392d218d21024c2458112af1cc6d17d66/image.webp"", ""size"": null, ""orig_name"": ""image.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}",42,,,2024-10-26 02:08:31.037928
models/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+
5
+ from .clip import FrozenCLIPEmbedder
6
+ from .quant import VectorQuantizer2
7
+ from .var import VAR
8
+ from .vqvae import VQVAE
9
+ from .pipeline import TVARPipeline
10
+
11
+
12
+ def build_vae_var(
13
+ # Shared args
14
+ device,
15
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
16
+ # VQVAE args
17
+ V=4096,
18
+ Cvae=32,
19
+ ch=160,
20
+ share_quant_resi=4,
21
+ # VAR args
22
+ depth=16,
23
+ shared_aln=False,
24
+ attn_l2_norm=True,
25
+ init_adaln=0.5,
26
+ init_adaln_gamma=1e-5,
27
+ init_head=0.02,
28
+ init_std=-1, # init_std < 0: automated
29
+ text_encoder_path=None,
30
+ text_encoder_2_path=None,
31
+ rope=False,
32
+ rope_theta=100,
33
+ rope_size=None,
34
+ dpr=0,
35
+ use_swiglu_ffn=False,
36
+ ) -> Tuple[VQVAE, VAR]:
37
+ heads = depth
38
+ width = depth * 64
39
+ if dpr > 0:
40
+ dpr = dpr * depth / 24
41
+
42
+ # disable built-in initialization for speed
43
+ for clz in (
44
+ nn.Linear,
45
+ nn.LayerNorm,
46
+ nn.BatchNorm2d,
47
+ nn.SyncBatchNorm,
48
+ nn.Conv1d,
49
+ nn.Conv2d,
50
+ nn.ConvTranspose1d,
51
+ nn.ConvTranspose2d,
52
+ ):
53
+ setattr(clz, "reset_parameters", lambda self: None)
54
+
55
+ # build models
56
+ vae_local = VQVAE(
57
+ vocab_size=V,
58
+ z_channels=Cvae,
59
+ ch=ch,
60
+ test_mode=True,
61
+ share_quant_resi=share_quant_resi,
62
+ v_patch_nums=patch_nums,
63
+ ).to(device)
64
+ var_wo_ddp = VAR(
65
+ depth=depth,
66
+ embed_dim=width,
67
+ num_heads=heads,
68
+ drop_rate=0.0,
69
+ attn_drop_rate=0.0,
70
+ drop_path_rate=dpr,
71
+ norm_eps=1e-6,
72
+ shared_aln=shared_aln,
73
+ attn_l2_norm=attn_l2_norm,
74
+ patch_nums=patch_nums,
75
+ rope=rope,
76
+ rope_theta=rope_theta,
77
+ rope_size=rope_size,
78
+ use_swiglu_ffn=use_swiglu_ffn,
79
+ ).to(device)
80
+ var_wo_ddp.init_weights(
81
+ init_adaln=init_adaln,
82
+ init_adaln_gamma=init_adaln_gamma,
83
+ init_head=init_head,
84
+ init_std=init_std,
85
+ )
86
+ text_encoder = FrozenCLIPEmbedder(text_encoder_path)
87
+ text_encoder_2 = FrozenCLIPEmbedder(text_encoder_2_path)
88
+ pipe = TVARPipeline(var_wo_ddp, vae_local, text_encoder, text_encoder_2, device)
89
+
90
+ return vae_local, var_wo_ddp, pipe
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
models/__pycache__/basic_vae.cpython-39.pyc ADDED
Binary file (6.99 kB). View file
 
models/__pycache__/basic_var.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
models/__pycache__/clip.cpython-39.pyc ADDED
Binary file (1.92 kB). View file
 
models/__pycache__/helpers.cpython-39.pyc ADDED
Binary file (2.88 kB). View file
 
models/__pycache__/pipeline.cpython-39.pyc ADDED
Binary file (6.19 kB). View file
 
models/__pycache__/quant.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
models/__pycache__/rope.cpython-39.pyc ADDED
Binary file (2.13 kB). View file
 
models/__pycache__/var.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
models/__pycache__/vqvae.cpython-39.pyc ADDED
Binary file (5.76 kB). View file
 
models/basic_vae.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # this file only provides the 2 modules used in VQVAE
6
+ __all__ = [ "Encoder", "Decoder"]
7
+
8
+
9
+ """
10
+ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
11
+ """
12
+
13
+
14
+ # swish
15
+ def nonlinearity(x):
16
+ return x * torch.sigmoid(x)
17
+
18
+
19
+ def Normalize(in_channels, num_groups=32):
20
+ return torch.nn.GroupNorm(
21
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
22
+ )
23
+
24
+
25
+ class Upsample2x(nn.Module):
26
+ def __init__(self, in_channels):
27
+ super().__init__()
28
+ self.conv = torch.nn.Conv2d(
29
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
30
+ )
31
+
32
+ def forward(self, x):
33
+ return self.conv(F.interpolate(x, scale_factor=2, mode="nearest"))
34
+
35
+
36
+ class Downsample2x(nn.Module):
37
+ def __init__(self, in_channels):
38
+ super().__init__()
39
+ self.conv = torch.nn.Conv2d(
40
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode="constant", value=0))
45
+
46
+
47
+ class ResnetBlock(nn.Module):
48
+ def __init__(
49
+ self, *, in_channels, out_channels=None, dropout
50
+ ): # conv_shortcut=False, # conv_shortcut: always False in VAE
51
+ super().__init__()
52
+ self.in_channels = in_channels
53
+ out_channels = in_channels if out_channels is None else out_channels
54
+ self.out_channels = out_channels
55
+
56
+ self.norm1 = Normalize(in_channels)
57
+ self.conv1 = torch.nn.Conv2d(
58
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
59
+ )
60
+ self.norm2 = Normalize(out_channels)
61
+ self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
62
+ self.conv2 = torch.nn.Conv2d(
63
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
64
+ )
65
+ if self.in_channels != self.out_channels:
66
+ self.nin_shortcut = torch.nn.Conv2d(
67
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
68
+ )
69
+ else:
70
+ self.nin_shortcut = nn.Identity()
71
+
72
+ def forward(self, x):
73
+ h = self.conv1(F.silu(self.norm1(x), inplace=True))
74
+ h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
75
+ return self.nin_shortcut(x) + h
76
+
77
+
78
+ class AttnBlock(nn.Module):
79
+ def __init__(self, in_channels):
80
+ super().__init__()
81
+ self.C = in_channels
82
+
83
+ self.norm = Normalize(in_channels)
84
+ self.qkv = torch.nn.Conv2d(
85
+ in_channels, 3 * in_channels, kernel_size=1, stride=1, padding=0
86
+ )
87
+ self.w_ratio = int(in_channels) ** (-0.5)
88
+ self.proj_out = torch.nn.Conv2d(
89
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
90
+ )
91
+
92
+ def forward(self, x):
93
+ qkv = self.qkv(self.norm(x))
94
+ B, _, H, W = qkv.shape # should be B,3C,H,W
95
+ C = self.C
96
+ q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
97
+
98
+ # compute attention
99
+ q = q.view(B, C, H * W).contiguous()
100
+ q = q.permute(0, 2, 1).contiguous() # B,HW,C
101
+ k = k.view(B, C, H * W).contiguous() # B,C,HW
102
+ w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW
103
+ # w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
104
+ w = F.softmax(w, dim=2)
105
+
106
+ # attend to values
107
+ v = v.view(B, C, H * W).contiguous()
108
+ w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
109
+ h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
110
+ h = h.view(B, C, H, W).contiguous()
111
+
112
+ return x + self.proj_out(h)
113
+
114
+
115
+ def make_attn(in_channels, using_sa=True):
116
+ return AttnBlock(in_channels) if using_sa else nn.Identity()
117
+
118
+
119
+ class Encoder(nn.Module):
120
+ def __init__(
121
+ self,
122
+ *,
123
+ ch=128,
124
+ ch_mult=(1, 2, 4, 8),
125
+ num_res_blocks=2,
126
+ dropout=0.0,
127
+ in_channels=3,
128
+ z_channels,
129
+ double_z=False,
130
+ using_sa=True,
131
+ using_mid_sa=True,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.downsample_ratio = 2 ** (self.num_resolutions - 1)
137
+ self.num_res_blocks = num_res_blocks
138
+ self.in_channels = in_channels
139
+
140
+ # downsampling
141
+ self.conv_in = torch.nn.Conv2d(
142
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
143
+ )
144
+
145
+ in_ch_mult = (1,) + tuple(ch_mult)
146
+ self.down = nn.ModuleList()
147
+ for i_level in range(self.num_resolutions):
148
+ block = nn.ModuleList()
149
+ attn = nn.ModuleList()
150
+ block_in = ch * in_ch_mult[i_level]
151
+ block_out = ch * ch_mult[i_level]
152
+ for i_block in range(self.num_res_blocks):
153
+ block.append(
154
+ ResnetBlock(
155
+ in_channels=block_in, out_channels=block_out, dropout=dropout
156
+ )
157
+ )
158
+ block_in = block_out
159
+ if i_level == self.num_resolutions - 1 and using_sa:
160
+ attn.append(make_attn(block_in, using_sa=True))
161
+ down = nn.Module()
162
+ down.block = block
163
+ down.attn = attn
164
+ if i_level != self.num_resolutions - 1:
165
+ down.downsample = Downsample2x(block_in)
166
+ self.down.append(down)
167
+
168
+ # middle
169
+ self.mid = nn.Module()
170
+ self.mid.block_1 = ResnetBlock(
171
+ in_channels=block_in, out_channels=block_in, dropout=dropout
172
+ )
173
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
174
+ self.mid.block_2 = ResnetBlock(
175
+ in_channels=block_in, out_channels=block_in, dropout=dropout
176
+ )
177
+
178
+ # end
179
+ self.norm_out = Normalize(block_in)
180
+ self.conv_out = torch.nn.Conv2d(
181
+ block_in,
182
+ (2 * z_channels if double_z else z_channels),
183
+ kernel_size=3,
184
+ stride=1,
185
+ padding=1,
186
+ )
187
+
188
+ def forward(self, x):
189
+ # downsampling
190
+ h = self.conv_in(x)
191
+ for i_level in range(self.num_resolutions):
192
+ for i_block in range(self.num_res_blocks):
193
+ h = self.down[i_level].block[i_block](h)
194
+ if len(self.down[i_level].attn) > 0:
195
+ h = self.down[i_level].attn[i_block](h)
196
+ if i_level != self.num_resolutions - 1:
197
+ h = self.down[i_level].downsample(h)
198
+
199
+ # middle
200
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
201
+
202
+ # end
203
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
204
+ return h
205
+
206
+
207
+ class Decoder(nn.Module):
208
+ def __init__(
209
+ self,
210
+ *,
211
+ ch=128,
212
+ ch_mult=(1, 2, 4, 8),
213
+ num_res_blocks=2,
214
+ dropout=0.0,
215
+ in_channels=3, # in_channels: raw img channels
216
+ z_channels,
217
+ using_sa=True,
218
+ using_mid_sa=True,
219
+ ):
220
+ super().__init__()
221
+ self.ch = ch
222
+ self.num_resolutions = len(ch_mult)
223
+ self.num_res_blocks = num_res_blocks
224
+ self.in_channels = in_channels
225
+
226
+ # compute in_ch_mult, block_in and curr_res at lowest res
227
+ in_ch_mult = (1,) + tuple(ch_mult)
228
+ block_in = ch * ch_mult[self.num_resolutions - 1]
229
+
230
+ # z to block_in
231
+ self.conv_in = torch.nn.Conv2d(
232
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
233
+ )
234
+
235
+ # middle
236
+ self.mid = nn.Module()
237
+ self.mid.block_1 = ResnetBlock(
238
+ in_channels=block_in, out_channels=block_in, dropout=dropout
239
+ )
240
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
241
+ self.mid.block_2 = ResnetBlock(
242
+ in_channels=block_in, out_channels=block_in, dropout=dropout
243
+ )
244
+
245
+ # upsampling
246
+ self.up = nn.ModuleList()
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ block = nn.ModuleList()
249
+ attn = nn.ModuleList()
250
+ block_out = ch * ch_mult[i_level]
251
+ for i_block in range(self.num_res_blocks + 1):
252
+ block.append(
253
+ ResnetBlock(
254
+ in_channels=block_in, out_channels=block_out, dropout=dropout
255
+ )
256
+ )
257
+ block_in = block_out
258
+ if i_level == self.num_resolutions - 1 and using_sa:
259
+ attn.append(make_attn(block_in, using_sa=True))
260
+ up = nn.Module()
261
+ up.block = block
262
+ up.attn = attn
263
+ if i_level != 0:
264
+ up.upsample = Upsample2x(block_in)
265
+ self.up.insert(0, up) # prepend to get consistent order
266
+
267
+ # end
268
+ self.norm_out = Normalize(block_in)
269
+ self.conv_out = torch.nn.Conv2d(
270
+ block_in, in_channels, kernel_size=3, stride=1, padding=1
271
+ )
272
+
273
+ def forward(self, z):
274
+ # z to block_in
275
+ # middle
276
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
277
+
278
+ # upsampling
279
+ for i_level in reversed(range(self.num_resolutions)):
280
+ for i_block in range(self.num_res_blocks + 1):
281
+ h = self.up[i_level].block[i_block](h)
282
+ if len(self.up[i_level].attn) > 0:
283
+ h = self.up[i_level].attn[i_block](h)
284
+ if i_level != 0:
285
+ h = self.up[i_level].upsample(h)
286
+
287
+ # end
288
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
289
+ return h
models/basic_var.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn.functional import scaled_dot_product_attention # q, k, v: BHLc
9
+
10
+ from models.helpers import DropPath
11
+ from models.rope import apply_rotary_emb
12
+
13
+ try:
14
+ from flash_attn.ops.fused_dense import fused_mlp_func
15
+ except ImportError:
16
+ fused_mlp_func = None
17
+
18
+ # this file only provides the 4 blocks used in VAR transformer
19
+ __all__ = ["FFN", "AdaLNSelfCrossAttn", "AdaLNBeforeHead"]
20
+
21
+
22
+ try:
23
+ from apex.normalization import FusedRMSNorm as RMSNorm
24
+ except ImportError:
25
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
26
+
27
+ class RMSNorm(torch.nn.Module):
28
+ def __init__(self, dim: int, eps: float = 1e-6):
29
+ """
30
+ Initialize the RMSNorm normalization layer.
31
+
32
+ Args:
33
+ dim (int): The dimension of the input tensor.
34
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
35
+
36
+ Attributes:
37
+ eps (float): A small value added to the denominator for numerical stability.
38
+ weight (nn.Parameter): Learnable scaling parameter.
39
+
40
+ """
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+
45
+ def _norm(self, x):
46
+ """
47
+ Apply the RMSNorm normalization to the input tensor.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The normalized tensor.
54
+
55
+ """
56
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
57
+
58
+ def forward(self, x):
59
+ """
60
+ Forward pass through the RMSNorm layer.
61
+
62
+ Args:
63
+ x (torch.Tensor): The input tensor.
64
+
65
+ Returns:
66
+ torch.Tensor: The output tensor after applying RMSNorm.
67
+
68
+ """
69
+ output = self._norm(x.float()).type_as(x)
70
+ return output * self.weight
71
+
72
+
73
+ class FFN(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_features,
77
+ hidden_features=None,
78
+ out_features=None,
79
+ drop=0.0,
80
+ fused_if_available=True,
81
+ ):
82
+ super().__init__()
83
+ self.fused_mlp_func = fused_mlp_func if fused_if_available else None
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+ self.fc1 = nn.Linear(in_features, hidden_features)
87
+ self.act = nn.GELU(approximate="tanh")
88
+ self.fc2 = nn.Linear(hidden_features, out_features)
89
+ self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
90
+
91
+ def forward(self, x):
92
+ if self.fused_mlp_func is not None:
93
+ return self.drop(
94
+ self.fused_mlp_func(
95
+ x=x,
96
+ weight1=self.fc1.weight,
97
+ weight2=self.fc2.weight,
98
+ bias1=self.fc1.bias,
99
+ bias2=self.fc2.bias,
100
+ activation="gelu_approx",
101
+ save_pre_act=self.training,
102
+ return_residual=False,
103
+ checkpoint_lvl=0,
104
+ heuristic=0,
105
+ process_group=None,
106
+ )
107
+ )
108
+ else:
109
+ return self.drop(self.fc2(self.act(self.fc1(x))))
110
+
111
+ def extra_repr(self) -> str:
112
+ return f"fused_mlp_func={self.fused_mlp_func is not None}"
113
+
114
+
115
+ class SwiGLUFFN(nn.Module):
116
+ def __init__(
117
+ self,
118
+ dim: int,
119
+ ff_mult: float = 8 / 3,
120
+ ):
121
+ """
122
+ Initialize the FeedForward module.
123
+
124
+ Args:
125
+ dim (int): Input dimension.
126
+ ff_mult (float, optional): Custom multiplier for hidden dimension. Defaults to 4.
127
+ """
128
+ super().__init__()
129
+ hidden_dim = int(dim * ff_mult)
130
+
131
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
132
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
133
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
134
+ self.fused_mlp_func = None
135
+ self._init()
136
+
137
+ def _init(self):
138
+ for module in self.modules():
139
+ if isinstance(module, nn.Linear):
140
+ nn.init.xavier_uniform_(module.weight)
141
+ if module.bias is not None:
142
+ nn.init.zeros_(module.bias)
143
+
144
+ # @torch.compile
145
+ def _forward_silu_gating(self, x_gate: torch.Tensor, x_up: torch.Tensor):
146
+ return F.silu(x_gate) * x_up
147
+
148
+ def forward(self, x: torch.Tensor):
149
+ return self.down_proj(
150
+ self._forward_silu_gating(self.gate_proj(x), self.up_proj(x))
151
+ )
152
+
153
+ def extra_repr(self) -> str:
154
+ return f"fused_mlp_func={self.fused_mlp_func is not None}"
155
+
156
+
157
+ class CrossAttention(nn.Module):
158
+ def __init__(
159
+ self,
160
+ embed_dim: int = 768,
161
+ context_dim: int = 2048,
162
+ num_heads: int = 12,
163
+ attn_drop: float = 0.0,
164
+ proj_drop: float = 0.0,
165
+ qk_norm: bool = False,
166
+ ):
167
+ super().__init__()
168
+ assert embed_dim % num_heads == 0
169
+ assert attn_drop == 0.0
170
+
171
+ self.num_heads, self.head_dim = (
172
+ num_heads,
173
+ embed_dim // num_heads,
174
+ )
175
+ self.qk_norm = qk_norm
176
+ self.scale = 1 / math.sqrt(self.head_dim)
177
+
178
+ self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6)
179
+ self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6)
180
+
181
+ self.to_q = nn.Linear(embed_dim, embed_dim, bias=True)
182
+ self.to_kv = nn.Linear(context_dim, embed_dim * 2, bias=True)
183
+
184
+ self.proj = nn.Linear(embed_dim, embed_dim)
185
+ self.proj_drop = (
186
+ nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
187
+ )
188
+ self.attn_drop = attn_drop
189
+
190
+ # only used during inference
191
+ self.caching, self.cached_k, self.cached_v = False, None, None
192
+
193
+ def kv_caching(self, enable: bool):
194
+ self.caching, self.cached_k, self.cached_v = enable, None, None
195
+
196
+ def forward(self, x, context, context_attn_bias=None, freqs_cis=None):
197
+ B, L, C = x.shape
198
+ context_B, context_L, context_C = context.shape
199
+ assert B == context_B
200
+
201
+ q = self.to_q(x).view(B, L, -1) # BLD , self.num_heads, self.head_dim)
202
+ if self.qk_norm:
203
+ q = self.q_norm(q)
204
+
205
+ q = q.view(B, L, self.num_heads, self.head_dim)
206
+ q = q.permute(0, 2, 1, 3) # BHLc
207
+
208
+ if self.cached_k is None:
209
+ # not using caches or first scale inference
210
+ kv = self.to_kv(context).view(B, context_L, 2, -1) # qkv: BL3D
211
+ k, v = kv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLHD
212
+
213
+ if self.qk_norm:
214
+ k = self.k_norm(k)
215
+
216
+ k = k.view(B, context_L, self.num_heads, self.head_dim)
217
+ k = k.permute(0, 2, 1, 3) # BHLc
218
+
219
+ v = v.view(B, context_L, self.num_heads, self.head_dim)
220
+ v = v.permute(0, 2, 1, 3) # BHLc
221
+
222
+ if self.caching:
223
+ self.cached_k = k
224
+ self.cached_v = v
225
+ else:
226
+ k = self.cached_k
227
+ v = self.cached_v
228
+
229
+ if context_attn_bias is not None:
230
+ context_attn_bias = rearrange(context_attn_bias, "b j -> b 1 1 j")
231
+
232
+ dropout_p = self.attn_drop if self.training else 0.0
233
+ out = (
234
+ scaled_dot_product_attention(
235
+ query=q,
236
+ key=k,
237
+ value=v,
238
+ scale=self.scale,
239
+ attn_mask=context_attn_bias,
240
+ dropout_p=dropout_p,
241
+ )
242
+ .transpose(1, 2)
243
+ .reshape(B, L, C)
244
+ )
245
+
246
+ return self.proj_drop(self.proj(out))
247
+
248
+
249
+ class SelfAttention(nn.Module):
250
+ def __init__(
251
+ self,
252
+ block_idx: int,
253
+ embed_dim: int = 768,
254
+ num_heads: int = 12,
255
+ attn_drop: float = 0.0,
256
+ proj_drop: float = 0.0,
257
+ qk_norm: bool = False,
258
+ ):
259
+ super().__init__()
260
+ assert embed_dim % num_heads == 0
261
+ self.block_idx, self.num_heads, self.head_dim = (
262
+ block_idx,
263
+ num_heads,
264
+ embed_dim // num_heads,
265
+ )
266
+ self.qk_norm = qk_norm
267
+ self.scale = 1 / math.sqrt(self.head_dim)
268
+
269
+ self.q_norm = nn.LayerNorm(embed_dim, eps=1e-6)
270
+ self.k_norm = nn.LayerNorm(embed_dim, eps=1e-6)
271
+
272
+ self.to_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
273
+ self.proj = nn.Linear(embed_dim, embed_dim)
274
+ self.proj_drop = (
275
+ nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
276
+ )
277
+ self.attn_drop = attn_drop
278
+
279
+ # only used during inference
280
+ self.caching, self.cached_k, self.cached_v = False, None, None
281
+
282
+ def kv_caching(self, enable: bool):
283
+ self.caching, self.cached_k, self.cached_v = enable, None, None
284
+
285
+ # NOTE: attn_bias is None during inference because kv cache is enabled
286
+ def forward(self, x, attn_bias, freqs_cis: torch.Tensor = None):
287
+ B, L, C = x.shape
288
+
289
+ qkv = self.to_qkv(x).view(B, L, 3, -1)
290
+ q, k, v = qkv.permute(2, 0, 1, 3).unbind(dim=0) # q or k or v: BLD
291
+
292
+ if self.qk_norm:
293
+ q = self.q_norm(q)
294
+ k = self.k_norm(k)
295
+
296
+ q = q.view(B, L, self.num_heads, self.head_dim)
297
+ q = q.permute(0, 2, 1, 3) # BHLc
298
+ k = k.view(B, L, self.num_heads, self.head_dim)
299
+ k = k.permute(0, 2, 1, 3) # BHLc
300
+ v = v.view(B, L, self.num_heads, self.head_dim)
301
+ v = v.permute(0, 2, 1, 3) # BHLc
302
+ dim_cat = 2
303
+
304
+ if freqs_cis is not None:
305
+ q = apply_rotary_emb(q, freqs_cis=freqs_cis)
306
+ k = apply_rotary_emb(k, freqs_cis=freqs_cis)
307
+
308
+ if self.caching:
309
+ if self.cached_k is None:
310
+ self.cached_k = k
311
+ self.cached_v = v
312
+ else:
313
+ k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat)
314
+ v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
315
+
316
+ dropout_p = self.attn_drop if self.training else 0.0
317
+ out = (
318
+ scaled_dot_product_attention(
319
+ query=q,
320
+ key=k,
321
+ value=v,
322
+ scale=self.scale,
323
+ attn_mask=attn_bias,
324
+ dropout_p=dropout_p,
325
+ )
326
+ .transpose(1, 2)
327
+ .reshape(B, L, C)
328
+ )
329
+
330
+ return self.proj_drop(self.proj(out))
331
+
332
+ def extra_repr(self) -> str:
333
+ return f"attn_l2_norm={self.qk_norm}"
334
+
335
+
336
+ class AdaLNSelfCrossAttn(nn.Module):
337
+ def __init__(
338
+ self,
339
+ block_idx,
340
+ last_drop_p,
341
+ embed_dim,
342
+ cond_dim,
343
+ shared_aln: bool,
344
+ num_heads,
345
+ mlp_ratio=4.0,
346
+ drop=0.0,
347
+ attn_drop=0.0,
348
+ drop_path=0.0,
349
+ qk_norm=False,
350
+ context_dim=None,
351
+ use_swiglu_ffn=False,
352
+ norm_eps=1e-6,
353
+ ):
354
+ super().__init__()
355
+ assert attn_drop == 0.0
356
+ assert qk_norm
357
+
358
+ self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
359
+ self.C, self.D = embed_dim, cond_dim
360
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
361
+ self.attn = SelfAttention(
362
+ block_idx=block_idx,
363
+ embed_dim=embed_dim,
364
+ num_heads=num_heads,
365
+ attn_drop=attn_drop,
366
+ proj_drop=drop,
367
+ qk_norm=qk_norm,
368
+ )
369
+
370
+ if context_dim:
371
+ self.cross_attn = CrossAttention(
372
+ embed_dim=embed_dim,
373
+ context_dim=context_dim,
374
+ num_heads=num_heads,
375
+ attn_drop=attn_drop,
376
+ proj_drop=drop,
377
+ qk_norm=qk_norm,
378
+ )
379
+ else:
380
+ self.cross_attn = None
381
+
382
+ if use_swiglu_ffn:
383
+ self.ffn = SwiGLUFFN(dim=embed_dim)
384
+ else:
385
+ self.ffn = FFN(
386
+ in_features=embed_dim,
387
+ hidden_features=round(embed_dim * mlp_ratio),
388
+ drop=drop,
389
+ )
390
+
391
+ self.self_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
392
+ self.self_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
393
+ self.cross_attention_norm1 = RMSNorm(embed_dim, eps=norm_eps)
394
+ self.cross_attention_norm2 = RMSNorm(embed_dim, eps=norm_eps)
395
+
396
+ self.ffn_norm1 = RMSNorm(embed_dim, eps=norm_eps)
397
+ self.ffn_norm2 = RMSNorm(embed_dim, eps=norm_eps)
398
+
399
+ self.attention_y_norm = RMSNorm(context_dim, eps=norm_eps)
400
+
401
+ self.shared_aln = shared_aln
402
+ if self.shared_aln:
403
+ self.ada_gss = nn.Parameter(
404
+ torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5
405
+ )
406
+ else:
407
+ lin = nn.Linear(cond_dim, 6 * embed_dim)
408
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
409
+
410
+ self.fused_add_norm_fn = None
411
+
412
+ # NOTE: attn_bias is None during inference because kv cache is enabled
413
+ def forward(
414
+ self,
415
+ x,
416
+ cond_BD,
417
+ attn_bias,
418
+ context=None,
419
+ context_attn_bias=None,
420
+ freqs_cis=None,
421
+ ): # C: embed_dim, D: cond_dim
422
+ if self.shared_aln:
423
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (
424
+ self.ada_gss + cond_BD
425
+ ).unbind(
426
+ 2
427
+ ) # 116C + B16C =unbind(2)=> 6 B1C
428
+ else:
429
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (
430
+ self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
431
+ )
432
+ x = x + self.self_attention_norm2(
433
+ self.attn(
434
+ self.self_attention_norm1(x).mul(scale1.add(1)).add(shift1),
435
+ attn_bias=attn_bias,
436
+ freqs_cis=freqs_cis,
437
+ ).mul(gamma1)
438
+ )
439
+ if context is not None:
440
+ x = x + self.cross_attention_norm2(
441
+ self.cross_attn(
442
+ self.cross_attention_norm1(x),
443
+ self.attention_y_norm(context),
444
+ context_attn_bias=context_attn_bias,
445
+ freqs_cis=freqs_cis,
446
+ )
447
+ )
448
+ x = x + self.ffn_norm2(
449
+ self.ffn(self.ffn_norm1(x).mul(scale2.add(1)).add(shift2)).mul(gamma2)
450
+ )
451
+ return x
452
+
453
+ def extra_repr(self) -> str:
454
+ return f"shared_aln={self.shared_aln}"
455
+
456
+
457
+ class AdaLNBeforeHead(nn.Module):
458
+ def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
459
+ super().__init__()
460
+ self.C, self.D = C, D
461
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
462
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2 * C))
463
+
464
+ def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
465
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
466
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
models/clip.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPTextModel, CLIPTokenizer
6
+
7
+
8
+ class FrozenCLIPEmbedder(nn.Module):
9
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
10
+
11
+ def __init__(
12
+ self,
13
+ version="openai/clip-vit-large-patch14",
14
+ device="cuda",
15
+ max_length=77,
16
+ freeze=True,
17
+ ): # clip-vit-base-patch32
18
+ super().__init__()
19
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
20
+ self.transformer = CLIPTextModel.from_pretrained(version).to(device)
21
+ self.device = device
22
+ self.hidden_size = self.transformer.config.hidden_size
23
+ self.max_length = max_length
24
+ if freeze:
25
+ self.freeze()
26
+
27
+ def freeze(self):
28
+ self.transformer = self.transformer.eval()
29
+ # self.train = disabled_train
30
+ for param in self.parameters():
31
+ param.requires_grad = False
32
+
33
+ def forward(self, text):
34
+ batch_encoding = self.tokenizer(
35
+ text,
36
+ truncation=True,
37
+ max_length=self.max_length,
38
+ return_overflowing_tokens=False,
39
+ padding="max_length",
40
+ return_tensors="pt",
41
+ ).to(self.device)
42
+
43
+ with torch.cuda.amp.autocast():
44
+ outputs = self.transformer(**batch_encoding)
45
+
46
+ attn_bias = batch_encoding["attention_mask"].float()
47
+ attn_bias[attn_bias == 0] = -float("inf")
48
+ attn_bias[attn_bias == 1] = 0.0
49
+ outputs["attn_bias"] = attn_bias
50
+ return outputs
51
+
52
+ def encode(self, text):
53
+ return self(text)
models/helpers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def sample_with_top_k_top_p_(
7
+ logits_BlV: torch.Tensor,
8
+ top_k: int = 0,
9
+ top_p: float = 0.0,
10
+ rng=None,
11
+ num_samples=1,
12
+ ) -> torch.Tensor: # return idx, shaped (B, l)
13
+ B, l, V = logits_BlV.shape
14
+ if top_k > 0:
15
+ idx_to_remove = logits_BlV < logits_BlV.topk(
16
+ top_k, largest=True, sorted=False, dim=-1
17
+ )[0].amin(dim=-1, keepdim=True)
18
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
19
+ if top_p > 0:
20
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
21
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
22
+ sorted_idx_to_remove[..., -1:] = False
23
+ logits_BlV.masked_fill_(
24
+ sorted_idx_to_remove.scatter(
25
+ sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove
26
+ ),
27
+ -torch.inf,
28
+ )
29
+ # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
30
+ replacement = num_samples >= 0
31
+ num_samples = abs(num_samples)
32
+ return torch.multinomial(
33
+ logits_BlV.softmax(dim=-1).view(-1, V),
34
+ num_samples=num_samples,
35
+ replacement=replacement,
36
+ generator=rng,
37
+ ).view(B, l, num_samples)
38
+
39
+
40
+ def gumbel_softmax_with_rng(
41
+ logits: torch.Tensor,
42
+ tau: float = 1,
43
+ hard: bool = False,
44
+ eps: float = 1e-10,
45
+ dim: int = -1,
46
+ rng: torch.Generator = None,
47
+ ) -> torch.Tensor:
48
+ if rng is None:
49
+ return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
50
+
51
+ gumbels = (
52
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
53
+ .exponential_(generator=rng)
54
+ .log()
55
+ )
56
+ gumbels = (logits + gumbels) / tau
57
+ y_soft = gumbels.softmax(dim)
58
+
59
+ if hard:
60
+ index = y_soft.max(dim, keepdim=True)[1]
61
+ y_hard = torch.zeros_like(
62
+ logits, memory_format=torch.legacy_contiguous_format
63
+ ).scatter_(dim, index, 1.0)
64
+ ret = y_hard - y_soft.detach() + y_soft
65
+ else:
66
+ ret = y_soft
67
+ return ret
68
+
69
+
70
+ def drop_path(
71
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
72
+ ): # taken from timm
73
+ if drop_prob == 0.0 or not training:
74
+ return x
75
+ keep_prob = 1 - drop_prob
76
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
77
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
78
+ if keep_prob > 0.0 and scale_by_keep:
79
+ random_tensor.div_(keep_prob)
80
+ return x * random_tensor
81
+
82
+
83
+ class DropPath(nn.Module): # taken from timm
84
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
85
+ super(DropPath, self).__init__()
86
+ self.drop_prob = drop_prob
87
+ self.scale_by_keep = scale_by_keep
88
+
89
+ def forward(self, x):
90
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
91
+
92
+ def extra_repr(self):
93
+ return f"(drop_prob=...)"
models/pipeline.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torchvision.transforms import ToPILImage
5
+
6
+ from models.vqvae import VQVAEHF
7
+ from models.clip import FrozenCLIPEmbedder
8
+ from models.var import TVARHF, sample_with_top_k_top_p_, gumbel_softmax_with_rng
9
+
10
+
11
+ class TVARPipeline:
12
+ vae_path = "michellemoorre/vae-test"
13
+ text_encoder_path = "openai/clip-vit-large-patch14"
14
+ text_encoder_2_path = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15
+
16
+ def __init__(self, var, vae, text_encoder, text_encoder_2, device):
17
+ self.var = var
18
+ self.vae = vae
19
+ self.text_encoder = text_encoder
20
+ self.text_encoder_2 = text_encoder_2
21
+
22
+ self.var.eval()
23
+ self.vae.eval()
24
+
25
+ self.device = device
26
+
27
+
28
+ @classmethod
29
+ def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
30
+ var = TVARHF.from_pretrained(pretrained_model_name_or_path).to(device)
31
+ vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
32
+ text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
33
+ text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
34
+
35
+ return cls(var, vae, text_encoder, text_encoder_2, device)
36
+
37
+
38
+ @staticmethod
39
+ def to_image(tensor):
40
+ return [ToPILImage()(
41
+ (255 * img.cpu().detach()).to(torch.uint8))
42
+ for img in tensor]
43
+
44
+
45
+ def encode_prompt(
46
+ self,
47
+ prompt: Union[str, List[str]],
48
+ null_prompt: str = "",
49
+ encode_null: bool = True,
50
+ ):
51
+ prompt = [prompt] if isinstance(prompt, str) else prompt
52
+ encodings = [
53
+ self.text_encoder.encode(prompt),
54
+ self.text_encoder_2.encode(prompt),
55
+ ]
56
+ prompt_embeds = torch.concat(
57
+ [encoding.last_hidden_state for encoding in encodings], dim=-1
58
+ )
59
+ pooled_prompt_embeds = encodings[-1].pooler_output
60
+ attn_bias = encodings[-1].attn_bias
61
+
62
+ if encode_null:
63
+ null_prompt = [null_prompt] if isinstance(null_prompt, str) else prompt
64
+ null_encodings = [
65
+ self.text_encoder.encode(null_prompt),
66
+ self.text_encoder_2.encode(null_prompt),
67
+ ]
68
+ null_prompt_embeds = torch.concat(
69
+ [encoding.last_hidden_state for encoding in encodings], dim=-1
70
+ )
71
+ null_pooled_prompt_embeds = null_encodings[-1].pooler_output
72
+ null_attn_bias = null_encodings[-1].attn_bias
73
+
74
+ B, L, hidden_dim = prompt_embeds.shape
75
+ pooled_dim = pooled_prompt_embeds.shape[1]
76
+
77
+ null_prompt_embeds = null_prompt_embeds[:, :L].expand(B, L, hidden_dim).to(prompt_embeds.device)
78
+ null_pooled_prompt_embeds = null_pooled_prompt_embeds.expand(B, pooled_dim).to(pooled_prompt_embeds.device)
79
+ null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attn_bias.device)
80
+
81
+ prompt_embeds = torch.cat([prompt_embeds, null_prompt_embeds], dim=0)
82
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, null_pooled_prompt_embeds], dim=0)
83
+ attn_bias = torch.cat([attn_bias, null_attn_bias], dim=0)
84
+
85
+ return prompt_embeds, pooled_prompt_embeds, attn_bias
86
+
87
+ @torch.inference_mode()
88
+ def __call__(
89
+ self,
90
+ prompt = None,
91
+ null_prompt = "",
92
+ g_seed: Optional[int] = None,
93
+ cfg=4.0,
94
+ top_k=450,
95
+ top_p=0.95,
96
+ more_smooth=False,
97
+ re=False,
98
+ re_max_depth=10,
99
+ return_pil=True,
100
+ encoded_prompt = None,
101
+ encoded_null_prompt = None,
102
+ ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
103
+ """
104
+ only used for inference, on autoregressive mode
105
+ :param B: batch size
106
+ :param label_B: imagenet label; if None, randomly sampled
107
+ :param g_seed: random seed
108
+ :param cfg: classifier-free guidance ratio
109
+ :param top_k: top-k sampling
110
+ :param top_p: top-p sampling
111
+ :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
112
+ :return: if returns_vemb: list of embedding h_BChw := vae_embed(idx_Bl), else: list of idx_Bl
113
+ """
114
+ assert not self.var.training
115
+ var = self.var
116
+ vae = self.vae
117
+ vae_quant = self.vae.quantize
118
+ if g_seed is None:
119
+ rng = None
120
+ else:
121
+ var.rng.manual_seed(g_seed)
122
+ rng = var.rng
123
+
124
+ if encoded_prompt is not None:
125
+ assert encoded_null_prompt is not None
126
+ context, cond_vector, context_attn_bias = self.var.parse_batch(
127
+ encoded_prompt,
128
+ encoded_null_prompt,
129
+ )
130
+ else:
131
+ context, cond_vector, context_attn_bias = self.encode_prompt(prompt, null_prompt)
132
+
133
+ B = context.shape[0] // 2
134
+
135
+ cond_vector = var.text_pooler(cond_vector)
136
+
137
+ sos = cond_BD = cond_vector
138
+
139
+ lvl_pos = var.lvl_embed(var.lvl_1L)
140
+ if not var.rope:
141
+ lvl_pos += var.pos_1LC
142
+ next_token_map = (
143
+ sos.unsqueeze(1)
144
+ + var.pos_start.expand(2 * B, var.first_l, -1)
145
+ + lvl_pos[:, : var.first_l]
146
+ )
147
+ cur_L = 0
148
+ f_hat = sos.new_zeros(B, var.Cvae, var.patch_nums[-1], var.patch_nums[-1])
149
+
150
+ for b in var.blocks:
151
+ b.attn.kv_caching(True)
152
+ b.cross_attn.kv_caching(True)
153
+
154
+ for si, pn in enumerate(var.patch_nums): # si: i-th segment
155
+ ratio = si / var.num_stages_minus_1
156
+ cond_BD_or_gss = var.shared_ada_lin(cond_BD)
157
+ x_BLC = next_token_map
158
+
159
+ if var.rope:
160
+ freqs_cis = var.freqs_cis[:, cur_L : cur_L + pn * pn]
161
+ else:
162
+ freqs_cis = var.freqs_cis
163
+
164
+ for block in var.blocks:
165
+ x_BLC = block(
166
+ x=x_BLC,
167
+ cond_BD=cond_BD_or_gss,
168
+ attn_bias=None,
169
+ context=context,
170
+ context_attn_bias=context_attn_bias,
171
+ freqs_cis=freqs_cis,
172
+ )
173
+ cur_L += pn * pn
174
+
175
+ logits_BlV = var.get_logits(x_BLC, cond_BD)
176
+
177
+ t = cfg * ratio
178
+ logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
179
+
180
+ idx_Bl = sample_with_top_k_top_p_(
181
+ logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
182
+ )[:, :, 0]
183
+ if re:
184
+ selected_logits = torch.gather(logits_BlV, -1, idx_Bl.unsqueeze(-1))[:, :, 0]
185
+ mx = selected_logits.sum(dim=-1)[:, None]
186
+ for _ in range(re_max_depth):
187
+ new_idx_Bl = sample_with_top_k_top_p_(
188
+ logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1
189
+ )[:, :, 0]
190
+ selected_logits = torch.gather(logits_BlV, -1, new_idx_Bl.unsqueeze(-1))[:, :, 0]
191
+
192
+ new_mx = selected_logits.sum(dim=-1)[:, None]
193
+ idx_Bl = idx_Bl * (mx >= new_mx) + new_idx_Bl * (mx < new_mx)
194
+ mx = mx * (mx >= new_mx) + new_mx * (mx < new_mx)
195
+ if not more_smooth: # this is the default case
196
+ h_BChw = vae_quant.embedding(idx_Bl) # B, l, Cvae
197
+ else: # not used when evaluating FID/IS/Precision/Recall
198
+ gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
199
+ h_BChw = gumbel_softmax_with_rng(
200
+ logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng
201
+ ) @ vae_quant.embedding.weight.unsqueeze(0)
202
+
203
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, var.Cvae, pn, pn)
204
+ f_hat, next_token_map = vae_quant.get_next_autoregressive_input(
205
+ si, len(var.patch_nums), f_hat, h_BChw
206
+ )
207
+ if si != var.num_stages_minus_1: # prepare for next stage
208
+ next_token_map = next_token_map.view(B, var.Cvae, -1).transpose(1, 2)
209
+ next_token_map = (
210
+ var.word_embed(next_token_map)
211
+ + lvl_pos[:, cur_L : cur_L + var.patch_nums[si + 1] ** 2]
212
+ )
213
+ next_token_map = next_token_map.repeat(
214
+ 2, 1, 1
215
+ ) # double the batch sizes due to CFG
216
+
217
+ for b in var.blocks:
218
+ b.attn.kv_caching(False)
219
+ b.cross_attn.kv_caching(False)
220
+
221
+ # de-normalize, from [-1, 1] to [0, 1]
222
+ img = vae.fhat_to_img(f_hat).add(1).mul(0.5)
223
+ if return_pil:
224
+ img = self.to_image(img)
225
+ return img
226
+
models/quant.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import distributed as tdist
7
+ from torch import nn as nn
8
+ from torch.nn import functional as F
9
+
10
+ import dist
11
+
12
+ # this file only provides the VectorQuantizer2 used in VQVAE
13
+ __all__ = ["VectorQuantizer2"]
14
+
15
+
16
+ class VectorQuantizer2(nn.Module):
17
+ # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
18
+ def __init__(
19
+ self,
20
+ vocab_size,
21
+ Cvae,
22
+ using_znorm,
23
+ beta: float = 0.25,
24
+ default_qresi_counts=0,
25
+ v_patch_nums=None,
26
+ quant_resi=0.5,
27
+ share_quant_resi=4, # share_quant_resi: args.qsr
28
+ ):
29
+ super().__init__()
30
+ self.vocab_size: int = vocab_size
31
+ self.Cvae: int = Cvae
32
+ self.using_znorm: bool = using_znorm
33
+ self.v_patch_nums: Tuple[int] = v_patch_nums
34
+
35
+ self.quant_resi_ratio = quant_resi
36
+ if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
37
+ self.quant_resi = PhiNonShared(
38
+ [
39
+ (Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
40
+ for _ in range(default_qresi_counts or len(self.v_patch_nums))
41
+ ]
42
+ )
43
+ elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
44
+ self.quant_resi = PhiShared(
45
+ Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()
46
+ )
47
+ else: # partially shared: \phi_{1 to share_quant_resi} for K scales
48
+ self.quant_resi = PhiPartiallyShared(
49
+ nn.ModuleList([(
50
+ Phi(Cvae, quant_resi)
51
+ if abs(quant_resi) > 1e-6
52
+ else nn.Identity()
53
+ ) for _ in range(share_quant_resi)])
54
+ )
55
+
56
+ self.register_buffer(
57
+ "ema_vocab_hit_SV",
58
+ torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0),
59
+ )
60
+ self.record_hit = 0
61
+
62
+ self.beta: float = beta
63
+ self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
64
+
65
+ def eini(self, eini):
66
+ if eini > 0:
67
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
68
+ elif eini < 0:
69
+ self.embedding.weight.data.uniform_(
70
+ -abs(eini) / self.vocab_size, abs(eini) / self.vocab_size
71
+ )
72
+
73
+ def extra_repr(self) -> str:
74
+ return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}"
75
+
76
+ # ===================== `forward` is only used in VAE training =====================
77
+ def forward(
78
+ self, f_BChw: torch.Tensor, ret_usages=False
79
+ ) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
80
+ dtype = f_BChw.dtype
81
+ if dtype != torch.float32:
82
+ f_BChw = f_BChw.float()
83
+ B, C, H, W = f_BChw.shape
84
+ f_no_grad = f_BChw.detach()
85
+
86
+ f_rest = f_no_grad.clone()
87
+ f_hat = torch.zeros_like(f_rest)
88
+
89
+ with torch.cuda.amp.autocast(enabled=False):
90
+ mean_vq_loss: torch.Tensor = 0.0
91
+ vocab_hit_V = torch.zeros(
92
+ self.vocab_size, dtype=torch.float, device=f_BChw.device
93
+ )
94
+ SN = len(self.v_patch_nums)
95
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
96
+ # find the nearest embedding
97
+ if self.using_znorm:
98
+ rest_NC = (
99
+ F.interpolate(f_rest, size=(pn, pn), mode="area")
100
+ .permute(0, 2, 3, 1)
101
+ .reshape(-1, C)
102
+ if (si != SN - 1)
103
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
104
+ )
105
+ rest_NC = F.normalize(rest_NC, dim=-1)
106
+ idx_N = torch.argmax(
107
+ rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0),
108
+ dim=1,
109
+ )
110
+ else:
111
+ rest_NC = (
112
+ F.interpolate(f_rest, size=(pn, pn), mode="area")
113
+ .permute(0, 2, 3, 1)
114
+ .reshape(-1, C)
115
+ if (si != SN - 1)
116
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
117
+ )
118
+ d_no_grad = torch.sum(
119
+ rest_NC.square(), dim=1, keepdim=True
120
+ ) + torch.sum(
121
+ self.embedding.weight.data.square(), dim=1, keepdim=False
122
+ )
123
+ d_no_grad.addmm_(
124
+ rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1
125
+ ) # (B*h*w, vocab_size)
126
+ idx_N = torch.argmin(d_no_grad, dim=1)
127
+
128
+ hit_V = idx_N.bincount(minlength=self.vocab_size).float()
129
+ if self.training:
130
+ if dist.initialized():
131
+ handler = tdist.all_reduce(hit_V, async_op=True)
132
+
133
+ # calc loss
134
+ idx_Bhw = idx_N.view(B, pn, pn)
135
+ h_BChw = (
136
+ F.interpolate(
137
+ self.embedding(idx_Bhw).permute(0, 3, 1, 2),
138
+ size=(H, W),
139
+ mode="bicubic",
140
+ ).contiguous()
141
+ if (si != SN - 1)
142
+ else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
143
+ )
144
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
145
+ f_hat = f_hat + h_BChw
146
+ f_rest -= h_BChw
147
+
148
+ if self.training and dist.initialized():
149
+ handler.wait()
150
+ if self.record_hit == 0:
151
+ self.ema_vocab_hit_SV[si].copy_(hit_V)
152
+ elif self.record_hit < 100:
153
+ self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
154
+ else:
155
+ self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
156
+ self.record_hit += 1
157
+ vocab_hit_V.add_(hit_V)
158
+ mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
159
+
160
+ mean_vq_loss *= 1.0 / SN
161
+ f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
162
+
163
+ margin = (
164
+ tdist.get_world_size()
165
+ * (f_BChw.numel() / f_BChw.shape[1])
166
+ / self.vocab_size
167
+ * 0.08
168
+ )
169
+ # margin = pn*pn / 100
170
+ if ret_usages:
171
+ usages = [
172
+ (self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100
173
+ for si, pn in enumerate(self.v_patch_nums)
174
+ ]
175
+ else:
176
+ usages = None
177
+ return f_hat, usages, mean_vq_loss
178
+
179
+ # ===================== `forward` is only used in VAE training =====================
180
+
181
+ def embed_to_fhat(
182
+ self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False
183
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
184
+ ls_f_hat_BChw = []
185
+ B = ms_h_BChw[0].shape[0]
186
+ H = W = self.v_patch_nums[-1]
187
+ SN = len(self.v_patch_nums)
188
+ if all_to_max_scale:
189
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
190
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
191
+ h_BChw = ms_h_BChw[si]
192
+ if si < len(self.v_patch_nums) - 1:
193
+ h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic")
194
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
195
+ f_hat.add_(h_BChw)
196
+ if last_one:
197
+ ls_f_hat_BChw = f_hat
198
+ else:
199
+ ls_f_hat_BChw.append(f_hat.clone())
200
+ else:
201
+ # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
202
+ # WARNING: this should only be used for experimental purpose
203
+ f_hat = ms_h_BChw[0].new_zeros(
204
+ B,
205
+ self.Cvae,
206
+ self.v_patch_nums[0],
207
+ self.v_patch_nums[0],
208
+ dtype=torch.float32,
209
+ )
210
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
211
+ f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic")
212
+ h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
213
+ f_hat.add_(h_BChw)
214
+ if last_one:
215
+ ls_f_hat_BChw = f_hat
216
+ else:
217
+ ls_f_hat_BChw.append(f_hat)
218
+
219
+ return ls_f_hat_BChw
220
+
221
+ def f_to_idxBl_or_fhat(
222
+ self,
223
+ f_BChw: torch.Tensor,
224
+ to_fhat: bool,
225
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
226
+ noise_std: Optional[float] = None,
227
+ ) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
228
+ B, C, H, W = f_BChw.shape
229
+ f_no_grad = f_BChw.detach()
230
+ f_rest = f_no_grad.clone()
231
+ f_hat = torch.zeros_like(f_rest)
232
+
233
+ f_hat_or_idx_Bl: List[torch.Tensor] = []
234
+
235
+ patch_hws = [
236
+ (pn, pn) if isinstance(pn, int) else (pn[0], pn[1])
237
+ for pn in (v_patch_nums or self.v_patch_nums)
238
+ ] # from small to large
239
+ assert (
240
+ patch_hws[-1][0] == H and patch_hws[-1][1] == W
241
+ ), f"{patch_hws[-1]=} != ({H=}, {W=})"
242
+
243
+ SN = len(patch_hws)
244
+ for si, (ph, pw) in enumerate(patch_hws): # from small to large
245
+ # find the nearest embedding
246
+ z_NC = (
247
+ F.interpolate(f_rest, size=(ph, pw), mode="area")
248
+ .permute(0, 2, 3, 1)
249
+ .reshape(-1, C)
250
+ if (si != SN - 1)
251
+ else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
252
+ )
253
+ if noise_std is not None:
254
+ z_NC = math.sqrt(1 - noise_std ** 2) * z_NC + torch.randn_like(z_NC) * noise_std
255
+
256
+ if self.using_znorm:
257
+ z_NC = F.normalize(z_NC, dim=-1)
258
+ idx_N = torch.argmax(
259
+ z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1
260
+ )
261
+ else:
262
+ d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
263
+ self.embedding.weight.data.square(), dim=1, keepdim=False
264
+ )
265
+ d_no_grad.addmm_(
266
+ z_NC, self.embedding.weight.data.T, alpha=-2, beta=1
267
+ ) # (B*h*w, vocab_size)
268
+ idx_N = torch.argmin(d_no_grad, dim=1)
269
+
270
+ idx_Bhw = idx_N.view(B, ph, pw)
271
+ h_BChw = (
272
+ F.interpolate(
273
+ self.embedding(idx_Bhw).permute(0, 3, 1, 2),
274
+ size=(H, W),
275
+ mode="bicubic",
276
+ ).contiguous()
277
+ if (si != SN - 1)
278
+ else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
279
+ )
280
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
281
+ f_hat.add_(h_BChw)
282
+ f_rest.sub_(h_BChw)
283
+ f_hat_or_idx_Bl.append(
284
+ f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw)
285
+ )
286
+
287
+ return f_hat_or_idx_Bl
288
+
289
+ # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
290
+ def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
291
+ next_scales = []
292
+ B = gt_ms_idx_Bl[0].shape[0]
293
+ C = self.Cvae
294
+ H = W = self.v_patch_nums[-1]
295
+ SN = len(self.v_patch_nums)
296
+
297
+ f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
298
+ pn_next: int = self.v_patch_nums[0]
299
+ for si in range(SN - 1):
300
+ h_BChw = F.interpolate(
301
+ self.embedding(gt_ms_idx_Bl[si])
302
+ .transpose_(1, 2)
303
+ .view(B, C, pn_next, pn_next),
304
+ size=(H, W),
305
+ mode="bicubic",
306
+ )
307
+ f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
308
+ pn_next = self.v_patch_nums[si + 1]
309
+ next_scales.append(
310
+ F.interpolate(f_hat, size=(pn_next, pn_next), mode="area")
311
+ .view(B, C, -1)
312
+ .transpose(1, 2)
313
+ )
314
+ # cat BlCs to BLC, this should be float32
315
+ return torch.cat(next_scales, dim=1) if len(next_scales) else None
316
+
317
+ # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
318
+ def get_next_autoregressive_input(
319
+ self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor
320
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
321
+ HW = self.v_patch_nums[-1]
322
+ if si != SN - 1:
323
+ h = self.quant_resi[si / (SN - 1)](
324
+ F.interpolate(h_BChw, size=(HW, HW), mode="bicubic")
325
+ ) # conv after upsample
326
+ f_hat.add_(h)
327
+ return f_hat, F.interpolate(
328
+ f_hat,
329
+ size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
330
+ mode="area",
331
+ )
332
+ else:
333
+ h = self.quant_resi[si / (SN - 1)](h_BChw)
334
+ f_hat.add_(h)
335
+ return f_hat, f_hat
336
+
337
+
338
+ class Phi(nn.Conv2d):
339
+ def __init__(self, embed_dim, quant_resi):
340
+ ks = 3
341
+ super().__init__(
342
+ in_channels=embed_dim,
343
+ out_channels=embed_dim,
344
+ kernel_size=ks,
345
+ stride=1,
346
+ padding=ks // 2,
347
+ )
348
+ self.resi_ratio = abs(quant_resi)
349
+
350
+ def forward(self, h_BChw):
351
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(
352
+ self.resi_ratio
353
+ )
354
+
355
+
356
+ class PhiShared(nn.Module):
357
+ def __init__(self, qresi: Phi):
358
+ super().__init__()
359
+ self.qresi: Phi = qresi
360
+
361
+ def __getitem__(self, _) -> Phi:
362
+ return self.qresi
363
+
364
+
365
+ class PhiPartiallyShared(nn.Module):
366
+ def __init__(self, qresi_ls: nn.ModuleList):
367
+ super().__init__()
368
+ self.qresi_ls = qresi_ls
369
+ K = len(qresi_ls)
370
+ self.ticks = (
371
+ np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
372
+ if K == 4
373
+ else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
374
+ )
375
+
376
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
377
+ return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
378
+
379
+ def extra_repr(self) -> str:
380
+ return f"ticks={self.ticks}"
381
+
382
+
383
+ class PhiNonShared(nn.ModuleList):
384
+ def __init__(self, qresi: List):
385
+ super().__init__(qresi)
386
+ # self.qresi = qresi
387
+ K = len(qresi)
388
+ self.ticks = (
389
+ np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
390
+ if K == 4
391
+ else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
392
+ )
393
+
394
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
395
+ return super().__getitem__(
396
+ np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()
397
+ )
398
+
399
+ def extra_repr(self) -> str:
400
+ return f"ticks={self.ticks}"
models/rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def init_t_xy(end_x: int, end_y: int):
5
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
6
+ t_x = (t % end_x).float()
7
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
8
+ return t_x, t_y
9
+
10
+
11
+ def compute_axial_cis(
12
+ dim: int, end_x: int, end_y: int, theta: float = 100.0, norm_coeff: int = 1
13
+ ):
14
+ freqs_x = (
15
+ 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
16
+ * norm_coeff
17
+ )
18
+ freqs_y = (
19
+ 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
20
+ * norm_coeff
21
+ )
22
+
23
+ t_x, t_y = init_t_xy(end_x, end_y)
24
+ freqs_x = torch.outer(t_x, freqs_x)
25
+ freqs_y = torch.outer(t_y, freqs_y)
26
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
27
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
28
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
29
+
30
+
31
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
32
+ ndim = x.ndim
33
+ assert 0 <= 1 < ndim
34
+ freqs_cis = freqs_cis[:, x.shape[1], ...]
35
+ if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
36
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
37
+ elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
38
+ shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
39
+ return freqs_cis.view(*shape)
40
+
41
+
42
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor):
43
+ with torch.cuda.amp.autocast(enabled=False):
44
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
45
+ # freqs_cis = reshape_for_broadcast(freqs_cis, x).to(x_in.device)
46
+ freqs_cis = freqs_cis[None, :, : x.shape[2], ...].to(x_in.device)
47
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
48
+ return x_out.type_as(x_in)
models/var.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ import dist
10
+ from models.basic_var import AdaLNBeforeHead, AdaLNSelfCrossAttn
11
+ from models.clip import FrozenCLIPEmbedder
12
+ from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
13
+ from models.rope import compute_axial_cis
14
+ from models.vqvae import VQVAE, VectorQuantizer2
15
+
16
+
17
+ class SharedAdaLin(nn.Linear):
18
+ def forward(self, cond_BD):
19
+ C = self.weight.shape[0] // 6
20
+ return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
21
+
22
+
23
+ class VAR(nn.Module):
24
+ def __init__(
25
+ self,
26
+ rope=False,
27
+ rope_theta=100,
28
+ rope_size=None,
29
+ depth=16,
30
+ embed_dim=1024,
31
+ num_heads=16,
32
+ mlp_ratio=4.0,
33
+ drop_rate=0.0,
34
+ attn_drop_rate=0.0,
35
+ drop_path_rate=0.0,
36
+ norm_eps=1e-6,
37
+ shared_aln=False,
38
+ attn_l2_norm=False,
39
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
40
+ fused_if_available=True,
41
+ use_swiglu_ffn=False,
42
+ Cvae=32,
43
+ V=4096
44
+ ):
45
+ super().__init__()
46
+ # 0. hyperparameters
47
+ assert embed_dim % num_heads == 0
48
+ self.depth, self.C, self.D, self.num_heads = (
49
+ depth,
50
+ embed_dim,
51
+ embed_dim,
52
+ num_heads,
53
+ )
54
+ self.Cvae, self.V = Cvae, V
55
+
56
+ self.prog_si = -1 # progressive training
57
+
58
+ self.patch_nums: Tuple[int] = patch_nums
59
+ self.L = sum(pn**2 for pn in self.patch_nums)
60
+ self.first_l = self.patch_nums[0] ** 2
61
+ self.rope = rope
62
+
63
+ self.num_stages_minus_1 = len(self.patch_nums) - 1
64
+ self.rng = torch.Generator(device=dist.get_device())
65
+
66
+ # 1. input (word) embedding
67
+ self.word_embed = nn.Linear(self.Cvae, self.C)
68
+
69
+ # 2. text embedding
70
+ self.pooled_embed_size = 1280
71
+ context_dim = 1280 + 768
72
+
73
+ self.text_pooler = nn.Linear(self.pooled_embed_size, self.D)
74
+
75
+ init_std = math.sqrt(1 / self.C / 3)
76
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
77
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
78
+
79
+ # 3. position embedding
80
+ if not self.rope:
81
+ # absolute position embedding
82
+ pos_1LC = []
83
+ for i, pn in enumerate(self.patch_nums):
84
+ pe = torch.empty(1, pn * pn, self.C)
85
+ nn.init.trunc_normal_(pe, mean=0, std=init_std)
86
+ pos_1LC.append(pe)
87
+ pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
88
+ assert tuple(pos_1LC.shape) == (1, self.L, self.C)
89
+ self.pos_1LC = nn.Parameter(pos_1LC)
90
+ self.freqs_cis = None
91
+
92
+ else:
93
+ # RoPE position embedding
94
+ assert (
95
+ self.C // self.num_heads
96
+ ) % 4 == 0, "2d rope needs head dim to be divisible by 4"
97
+ patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums)
98
+ self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads)
99
+ freqs_cis = []
100
+ for i, pn in enumerate(self.patch_nums):
101
+ norm_coeff = rope_size / patch_nums_m1[i]
102
+ cur_freqs = self.compute_cis(
103
+ end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff
104
+ )
105
+ freqs_cis.append(cur_freqs[None, ...])
106
+ self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex
107
+
108
+ # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
109
+ self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
110
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
111
+
112
+ # 4. backbone blocks
113
+ self.shared_ada_lin = (
114
+ nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6 * self.C))
115
+ if shared_aln
116
+ else nn.Identity()
117
+ )
118
+
119
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
120
+ self.drop_path_rate = drop_path_rate
121
+ # stochastic depth decay rule (linearly increasing)
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
123
+ self.blocks = nn.ModuleList([])
124
+ for block_idx in range(depth):
125
+ self.blocks.append(
126
+ AdaLNSelfCrossAttn(
127
+ cond_dim=self.D,
128
+ shared_aln=shared_aln,
129
+ block_idx=block_idx,
130
+ embed_dim=self.C,
131
+ num_heads=num_heads,
132
+ mlp_ratio=mlp_ratio,
133
+ drop=drop_rate,
134
+ attn_drop=attn_drop_rate,
135
+ drop_path=dpr[block_idx],
136
+ last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
137
+ qk_norm=attn_l2_norm,
138
+ context_dim=context_dim,
139
+ use_swiglu_ffn=use_swiglu_ffn,
140
+ norm_eps=norm_eps,
141
+ )
142
+ )
143
+
144
+ fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
145
+ self.using_fused_add_norm_fn = any(fused_add_norm_fns)
146
+ print(
147
+ f"\n[constructor] ==== fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n"
148
+ f" [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n"
149
+ f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})",
150
+ end="\n\n",
151
+ flush=True,
152
+ )
153
+
154
+ # 5. attention mask used in training (for masking out the future)
155
+ # it won't be used in inference, since kv cache is enabled
156
+ d: torch.Tensor = torch.cat(
157
+ [torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]
158
+ ).view(1, self.L, 1)
159
+ dT = d.transpose(1, 2) # dT: 11L
160
+ lvl_1L = dT[:, 0].contiguous()
161
+ self.register_buffer("lvl_1L", lvl_1L)
162
+ attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf).reshape(
163
+ 1, 1, self.L, self.L
164
+ )
165
+ self.register_buffer(
166
+ "attn_bias_for_masking", attn_bias_for_masking.contiguous()
167
+ )
168
+
169
+ # 6. classifier head
170
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
171
+ self.head = nn.Linear(self.C, self.V)
172
+
173
+ # By defailt disable gradient checkpointing
174
+ self.use_gradient_checkpointing = False
175
+
176
+ def enable_gradient_checkpointing(self):
177
+ self.use_gradient_checkpointing = True
178
+
179
+ def disable_gradient_checkpointing(self):
180
+ self.use_gradient_checkpointing = False
181
+
182
+ def get_logits(
183
+ self,
184
+ h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
185
+ cond_BD: Optional[torch.Tensor],
186
+ ):
187
+ if not isinstance(h_or_h_and_residual, torch.Tensor):
188
+ h, resi = h_or_h_and_residual # fused_add_norm must be used
189
+ h = resi + self.blocks[-1].drop_path(h)
190
+ else: # fused_add_norm is not used
191
+ h = h_or_h_and_residual
192
+ return self.head(self.head_nm(h.float(), cond_BD).float()).float()
193
+
194
+ def parse_batch(self, batch, null_batch=None):
195
+ embedding_1 = batch["vit_l_14_text_embeddings"]
196
+ embedding_2 = batch["vit_bigg_14_text_embeddings"]
197
+ attention_mask = batch["vit_bigg_14_text_mask"]
198
+
199
+ batch_size = embedding_1.size(0)
200
+ prompt_embed = torch.concat([embedding_1, embedding_2], dim=-1)
201
+ prompt_lens = attention_mask.sum(dim=-1).to(int)
202
+ pooled_output = embedding_2[
203
+ torch.arange(batch_size, device=embedding_2.device), prompt_lens - 1
204
+ ]
205
+
206
+ attention_bias = attention_mask.clone()
207
+ attention_bias[attention_mask == 0] = -float("inf")
208
+ attention_bias[attention_mask == 1] = 0.0
209
+
210
+ if null_batch is not None:
211
+ B, L, hidden_dim = prompt_embed.shape
212
+ pooled_dim = pooled_output.shape[1]
213
+
214
+ null_context = null_batch['prompt_embed']
215
+ null_pooled_embed = null_batch['pooled_embed']
216
+ null_attn_bias = null_batch['attn_bias']
217
+
218
+ null_context = null_context[:, :L].expand(B, L, hidden_dim).to(prompt_embed.device)
219
+ null_pooled_embed = null_pooled_embed.expand(B, pooled_dim).to(pooled_output.device)
220
+ null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attention_bias.device)
221
+
222
+ prompt_embed = torch.cat([prompt_embed, null_context], dim=0)
223
+ pooled_output = torch.cat([pooled_output, null_pooled_embed], dim=0)
224
+ attention_bias = torch.cat([attention_bias, null_attn_bias], dim=0)
225
+
226
+ return (
227
+ prompt_embed.to(dist.get_device()),
228
+ pooled_output.to(dist.get_device()),
229
+ attention_bias.to(dist.get_device()),
230
+ )
231
+
232
+ def forward(
233
+ self,
234
+ x_BLCv_wo_first_l: torch.Tensor,
235
+ prompt_embeds: torch.Tensor,
236
+ pooled_prompt_embeds: torch.Tensor,
237
+ prompt_attn_bias: torch.Tensor,
238
+ ) -> torch.Tensor: # returns logits_BLV
239
+ """
240
+ :param batch: {'image': not used in forward,
241
+ 'text': image caption,
242
+ 'vit_l_14_text_embeddings': text embedding from CLIP-ViT-L-14
243
+ 'vit_bigg_14_text_embeddings': text embedding from CLIP-ViT-Big-G-14
244
+ 'vit_bigg_14_text_mask': attention mask to get a correct pooled embedding
245
+ :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
246
+ :return: logits BLV, V is vocab_size
247
+ """
248
+ bg, ed = 0, self.L
249
+ B = x_BLCv_wo_first_l.shape[0]
250
+ with torch.amp.autocast('cuda', enabled=False):
251
+ pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds)
252
+
253
+ sos = cond_BD = pooled_prompt_embeds
254
+ sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(
255
+ B, self.first_l, -1
256
+ )
257
+
258
+ x_BLC = torch.cat(
259
+ (sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1
260
+ )
261
+ x_BLC += self.lvl_embed(
262
+ self.lvl_1L[:, :ed].expand(B, -1)
263
+ ) # lvl: BLC; pos: 1LC
264
+ if not self.rope:
265
+ x_BLC += self.pos_1LC[:, :ed]
266
+ attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
267
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
268
+
269
+ # hack: get the dtype if mixed precision is used
270
+ temp = x_BLC.new_ones(8, 8)
271
+ main_type = torch.matmul(temp, temp).dtype
272
+
273
+ x_BLC = x_BLC.to(dtype=main_type)
274
+ cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
275
+ attn_bias = attn_bias.to(dtype=main_type)
276
+
277
+ for block in self.blocks:
278
+ if self.use_gradient_checkpointing:
279
+ x_BLC = torch.utils.checkpoint.checkpoint(
280
+ block,
281
+ x=x_BLC,
282
+ cond_BD=cond_BD_or_gss,
283
+ attn_bias=attn_bias,
284
+ context=prompt_embeds,
285
+ freqs_cis=self.freqs_cis,
286
+ context_attn_bias=prompt_attn_bias,
287
+ use_reentrant=False,
288
+ )
289
+ else:
290
+ x_BLC = block(
291
+ x=x_BLC,
292
+ cond_BD=cond_BD_or_gss,
293
+ attn_bias=attn_bias,
294
+ context=prompt_embeds,
295
+ freqs_cis=self.freqs_cis,
296
+ context_attn_bias=prompt_attn_bias,
297
+ )
298
+
299
+ with torch.amp.autocast('cuda', enabled=not self.training):
300
+ x_BLC = self.get_logits(x_BLC.float(), cond_BD)
301
+
302
+ return x_BLC # logits BLV, V is vocab_size
303
+
304
+ def init_weights(
305
+ self,
306
+ init_adaln=0.5,
307
+ init_adaln_gamma=1e-5,
308
+ init_head=0.02,
309
+ init_std=0.02,
310
+ ):
311
+ if init_std < 0:
312
+ init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
313
+
314
+ print(f"[init_weights] {type(self).__name__} with {init_std=:g}")
315
+ for m in self.modules():
316
+ with_weight = hasattr(m, "weight") and m.weight is not None
317
+ with_bias = hasattr(m, "bias") and m.bias is not None
318
+ if isinstance(m, nn.Linear):
319
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
320
+ if with_bias:
321
+ m.bias.data.zero_()
322
+ elif isinstance(m, nn.Embedding):
323
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
324
+ if m.padding_idx is not None:
325
+ m.weight.data[m.padding_idx].zero_()
326
+ elif isinstance(
327
+ m,
328
+ (
329
+ nn.LayerNorm,
330
+ nn.BatchNorm1d,
331
+ nn.BatchNorm2d,
332
+ nn.BatchNorm3d,
333
+ nn.SyncBatchNorm,
334
+ nn.GroupNorm,
335
+ nn.InstanceNorm1d,
336
+ nn.InstanceNorm2d,
337
+ nn.InstanceNorm3d,
338
+ ),
339
+ ):
340
+ if with_weight:
341
+ m.weight.data.fill_(1.0)
342
+ if with_bias:
343
+ m.bias.data.zero_()
344
+
345
+ if init_head >= 0:
346
+ if isinstance(self.head, nn.Linear):
347
+ self.head.weight.data.mul_(init_head)
348
+ self.head.bias.data.zero_()
349
+ elif isinstance(self.head, nn.Sequential):
350
+ self.head[-1].weight.data.mul_(init_head)
351
+ self.head[-1].bias.data.zero_()
352
+
353
+ if isinstance(self.head_nm, AdaLNBeforeHead):
354
+ self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
355
+ if (
356
+ hasattr(self.head_nm.ada_lin[-1], "bias")
357
+ and self.head_nm.ada_lin[-1].bias is not None
358
+ ):
359
+ self.head_nm.ada_lin[-1].bias.data.zero_()
360
+
361
+ depth = len(self.blocks)
362
+ for block in self.blocks:
363
+ block.attn.proj.weight.data.div_(math.sqrt(2 * depth))
364
+ block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth))
365
+ if hasattr(block.ffn, "fc2"):
366
+ block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
367
+
368
+ if hasattr(block, "ada_lin"):
369
+ block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln)
370
+ block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma)
371
+ if (
372
+ hasattr(block.ada_lin[-1], "bias")
373
+ and block.ada_lin[-1].bias is not None
374
+ ):
375
+ block.ada_lin[-1].bias.data.zero_()
376
+ elif hasattr(block, "ada_gss"):
377
+ block.ada_gss.data[:, :, 2:].mul_(init_adaln)
378
+ block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
379
+
380
+ def extra_repr(self):
381
+ return f"drop_path_rate={self.drop_path_rate:g}"
382
+
383
+
384
+ class TVARHF(VAR, PyTorchModelHubMixin):
385
+ # tags=["image-generation"]):
386
+ def __init__(
387
+ self,
388
+ depth=30,
389
+ shared_aln=False,
390
+ attn_l2_norm=True,
391
+ rope=True,
392
+ rope_theta=10000,
393
+ rope_size=128,
394
+ use_swiglu_ffn=True,
395
+ ):
396
+ heads = depth
397
+ width = depth * 64
398
+ super().__init__(
399
+ depth=depth,
400
+ embed_dim=width,
401
+ num_heads=heads,
402
+ drop_rate=0.0,
403
+ attn_drop_rate=0.0,
404
+ norm_eps=1e-6,
405
+ shared_aln=shared_aln,
406
+ attn_l2_norm=attn_l2_norm,
407
+ patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
408
+ rope=rope,
409
+ rope_theta=rope_theta,
410
+ rope_size=rope_size,
411
+ use_swiglu_ffn=use_swiglu_ffn,
412
+ )
models/vqvae.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
4
+ - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
5
+ - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+
14
+ from .basic_vae import Decoder, Encoder
15
+ from .quant import VectorQuantizer2
16
+
17
+
18
+
19
+ class VQVAE(nn.Module):
20
+ def __init__(
21
+ self,
22
+ vocab_size=4096,
23
+ z_channels=32,
24
+ ch=128,
25
+ dropout=0.0,
26
+ beta=0.25, # commitment loss weight
27
+ using_znorm=False, # whether to normalize when computing the nearest neighbors
28
+ quant_conv_ks=3, # quant conv kernel size
29
+ quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
30
+ share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
31
+ default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
32
+ # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
33
+ v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
34
+ test_mode=True,
35
+ ):
36
+ super().__init__()
37
+ self.test_mode = test_mode
38
+ self.V, self.Cvae = vocab_size, z_channels
39
+ # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
40
+ ddconfig = dict(
41
+ dropout=dropout,
42
+ ch=ch,
43
+ z_channels=z_channels,
44
+ in_channels=3,
45
+ ch_mult=(1, 1, 2, 2, 4),
46
+ num_res_blocks=2, # from vq-f16/config.yaml above
47
+ using_sa=True,
48
+ using_mid_sa=True, # from vq-f16/config.yaml above
49
+ # resamp_with_conv=True, # always True, removed.
50
+ )
51
+ ddconfig.pop("double_z", None) # only KL-VAE should use double_z=True
52
+ self.encoder = Encoder(double_z=False, **ddconfig)
53
+ self.decoder = Decoder(**ddconfig)
54
+
55
+ self.vocab_size = vocab_size
56
+ self.downsample = 2 ** (len(ddconfig["ch_mult"]) - 1)
57
+ self.quantize: VectorQuantizer2 = VectorQuantizer2(
58
+ vocab_size=vocab_size,
59
+ Cvae=self.Cvae,
60
+ using_znorm=using_znorm,
61
+ beta=beta,
62
+ default_qresi_counts=default_qresi_counts,
63
+ v_patch_nums=v_patch_nums,
64
+ quant_resi=quant_resi,
65
+ share_quant_resi=share_quant_resi,
66
+ )
67
+ self.quant_conv = torch.nn.Conv2d(
68
+ self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
69
+ )
70
+ self.post_quant_conv = torch.nn.Conv2d(
71
+ self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks // 2
72
+ )
73
+
74
+ if self.test_mode:
75
+ self.eval()
76
+ [p.requires_grad_(False) for p in self.parameters()]
77
+
78
+ # ===================== `forward` is only used in VAE training =====================
79
+ def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
80
+ VectorQuantizer2.forward
81
+ f_hat, usages, vq_loss = self.quantize(
82
+ self.quant_conv(self.encoder(inp)), ret_usages=ret_usages
83
+ )
84
+ return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
85
+
86
+ # ===================== `forward` is only used in VAE training =====================
87
+
88
+ def fhat_to_img(self, f_hat: torch.Tensor):
89
+ return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
90
+
91
+ def img_to_idxBl(
92
+ self,
93
+ inp_img_no_grad: torch.Tensor,
94
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
95
+ noise_std: Optional[float] = None,
96
+ ) -> List[torch.LongTensor]: # return List[Bl]
97
+ f = self.quant_conv(self.encoder(inp_img_no_grad))
98
+ return self.quantize.f_to_idxBl_or_fhat(
99
+ f, to_fhat=False, v_patch_nums=v_patch_nums, noise_std=noise_std,
100
+ )
101
+
102
+ def idxBl_to_img(
103
+ self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False
104
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
105
+ B = ms_idx_Bl[0].shape[0]
106
+ ms_h_BChw = []
107
+ for idx_Bl in ms_idx_Bl:
108
+ l = idx_Bl.shape[1]
109
+ pn = round(l**0.5)
110
+ ms_h_BChw.append(
111
+ self.quantize.embedding(idx_Bl)
112
+ .transpose(1, 2)
113
+ .view(B, self.Cvae, pn, pn)
114
+ )
115
+ return self.embed_to_img(
116
+ ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one
117
+ )
118
+
119
+ def embed_to_img(
120
+ self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False
121
+ ) -> Union[List[torch.Tensor], torch.Tensor]:
122
+ if last_one:
123
+ return self.decoder(
124
+ self.post_quant_conv(
125
+ self.quantize.embed_to_fhat(
126
+ ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True
127
+ )
128
+ )
129
+ ).clamp_(-1, 1)
130
+ else:
131
+ return [
132
+ self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
133
+ for f_hat in self.quantize.embed_to_fhat(
134
+ ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False
135
+ )
136
+ ]
137
+
138
+ def img_to_reconstructed_img(
139
+ self,
140
+ x,
141
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
142
+ last_one=False,
143
+ ) -> List[torch.Tensor]:
144
+ f = self.quant_conv(self.encoder(x))
145
+ ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(
146
+ f, to_fhat=True, v_patch_nums=v_patch_nums
147
+ )
148
+ if last_one:
149
+ return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
150
+ else:
151
+ return [
152
+ self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
153
+ for f_hat in ls_f_hat_BChw
154
+ ]
155
+
156
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
157
+ if (
158
+ "quantize.ema_vocab_hit_SV" in state_dict
159
+ and state_dict["quantize.ema_vocab_hit_SV"].shape[0]
160
+ != self.quantize.ema_vocab_hit_SV.shape[0]
161
+ ):
162
+ state_dict["quantize.ema_vocab_hit_SV"] = self.quantize.ema_vocab_hit_SV
163
+ return super().load_state_dict(
164
+ state_dict=state_dict, strict=strict, assign=assign
165
+ )
166
+
167
+ class VQVAEHF(VQVAE, PyTorchModelHubMixin):
168
+ def __init__(
169
+ self,
170
+ vocab_size=4096,
171
+ z_channels=32,
172
+ ch=160,
173
+ test_mode=True,
174
+ share_quant_resi=4,
175
+ v_patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
176
+ ):
177
+ super().__init__(
178
+ vocab_size=vocab_size,
179
+ z_channels=z_channels,
180
+ ch=ch,
181
+ test_mode=True,
182
+ share_quant_resi=4,
183
+ v_patch_nums=v_patch_nums,
184
+ )