yeelou commited on
Commit
c6572a6
·
verified ·
1 Parent(s): bcfa187

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/content/d2c_hf_bin",
3
+ "architectures": [
4
+ "CogAgentForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cogagent.CogAgentConfig",
8
+ "AutoModelForCausalLM": "modeling_cogagent.CogAgentForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "cross_compute_hidden_size": 1024,
12
+ "cross_hidden_size": 1024,
13
+ "cross_image_size": 1120,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 4096,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 11008,
19
+ "max_position_embeddings": 2048,
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "pad_token_id": 0,
23
+ "quantization_config": {
24
+ "_load_in_4bit": false,
25
+ "_load_in_8bit": true,
26
+ "bnb_4bit_compute_dtype": "float32",
27
+ "bnb_4bit_quant_storage": "uint8",
28
+ "bnb_4bit_quant_type": "fp4",
29
+ "bnb_4bit_use_double_quant": false,
30
+ "llm_int8_enable_fp32_cpu_offload": false,
31
+ "llm_int8_has_fp16_weight": false,
32
+ "llm_int8_skip_modules": null,
33
+ "llm_int8_threshold": 6.0,
34
+ "load_in_4bit": false,
35
+ "load_in_8bit": true,
36
+ "quant_method": "bitsandbytes"
37
+ },
38
+ "rms_norm_eps": 1e-05,
39
+ "template_version": "chat",
40
+ "tie_word_embeddings": false,
41
+ "torch_dtype": "float16",
42
+ "transformers_version": "4.41.0.dev0",
43
+ "use_cache": true,
44
+ "vision_config": {
45
+ "dropout_prob": 0.0,
46
+ "hidden_act": "gelu",
47
+ "hidden_size": 1792,
48
+ "image_size": 224,
49
+ "in_channels": 3,
50
+ "intermediate_size": 15360,
51
+ "layer_norm_eps": 1e-06,
52
+ "num_heads": 16,
53
+ "num_hidden_layers": 63,
54
+ "num_positions": 257,
55
+ "patch_size": 14
56
+ },
57
+ "vocab_size": 32000
58
+ }
configuration_cogagent.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class CogAgentConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ cross_hidden_size=1024,
13
+ cross_compute_hidden_size=1024,
14
+ cross_image_size=1120,
15
+ intermediate_size=11008,
16
+ num_hidden_layers=32,
17
+ num_attention_heads=32,
18
+ hidden_act='silu',
19
+ max_position_embeddings=2048,
20
+ initializer_range=0.02,
21
+ rms_norm_eps=1e-06,
22
+ template_version: Literal["base", "chat"] = "chat",
23
+
24
+ pad_token_id=0,
25
+ bos_token_id=1,
26
+ eos_token_id=2,
27
+ tie_word_embeddings=False,
28
+ use_cache=True,
29
+ **kwargs,
30
+ ):
31
+ self.hidden_size = hidden_size
32
+ self.cross_hidden_size = cross_hidden_size
33
+ self.cross_compute_hidden_size = cross_compute_hidden_size
34
+ self.cross_image_size = cross_image_size
35
+ self.intermediate_size = intermediate_size
36
+ self.num_attention_heads = num_attention_heads
37
+ self.max_position_embeddings = max_position_embeddings
38
+ self.rms_norm_eps = rms_norm_eps
39
+ self.initializer_range = initializer_range
40
+ self.vocab_size = vocab_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.hidden_act = hidden_act
43
+ self.template_version = template_version
44
+ self.use_cache = use_cache
45
+ super().__init__(
46
+ pad_token_id=pad_token_id,
47
+ bos_token_id=bos_token_id,
48
+ eos_token_id=eos_token_id,
49
+ tie_word_embeddings=tie_word_embeddings,
50
+ **kwargs,
51
+ )
cross_visual.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ import torch
3
+ from torch import nn
4
+ from einops import rearrange, repeat
5
+ import logging
6
+
7
+ def broadcat(tensors, dim = -1):
8
+ num_tensors = len(tensors)
9
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11
+ shape_len = list(shape_lens)[0]
12
+ dim = (dim + shape_len) if dim < 0 else dim
13
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18
+ expanded_dims.insert(dim, (dim, dims[dim]))
19
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21
+ return torch.cat(tensors, dim = dim)
22
+
23
+ def rotate_half(x):
24
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
25
+ x1, x2 = x.unbind(dim = -1)
26
+ x = torch.stack((-x2, x1), dim = -1)
27
+ return rearrange(x, '... d r -> ... (d r)')
28
+
29
+ class VisionRotaryEmbeddingFast(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim,
33
+ pt_seq_len,
34
+ ft_seq_len=None,
35
+ custom_freqs = None,
36
+ freqs_for = 'lang',
37
+ theta = 10000,
38
+ max_freq = 10,
39
+ num_freqs = 1,
40
+ patch_dropout = 0.
41
+ ):
42
+ super().__init__()
43
+ if custom_freqs:
44
+ freqs = custom_freqs
45
+ elif freqs_for == 'lang':
46
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47
+ elif freqs_for == 'pixel':
48
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49
+ elif freqs_for == 'constant':
50
+ freqs = torch.ones(num_freqs).float()
51
+ else:
52
+ raise ValueError(f'unknown modality {freqs_for}')
53
+
54
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
55
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56
+
57
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
58
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
59
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
60
+
61
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
62
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
63
+
64
+ self.patch_dropout = patch_dropout
65
+
66
+ self.register_buffer("freqs_cos", freqs_cos)
67
+ self.register_buffer("freqs_sin", freqs_sin)
68
+
69
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
70
+
71
+ def forward(self, t, patch_indices_keep=None):
72
+ if patch_indices_keep is not None:
73
+ batch = t.size()[0]
74
+ batch_indices = torch.arange(batch)
75
+ batch_indices = batch_indices[..., None]
76
+
77
+ freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
78
+ freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
79
+
80
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
81
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
82
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
83
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
84
+
85
+ return t * freqs_cos + rotate_half(t) * freqs_sin
86
+
87
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
88
+
89
+ import torch.nn as nn
90
+ import os
91
+ from dataclasses import dataclass
92
+ from typing import Optional, Tuple, Union
93
+ from functools import partial
94
+
95
+ import numpy as np
96
+ import torch
97
+ import torch.nn.functional as F
98
+ from torch import nn
99
+
100
+ # --------------------------------------------------------
101
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
102
+ # --------------------------------------------------------
103
+ import math
104
+ import os
105
+ from functools import partial
106
+ import torch
107
+ import torch.nn as nn
108
+ import torch.nn.functional as F
109
+ import logging
110
+ try:
111
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
112
+ except:
113
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
114
+
115
+ class PatchDropout(nn.Module):
116
+ """
117
+ https://arxiv.org/abs/2212.00794
118
+ """
119
+
120
+ def __init__(self, prob, exclude_first_token=True):
121
+ super().__init__()
122
+ assert 0 <= prob < 1.
123
+ self.prob = prob
124
+ self.exclude_first_token = exclude_first_token # exclude CLS token
125
+ logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
126
+
127
+ def forward(self, x):
128
+ if not self.training or self.prob == 0.:
129
+ return x
130
+
131
+ if self.exclude_first_token:
132
+ cls_tokens, x = x[:, :1], x[:, 1:]
133
+ else:
134
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
135
+
136
+ batch = x.size()[0]
137
+ num_tokens = x.size()[1]
138
+
139
+ batch_indices = torch.arange(batch)
140
+ batch_indices = batch_indices[..., None]
141
+
142
+ keep_prob = 1 - self.prob
143
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
144
+
145
+ rand = torch.randn(batch, num_tokens)
146
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
147
+
148
+ x = x[batch_indices, patch_indices_keep]
149
+
150
+ if self.exclude_first_token:
151
+ x = torch.cat((cls_tokens, x), dim=1)
152
+
153
+ if self.training and os.getenv('RoPE') == '1':
154
+ return x, patch_indices_keep
155
+
156
+ return x
157
+
158
+ if os.getenv('ENV_TYPE') == 'deepspeed':
159
+ try:
160
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
161
+ except:
162
+ from torch.utils.checkpoint import checkpoint
163
+ else:
164
+ from torch.utils.checkpoint import checkpoint
165
+
166
+ import xformers.ops as xops
167
+
168
+ class DropPath(nn.Module):
169
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
170
+ """
171
+ def __init__(self, drop_prob=None):
172
+ super(DropPath, self).__init__()
173
+ self.drop_prob = drop_prob
174
+
175
+ def forward(self, x):
176
+ return drop_path(x, self.drop_prob, self.training)
177
+
178
+ def extra_repr(self) -> str:
179
+ return 'p={}'.format(self.drop_prob)
180
+
181
+
182
+ class Mlp(nn.Module):
183
+ def __init__(
184
+ self,
185
+ in_features,
186
+ hidden_features=None,
187
+ out_features=None,
188
+ act_layer=nn.GELU,
189
+ norm_layer=nn.LayerNorm,
190
+ drop=0.,
191
+ subln=False,
192
+
193
+ ):
194
+ super().__init__()
195
+ out_features = out_features or in_features
196
+ hidden_features = hidden_features or in_features
197
+ self.fc1 = nn.Linear(in_features, hidden_features)
198
+ self.act = act_layer()
199
+
200
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
201
+
202
+ self.fc2 = nn.Linear(hidden_features, out_features)
203
+ self.drop = nn.Dropout(drop)
204
+
205
+ def forward(self, x):
206
+ x = self.fc1(x)
207
+ x = self.act(x)
208
+ # x = self.drop(x)
209
+ # commit this for the orignal BERT implement
210
+ x = self.ffn_ln(x)
211
+
212
+ x = self.fc2(x)
213
+ x = self.drop(x)
214
+ return x
215
+
216
+ class SwiGLU(nn.Module):
217
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
218
+ norm_layer=nn.LayerNorm, subln=False):
219
+ super().__init__()
220
+ out_features = out_features or in_features
221
+ hidden_features = hidden_features or in_features
222
+
223
+ self.w1 = nn.Linear(in_features, hidden_features)
224
+ self.w2 = nn.Linear(in_features, hidden_features)
225
+
226
+ self.act = act_layer()
227
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
228
+ self.w3 = nn.Linear(hidden_features, out_features)
229
+
230
+ self.drop = nn.Dropout(drop)
231
+
232
+ def forward(self, x):
233
+ x1 = self.w1(x)
234
+ x2 = self.w2(x)
235
+ hidden = self.act(x1) * x2
236
+ x = self.ffn_ln(hidden)
237
+ x = self.w3(x)
238
+ x = self.drop(x)
239
+ return x
240
+
241
+ class Attention(nn.Module):
242
+ def __init__(
243
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
244
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
245
+ super().__init__()
246
+ self.num_heads = num_heads
247
+ head_dim = dim // num_heads
248
+ if attn_head_dim is not None:
249
+ head_dim = attn_head_dim
250
+ all_head_dim = head_dim * self.num_heads
251
+ self.scale = qk_scale or head_dim ** -0.5
252
+
253
+ self.subln = subln
254
+ if self.subln:
255
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
256
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
257
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
258
+ else:
259
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
260
+
261
+ if qkv_bias:
262
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
263
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
264
+ else:
265
+ self.q_bias = None
266
+ self.v_bias = None
267
+
268
+ if window_size:
269
+ self.window_size = window_size
270
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
271
+ self.relative_position_bias_table = nn.Parameter(
272
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
273
+ # cls to token & token 2 cls & cls to cls
274
+
275
+ # get pair-wise relative position index for each token inside the window
276
+ coords_h = torch.arange(window_size[0])
277
+ coords_w = torch.arange(window_size[1])
278
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
279
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
280
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
281
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
282
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
283
+ relative_coords[:, :, 1] += window_size[1] - 1
284
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
285
+ relative_position_index = \
286
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
287
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
288
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
289
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
290
+ relative_position_index[0, 0] = self.num_relative_distance - 1
291
+
292
+ self.register_buffer("relative_position_index", relative_position_index)
293
+ else:
294
+ self.window_size = None
295
+ self.relative_position_bias_table = None
296
+ self.relative_position_index = None
297
+
298
+ self.attn_drop = nn.Dropout(attn_drop)
299
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
300
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
301
+ self.proj = nn.Linear(all_head_dim, dim)
302
+ self.proj_drop = nn.Dropout(proj_drop)
303
+ self.xattn = xattn
304
+ self.xattn_drop = attn_drop
305
+
306
+ self.rope = rope
307
+
308
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
309
+ B, N, C = x.shape
310
+ if self.subln:
311
+ if self.q_proj.weight.dtype == torch.uint8:
312
+ import bitsandbytes as bnb
313
+ q = bnb.matmul_4bit(x, self.q_proj.weight.t(), bias=self.q_bias, quant_state=self.q_proj.weight.quant_state)
314
+ k = bnb.matmul_4bit(x, self.k_proj.weight.t(), bias=None, quant_state=self.k_proj.weight.quant_state)
315
+ v = bnb.matmul_4bit(x, self.v_proj.weight.t(), bias=self.v_bias, quant_state=self.v_proj.weight.quant_state)
316
+ else:
317
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
318
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
319
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
320
+
321
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
322
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
323
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
324
+ else:
325
+
326
+ qkv_bias = None
327
+ if self.q_bias is not None:
328
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
329
+
330
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
331
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
332
+ q, k, v = qkv[0], qkv[1], qkv[2]
333
+
334
+ if self.rope:
335
+ # slightly fast impl
336
+ q_t = q[:, :, 1:, :]
337
+ ro_q_t = self.rope(q_t)
338
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
339
+
340
+ k_t = k[:, :, 1:, :]
341
+ ro_k_t = self.rope(k_t)
342
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
343
+
344
+ if self.xattn:
345
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
346
+ k = k.permute(0, 2, 1, 3)
347
+ v = v.permute(0, 2, 1, 3)
348
+
349
+ x = xops.memory_efficient_attention(
350
+ q, k, v,
351
+ p=self.xattn_drop,
352
+ scale=self.scale,
353
+ )
354
+ x = x.reshape(B, N, -1)
355
+ x = self.inner_attn_ln(x)
356
+ x = self.proj(x)
357
+ x = self.proj_drop(x)
358
+ else:
359
+ q = q * self.scale
360
+ attn = (q @ k.transpose(-2, -1))
361
+
362
+ if self.relative_position_bias_table is not None:
363
+ relative_position_bias = \
364
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
365
+ self.window_size[0] * self.window_size[1] + 1,
366
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
369
+
370
+ if rel_pos_bias is not None:
371
+ attn = attn + rel_pos_bias.type_as(attn)
372
+
373
+ if attn_mask is not None:
374
+ attn_mask = attn_mask.bool()
375
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
376
+
377
+ attn = attn.softmax(dim=-1)
378
+ attn = self.attn_drop(attn)
379
+
380
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
381
+ x = self.inner_attn_ln(x)
382
+ x = self.proj(x)
383
+ x = self.proj_drop(x)
384
+ return x
385
+
386
+
387
+ class Block(nn.Module):
388
+
389
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
390
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
391
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
392
+ subln=False, naiveswiglu=False):
393
+ super().__init__()
394
+ self.norm1 = norm_layer(dim)
395
+ self.attn = Attention(
396
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
397
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
398
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
399
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
400
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
401
+ self.norm2 = norm_layer(dim)
402
+ mlp_hidden_dim = int(dim * mlp_ratio)
403
+
404
+ if naiveswiglu:
405
+ self.mlp = SwiGLU(
406
+ in_features=dim,
407
+ hidden_features=mlp_hidden_dim,
408
+ subln=subln,
409
+ norm_layer=norm_layer,
410
+ )
411
+ else:
412
+ self.mlp = Mlp(
413
+ in_features=dim,
414
+ hidden_features=mlp_hidden_dim,
415
+ act_layer=act_layer,
416
+ subln=subln,
417
+ drop=drop
418
+ )
419
+
420
+ if init_values is not None and init_values > 0:
421
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
422
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
423
+ else:
424
+ self.gamma_1, self.gamma_2 = None, None
425
+
426
+ self.postnorm = postnorm
427
+
428
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
429
+ if self.gamma_1 is None:
430
+ if self.postnorm:
431
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
432
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
433
+ else:
434
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
435
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
436
+ else:
437
+ if self.postnorm:
438
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
439
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
440
+ else:
441
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
442
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
443
+ return x
444
+
445
+
446
+ class PatchEmbed(nn.Module):
447
+ """ Image to Patch Embedding
448
+ """
449
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
450
+ super().__init__()
451
+ img_size = to_2tuple(img_size)
452
+ patch_size = to_2tuple(patch_size)
453
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
454
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
455
+ self.img_size = img_size
456
+ self.patch_size = patch_size
457
+ self.num_patches = num_patches
458
+
459
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
460
+
461
+ def forward(self, x, **kwargs):
462
+ B, C, H, W = x.shape
463
+ # FIXME look at relaxing size constraints
464
+ assert H == self.img_size[0] and W == self.img_size[1], \
465
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
466
+ x = self.proj(x).flatten(2).transpose(1, 2)
467
+ return x
468
+
469
+
470
+ class RelativePositionBias(nn.Module):
471
+
472
+ def __init__(self, window_size, num_heads):
473
+ super().__init__()
474
+ self.window_size = window_size
475
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
476
+ self.relative_position_bias_table = nn.Parameter(
477
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
478
+ # cls to token & token 2 cls & cls to cls
479
+
480
+ # get pair-wise relative position index for each token inside the window
481
+ coords_h = torch.arange(window_size[0])
482
+ coords_w = torch.arange(window_size[1])
483
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
484
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
485
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
486
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
487
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
488
+ relative_coords[:, :, 1] += window_size[1] - 1
489
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
490
+ relative_position_index = \
491
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
492
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
493
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
494
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
495
+ relative_position_index[0, 0] = self.num_relative_distance - 1
496
+
497
+ self.register_buffer("relative_position_index", relative_position_index)
498
+
499
+ def forward(self):
500
+ relative_position_bias = \
501
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
502
+ self.window_size[0] * self.window_size[1] + 1,
503
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
504
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
505
+
506
+
507
+ class EVAVisionTransformer(nn.Module):
508
+ """ Vision Transformer with support for patch or hybrid CNN input stage
509
+ """
510
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
511
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
512
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
513
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
514
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
515
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
516
+ super().__init__()
517
+ self.image_size = img_size
518
+ self.num_classes = num_classes
519
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
520
+
521
+ self.patch_embed = PatchEmbed(
522
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
523
+ num_patches = self.patch_embed.num_patches
524
+
525
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
526
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
527
+ if use_abs_pos_emb:
528
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
529
+ else:
530
+ self.pos_embed = None
531
+ self.pos_drop = nn.Dropout(p=drop_rate)
532
+
533
+ if use_shared_rel_pos_bias:
534
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
535
+ else:
536
+ self.rel_pos_bias = None
537
+
538
+ if rope:
539
+ half_head_dim = embed_dim // num_heads // 2
540
+ hw_seq_len = img_size // patch_size
541
+ self.rope = VisionRotaryEmbeddingFast(
542
+ dim=half_head_dim,
543
+ pt_seq_len=pt_hw_seq_len,
544
+ ft_seq_len=hw_seq_len if intp_freq else None,
545
+ # patch_dropout=patch_dropout
546
+ )
547
+ else:
548
+ self.rope = None
549
+
550
+ self.naiveswiglu = naiveswiglu
551
+
552
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
553
+ self.use_rel_pos_bias = use_rel_pos_bias
554
+ self.blocks = nn.ModuleList([
555
+ Block(
556
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
557
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
558
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
559
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
560
+ for i in range(depth)])
561
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
562
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
563
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
564
+
565
+ if self.pos_embed is not None:
566
+ trunc_normal_(self.pos_embed, std=.02)
567
+
568
+ trunc_normal_(self.cls_token, std=.02)
569
+ # trunc_normal_(self.mask_token, std=.02)
570
+
571
+ self.apply(self._init_weights)
572
+ self.fix_init_weight()
573
+
574
+ if isinstance(self.head, nn.Linear):
575
+ trunc_normal_(self.head.weight, std=.02)
576
+ self.head.weight.data.mul_(init_scale)
577
+ self.head.bias.data.mul_(init_scale)
578
+
579
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
580
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
581
+
582
+ self.grad_checkpointing = grad_checkpointing
583
+
584
+ def fix_init_weight(self):
585
+ def rescale(param, layer_id):
586
+ param.div_(math.sqrt(2.0 * layer_id))
587
+
588
+ for layer_id, layer in enumerate(self.blocks):
589
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
590
+ if self.naiveswiglu:
591
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
592
+ else:
593
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
594
+
595
+ def get_cast_dtype(self) -> torch.dtype:
596
+ return self.blocks[0].mlp.fc2.weight.dtype
597
+
598
+ def _init_weights(self, m):
599
+ if isinstance(m, nn.Linear):
600
+ trunc_normal_(m.weight, std=.02)
601
+ if m.bias is not None:
602
+ nn.init.constant_(m.bias, 0)
603
+ elif isinstance(m, nn.LayerNorm):
604
+ nn.init.constant_(m.bias, 0)
605
+ nn.init.constant_(m.weight, 1.0)
606
+
607
+ def get_num_layers(self):
608
+ return len(self.blocks)
609
+
610
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
611
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
612
+ for param in self.parameters():
613
+ param.requires_grad = False
614
+
615
+ @torch.jit.ignore
616
+ def set_grad_checkpointing(self, enable=True):
617
+ self.grad_checkpointing = enable
618
+
619
+ @torch.jit.ignore
620
+ def no_weight_decay(self):
621
+ return {'pos_embed', 'cls_token'}
622
+
623
+ def get_classifier(self):
624
+ return self.head
625
+
626
+ def reset_classifier(self, num_classes, global_pool=''):
627
+ self.num_classes = num_classes
628
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
629
+
630
+ def forward_features(self, x, return_all_features=False):
631
+
632
+ x = self.patch_embed(x)
633
+ batch_size, seq_len, _ = x.size()
634
+
635
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
636
+ x = torch.cat((cls_tokens, x), dim=1)
637
+ if self.pos_embed is not None:
638
+ x = x + self.pos_embed
639
+ x = self.pos_drop(x)
640
+
641
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
642
+ if os.getenv('RoPE') == '1':
643
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
644
+ x, patch_indices_keep = self.patch_dropout(x)
645
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
646
+ else:
647
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
648
+ x = self.patch_dropout(x)
649
+ else:
650
+ x = self.patch_dropout(x)
651
+
652
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
653
+ for i, blk in enumerate(self.blocks):
654
+ if i == len(self.blocks)-1:
655
+ continue
656
+ if self.grad_checkpointing:
657
+ x = checkpoint(blk, x, (rel_pos_bias,))
658
+ else:
659
+ x = blk(x, rel_pos_bias=rel_pos_bias)
660
+
661
+ if not return_all_features:
662
+ x = self.norm(x)
663
+ if self.fc_norm is not None:
664
+ return self.fc_norm(x.mean(1))
665
+ else:
666
+ return x[:, 0]
667
+ return x
668
+
669
+ def forward(self, x, return_all_features=False):
670
+ if return_all_features:
671
+ return self.forward_features(x, return_all_features)
672
+ x = self.forward_features(x)
673
+ x = self.head(x)
674
+ return x
675
+
676
+ class LayerNorm(nn.LayerNorm):
677
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
678
+
679
+ def forward(self, x: torch.Tensor):
680
+ orig_type = x.dtype
681
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
682
+ return x.to(orig_type)
683
+
684
+ try:
685
+ from apex.normalization import FusedLayerNorm
686
+ except:
687
+ FusedLayerNorm = LayerNorm
688
+ print("Please 'pip install apex'")
689
+
690
+
691
+ @dataclass
692
+ class CLIPVisionCfg:
693
+ layers: Union[Tuple[int, int, int, int], int] = 12
694
+ width: int = 768
695
+ head_width: int = 64
696
+ mlp_ratio: float = 4.0
697
+ patch_size: int = 16
698
+ image_size: Union[Tuple[int, int], int] = 224
699
+ ls_init_value: Optional[float] = None # layer scale initial value
700
+ patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
701
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
702
+ drop_path_rate: Optional[float] = None # drop path rate
703
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
704
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
705
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
706
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
707
+ timm_proj_bias: bool = False # enable bias final projection
708
+ eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
709
+ qkv_bias: bool = True
710
+ fusedLN: bool = False
711
+ xattn: bool = False
712
+ postnorm: bool = False
713
+ rope: bool = False
714
+ pt_hw_seq_len: int = 16 # 224/14
715
+ intp_freq: bool = False
716
+ naiveswiglu: bool = False
717
+ subln: bool = False
718
+
719
+
720
+ def _build_vision_tower(
721
+ embed_dim: int,
722
+ vision_cfg: CLIPVisionCfg
723
+ ):
724
+ if isinstance(vision_cfg, dict):
725
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
726
+
727
+ if vision_cfg.eva_model_name:
728
+ vision_heads = vision_cfg.width // vision_cfg.head_width
729
+ norm_layer = LayerNorm
730
+ visual = EVAVisionTransformer(
731
+ img_size=vision_cfg.image_size,
732
+ patch_size=vision_cfg.patch_size,
733
+ num_classes=embed_dim,
734
+ use_mean_pooling=vision_cfg.global_average_pool, #False
735
+ init_values=vision_cfg.ls_init_value,
736
+ patch_dropout=vision_cfg.patch_dropout,
737
+ embed_dim=vision_cfg.width,
738
+ depth=vision_cfg.layers,
739
+ num_heads=vision_heads,
740
+ mlp_ratio=vision_cfg.mlp_ratio,
741
+ qkv_bias=vision_cfg.qkv_bias,
742
+ drop_path_rate=vision_cfg.drop_path_rate,
743
+ norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
744
+ xattn=vision_cfg.xattn,
745
+ rope=vision_cfg.rope,
746
+ postnorm=vision_cfg.postnorm,
747
+ pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
748
+ intp_freq= vision_cfg.intp_freq,
749
+ naiveswiglu= vision_cfg.naiveswiglu,
750
+ subln= vision_cfg.subln
751
+ )
752
+
753
+ return visual
754
+
755
+ class Eva2LargeEncoder(nn.Module):
756
+ def __init__(self, image_size=224):
757
+ super(Eva2LargeEncoder, self).__init__()
758
+ self.config = {
759
+ "embed_dim": 768,
760
+ "vision_cfg": {
761
+ "image_size": 336,
762
+ "layers": 24,
763
+ "width": 1024,
764
+ "drop_path_rate": 0,
765
+ "head_width": 64,
766
+ "mlp_ratio": 2.6667,
767
+ "patch_size": 14,
768
+ "eva_model_name": "eva-clip-l-14-336",
769
+ "xattn": True,
770
+ "fusedLN": True,
771
+ "rope": True,
772
+ "pt_hw_seq_len": 16,
773
+ "intp_freq": True,
774
+ "naiveswiglu": True,
775
+ "subln": True
776
+ }
777
+ }
778
+ self.config['vision_cfg']['image_size'] = image_size
779
+
780
+ import os
781
+ os.environ['delRoPE'] = '1' # to avoid error in rope params when changing image size
782
+ self.model = _build_vision_tower(**self.config)
783
+
784
+
785
+ def forward(self, images):
786
+ encode = self.model(images, return_all_features=True)[:, 1:, :]
787
+ return encode
788
+
789
+ class CrossVisionModel(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.vit = Eva2LargeEncoder(image_size=config.cross_image_size)
793
+ self.pos_embed = nn.Parameter(torch.zeros((self.vit.config['vision_cfg']['image_size'] // self.vit.config['vision_cfg']['patch_size']) ** 2, self.vit.config['vision_cfg']['width']))
794
+
795
+ def forward(self, images):
796
+ enc = self.vit(images)
797
+ return enc + self.pos_embed.to(enc.device).unsqueeze(0)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.41.0.dev0"
7
+ }
modeling_cogagent.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """largely copy from llama and adapt for CogAgent"""
2
+ import warnings
3
+ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
+
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+
12
+ from transformers import PreTrainedModel, PreTrainedTokenizer
13
+ from transformers.utils.logging import get_logger
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
16
+
17
+ from .configuration_cogagent import CogAgentConfig
18
+ # from .util import FastRotaryEmbedding
19
+ from torch.nn import functional as F
20
+ from .visual import EVA2CLIPModel
21
+ from .cross_visual import CrossVisionModel
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers.utils import ModelOutput
25
+
26
+ logger = get_logger(__name__)
27
+
28
+ LANGUAGE_TOKEN_TYPE = 0
29
+ VISION_TOKEN_TYPE = 1
30
+
31
+
32
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
33
+ def _make_causal_mask(
34
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
35
+ ):
36
+ """
37
+ Make causal mask used for bi-directional self-attention.
38
+ """
39
+ bsz, tgt_len = input_ids_shape
40
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
41
+ mask_cond = torch.arange(mask.size(-1), device=device)
42
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
43
+ mask = mask.to(dtype)
44
+
45
+ if past_key_values_length > 0:
46
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
47
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
48
+
49
+
50
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
51
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
52
+ """
53
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
54
+ """
55
+ bsz, src_len = mask.size()
56
+ tgt_len = tgt_len if tgt_len is not None else src_len
57
+
58
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
59
+
60
+ inverted_mask = 1.0 - expanded_mask
61
+
62
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
63
+
64
+
65
+ class RMSNorm(nn.Module):
66
+ def __init__(self, hidden_size, eps=1e-6):
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return (self.weight * hidden_states).to(input_dtype)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.hidden_size = config.hidden_size
83
+ self.intermediate_size = config.intermediate_size
84
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
85
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
86
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
87
+ self.act_fn = ACT2FN[config.hidden_act]
88
+
89
+ def forward(self, x):
90
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
91
+ return down_proj
92
+
93
+
94
+ def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
95
+ vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
96
+ vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
97
+ language_token_mask = ~vision_token_mask
98
+ return vision_token_mask, language_token_mask
99
+
100
+
101
+ class VisionExpertMLP(nn.Module):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.language_mlp = MLP(config)
105
+ self.vision_mlp = MLP(config)
106
+
107
+ def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
108
+ output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
109
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
110
+ output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
111
+ output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
112
+ return output
113
+
114
+
115
+ def attention_fn(
116
+ query_layer: "torch.tensor(B, H, L, HD)",
117
+ key_layer: "torch.tensor(B, H, L, HD)",
118
+ value_layer: "torch.tensor(B, H, L, HD)",
119
+ attention_mask: "torch.tensor(B, H, L, HD)",
120
+ *,
121
+ scaling_attention_score: bool = True,
122
+ attention_dropout: nn.Module = None
123
+ ):
124
+ attention_mask_bool = (attention_mask == 0)
125
+ is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
126
+ is_full = (attention_mask_bool > 0).all()
127
+ if not (int(torch.__version__.split('.')[0]) >= 2):
128
+ warnings.warn("It's recommended to use torch2.0 or higher.")
129
+ if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
130
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
131
+ return torch.nn.functional.scaled_dot_product_attention(
132
+ query_layer, key_layer, value_layer,
133
+ attn_mask=None,
134
+ dropout_p=dropout_p,
135
+ is_causal=not is_full
136
+ )
137
+ else:
138
+ if scaling_attention_score:
139
+ query_layer = query_layer / math.sqrt(query_layer.shape[-1])
140
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
141
+ attention_scores = attention_scores + attention_mask
142
+ attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
143
+ if attention_dropout is not None:
144
+ attention_scores = attention_dropout(attention_scores)
145
+ context_layer = torch.matmul(attention_scores, value_layer)
146
+ return context_layer
147
+
148
+ class RotaryEmbedding(torch.nn.Module):
149
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
150
+ super().__init__()
151
+
152
+ self.dim = dim
153
+ self.max_position_embeddings = max_position_embeddings
154
+ self.base = base
155
+ inv_freq = self._compute_inv_freq(device)
156
+ self.register_buffer("inv_freq", inv_freq)
157
+ self.max_seq_len_cached = 0
158
+
159
+ def _compute_inv_freq(self, device=None):
160
+ return 1.0 / (
161
+ self.base
162
+ ** (torch.arange(0, self.dim, 2, device=device) / self.dim)
163
+ )
164
+
165
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
166
+ self.max_seq_len_cached = seq_len
167
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
168
+
169
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
170
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
171
+ emb = torch.cat((freqs, freqs), dim=-1)
172
+ self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
173
+ self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
174
+
175
+ def forward(self, x, seq_len):
176
+ # x: [bs, num_attention_heads, seq_len, head_size]
177
+ if seq_len > self.max_seq_len_cached:
178
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
179
+
180
+ return (
181
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
182
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
183
+ )
184
+
185
+
186
+ def rotate_half(x):
187
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
188
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
189
+
190
+
191
+ def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
192
+ # batch_size, num_head, seq_len, hidden_size
193
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
194
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
195
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
196
+ return q, k
197
+
198
+ class VisionExpertAttention(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.hidden_size = config.hidden_size
203
+ self.num_heads = config.num_attention_heads
204
+ self.head_dim = self.hidden_size // self.num_heads
205
+ self.max_position_embeddings = config.max_position_embeddings
206
+
207
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
208
+ self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
209
+ self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
210
+ self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
211
+ self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
212
+
213
+ def _transpose_for_scores(self, tensor):
214
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
215
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.head_dim)
216
+ tensor = tensor.view(*new_tensor_shape)
217
+ return tensor.permute(0, 2, 1, 3)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ token_type_ids: torch.LongTensor,
223
+ position_ids: torch.LongTensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
+ output_attentions: bool = False,
227
+ use_cache: bool = False,
228
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
229
+ bsz, q_len, _ = hidden_states.size()
230
+ vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
231
+
232
+ shape = list(hidden_states.shape)
233
+ shape[-1] = shape[-1] * 3
234
+ mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
235
+ mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
236
+ mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
237
+
238
+ query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
239
+ query_states = self._transpose_for_scores(query_states) # B, H, L, HD
240
+ key_states = self._transpose_for_scores(key_states) # B, H, L, HD
241
+ value_states = self._transpose_for_scores(value_states) # B, H, L, HD
242
+
243
+ kv_seq_len = key_states.shape[-2]
244
+ if past_key_value is not None:
245
+ kv_seq_len += past_key_value[0].shape[-2]
246
+
247
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
248
+ query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
249
+
250
+ if past_key_value is not None:
251
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
252
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
253
+
254
+ past_key_value = (key_states, value_states) if use_cache else None
255
+
256
+ context_layer = attention_fn(
257
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
258
+ scaling_attention_score=True, attention_dropout=None)
259
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.head_dim):
260
+ raise ValueError(
261
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
262
+ f" {context_layer.size()}"
263
+ )
264
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
265
+
266
+ attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
267
+ attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
268
+ attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
269
+
270
+ if output_attentions:
271
+ warnings.warn("output_attentions is not implemented.")
272
+
273
+ return attn_output, None, past_key_value
274
+
275
+ class CrossAttention(nn.Module):
276
+ def __init__(self, config):
277
+ super().__init__()
278
+ self.config = config
279
+ self.hidden_size = config.hidden_size
280
+ self.cross_hidden_size = config.cross_hidden_size
281
+ self.cross_compute_hidden_size = config.cross_compute_hidden_size
282
+ self.num_heads = config.num_attention_heads
283
+ self.head_dim = self.hidden_size // self.num_heads
284
+ self.cross_head_dim = self.cross_compute_hidden_size // self.num_heads
285
+ self.max_position_embeddings = config.max_position_embeddings
286
+
287
+ self.query = nn.Linear(self.hidden_size, self.cross_compute_hidden_size, bias=False)
288
+ self.key_value = nn.Linear(self.cross_hidden_size, self.cross_compute_hidden_size * 2, bias=False)
289
+ self.dense = nn.Linear(self.cross_compute_hidden_size, self.hidden_size, bias=False)
290
+
291
+ def _transpose_for_scores(self, tensor):
292
+ """Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
293
+ new_tensor_shape = tensor.size()[:-1] + (self.num_heads, self.cross_head_dim)
294
+ tensor = tensor.view(*new_tensor_shape)
295
+ return tensor.permute(0, 2, 1, 3)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ encoder_outputs: torch.LongTensor,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
303
+ output_attentions: bool = False,
304
+ use_cache: bool = False,
305
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
306
+ bsz, q_len, _ = hidden_states.size()
307
+
308
+ shape = list(hidden_states.shape)
309
+ shape[-1] = shape[-1] * 3
310
+
311
+ mixed_query_layer = self.query(hidden_states)
312
+ if past_key_value is None:
313
+ mixed_x_layer = self.key_value(encoder_outputs)
314
+ mixed_key_layer, mixed_value_layer = torch.split(mixed_x_layer, self.cross_compute_hidden_size, dim=-1)
315
+ key_states = self._transpose_for_scores(mixed_key_layer) # B, H, L, HD
316
+ value_states = self._transpose_for_scores(mixed_value_layer) # B, H, L, HD
317
+ else:
318
+ key_states, value_states = past_key_value
319
+
320
+ query_states = self._transpose_for_scores(mixed_query_layer) # B, H, L, HD
321
+
322
+ past_key_value = (key_states, value_states) if use_cache else None
323
+
324
+ context_layer = attention_fn(
325
+ query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
326
+ scaling_attention_score=True, attention_dropout=None)
327
+ if context_layer.size() != (bsz, self.num_heads, q_len, self.cross_head_dim):
328
+ raise ValueError(
329
+ f"`cross_attn_output` should be of size {(bsz, self.num_heads, q_len, self.cross_head_dim)}, but is"
330
+ f" {context_layer.size()}"
331
+ )
332
+ context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.cross_hidden_size)
333
+
334
+ attn_output = self.dense(context_layer)
335
+
336
+ if output_attentions:
337
+ warnings.warn("output_attentions is not implemented.")
338
+
339
+ return attn_output, None, past_key_value
340
+
341
+ class CogAgentDecoderLayer(nn.Module):
342
+ def __init__(self, config):
343
+ super().__init__()
344
+ self.hidden_size = config.hidden_size
345
+ self.self_attn = VisionExpertAttention(config=config)
346
+ self.cross_attn = CrossAttention(config=config)
347
+ self.mlp = VisionExpertMLP(config)
348
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
349
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
350
+ self.post_cross_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ encoder_outputs: torch.Tensor,
356
+ token_type_ids: torch.LongTensor,
357
+ position_ids: torch.LongTensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ cross_attention_mask: Optional[torch.Tensor] = None,
360
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
361
+ output_attentions: Optional[bool] = False,
362
+ use_cache: Optional[bool] = False,
363
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
364
+ residual = hidden_states
365
+
366
+ hidden_states = self.input_layernorm(hidden_states)
367
+
368
+ # Self Attention
369
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
370
+ hidden_states=hidden_states,
371
+ token_type_ids=token_type_ids,
372
+ position_ids=position_ids,
373
+ attention_mask=attention_mask,
374
+ past_key_value=past_key_value[:2] if past_key_value is not None else None,
375
+ output_attentions=output_attentions,
376
+ use_cache=use_cache,
377
+ )
378
+ hidden_states = residual + hidden_states
379
+
380
+ cross_input = self.post_cross_attention_layernorm(hidden_states)
381
+ # Fully Connected
382
+ attention_output, self_cross_attn_weights, present_cross_key_value = self.cross_attn(
383
+ hidden_states=cross_input,
384
+ encoder_outputs=encoder_outputs,
385
+ attention_mask=cross_attention_mask,
386
+ past_key_value=past_key_value[-2:] if past_key_value is not None else None,
387
+ output_attentions=output_attentions,
388
+ use_cache=use_cache,
389
+ )
390
+ hidden_states = hidden_states + attention_output
391
+ mlp_input = self.post_attention_layernorm(hidden_states)
392
+ mlp_output = self.mlp(mlp_input, token_type_ids=token_type_ids)
393
+ hidden_states = mlp_output + hidden_states
394
+
395
+ outputs = (hidden_states,)
396
+
397
+ if output_attentions:
398
+ outputs += (self_attn_weights,)
399
+
400
+ if use_cache:
401
+ outputs += (present_key_value+present_cross_key_value,)
402
+
403
+ return outputs # type: ignore
404
+
405
+
406
+ class CogAgentPreTrainedModel(PreTrainedModel):
407
+ config_class = CogAgentConfig
408
+ base_model_prefix = "model"
409
+ supports_gradient_checkpointing = False
410
+ _no_split_modules = ["CogAgentDecoderLayer", 'TransformerLayer', 'Block']
411
+ _skip_keys_device_placement = "past_key_values"
412
+
413
+ def _init_weights(self, module):
414
+ std = self.config.initializer_range
415
+ if isinstance(module, nn.Linear):
416
+ module.weight.data.normal_(mean=0.0, std=std)
417
+ if module.bias is not None:
418
+ module.bias.data.zero_()
419
+ elif isinstance(module, nn.Embedding):
420
+ module.weight.data.normal_(mean=0.0, std=std)
421
+ if module.padding_idx is not None:
422
+ module.weight.data[module.padding_idx].zero_()
423
+
424
+
425
+ def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
426
+ if images_list is None or len(images_list) == 0:
427
+ return True
428
+ for image_list in images_list:
429
+ if len(image_list):
430
+ return False
431
+ return True
432
+
433
+
434
+ def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
435
+ if attention_mask is not None:
436
+ tmp = x.clone()
437
+ tmp[~(attention_mask.bool())] = -1
438
+ else:
439
+ tmp = x.clone()
440
+ # image boi eoi token as LANGUAGE_TOKEN_TYPE
441
+ is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
442
+ is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
443
+ is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
444
+ is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
445
+ is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
446
+ tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
447
+ # final position ids
448
+ y = torch.zeros_like(x, dtype=torch.long)
449
+ y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
450
+ y = y.cumsum(dim=-1)
451
+ return y
452
+
453
+
454
+ class CogAgentModel(CogAgentPreTrainedModel):
455
+ def __init__(self, config):
456
+ super().__init__(config)
457
+ self.padding_idx = config.pad_token_id
458
+ self.vocab_size = config.vocab_size
459
+
460
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
461
+ self.layers = nn.ModuleList([CogAgentDecoderLayer(config) for _ in range(config.num_hidden_layers)])
462
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
463
+
464
+ self.vision = EVA2CLIPModel(config)
465
+ self.cross_vision = CrossVisionModel(config)
466
+
467
+ self.gradient_checkpointing = False
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
472
+ images_list, images = images, []
473
+
474
+ images = []
475
+ for image_list in images_list:
476
+ for image in image_list:
477
+ images.append(image)
478
+
479
+ images = torch.stack(images)
480
+ images_features = self.vision(images)
481
+ return images_features
482
+
483
+ def encode_cross_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
484
+ images_list, images = images, []
485
+
486
+ images = []
487
+ for image_list in images_list:
488
+ for image in image_list:
489
+ images.append(image)
490
+
491
+ images = torch.stack(images)
492
+ encoder_outputs = self.cross_vision(images)
493
+ return encoder_outputs
494
+
495
+ def forward(
496
+ self,
497
+ input_ids: torch.LongTensor = None,
498
+ images: List[List[torch.Tensor]] = None,
499
+ cross_images: List[List[torch.Tensor]] = None,
500
+ token_type_ids: Optional[torch.LongTensor] = None,
501
+ attention_mask: Optional[torch.Tensor] = None,
502
+ cross_attention_mask: Optional[torch.Tensor] = None,
503
+ position_ids: Optional[torch.LongTensor] = None,
504
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
505
+ inputs_embeds: Optional[torch.FloatTensor] = None,
506
+ use_cache: Optional[bool] = None,
507
+ output_attentions: Optional[bool] = None,
508
+ output_hidden_states: Optional[bool] = None,
509
+ return_dict: Optional[bool] = None,
510
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
511
+ """take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
512
+
513
+ if past_key_values is not None:
514
+ encoder_outputs = None
515
+ # generate mode with past_key_values. the image features are already mapped
516
+ else:
517
+ # not allow for inputs_embeds, because we want to process image feature
518
+ assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
519
+ if not is_empty(images): # multi-modality
520
+ assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
521
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
522
+ inputs_embeds = self.embed_tokens(input_ids)
523
+ images_features = self.encode_images(images)
524
+ encoder_outputs = self.encode_cross_images(cross_images)
525
+ images_features = rearrange(images_features, 'b n d -> (b n) d')
526
+ images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
527
+ inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
528
+ else: # single-modality
529
+ if token_type_ids is None:
530
+ token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
531
+ assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
532
+ inputs_embeds = self.embed_tokens(input_ids)
533
+ encoder_outputs = None
534
+
535
+ if position_ids is None:
536
+ position_ids = build_position_ids(token_type_ids, attention_mask)
537
+ input_ids = None
538
+
539
+ return self.llm_forward(
540
+ input_ids=input_ids,
541
+ encoder_outputs=encoder_outputs,
542
+ token_type_ids=token_type_ids,
543
+ attention_mask=attention_mask,
544
+ cross_attention_mask=cross_attention_mask,
545
+ position_ids=position_ids,
546
+ past_key_values=past_key_values,
547
+ inputs_embeds=inputs_embeds,
548
+ use_cache=use_cache,
549
+ output_attentions=output_attentions,
550
+ output_hidden_states=output_hidden_states,
551
+ return_dict=return_dict,
552
+ )
553
+
554
+ def llm_forward(
555
+ self,
556
+ input_ids: torch.LongTensor = None,
557
+ encoder_outputs: torch.LongTensor = None,
558
+ token_type_ids: torch.LongTensor = None,
559
+ attention_mask: Optional[torch.Tensor] = None,
560
+ cross_attention_mask: Optional[torch.Tensor] = None,
561
+ position_ids: Optional[torch.LongTensor] = None,
562
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
563
+ inputs_embeds: Optional[torch.FloatTensor] = None,
564
+ use_cache: Optional[bool] = None,
565
+ output_attentions: Optional[bool] = None,
566
+ output_hidden_states: Optional[bool] = None,
567
+ return_dict: Optional[bool] = None,
568
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
569
+ """largely copy from llama forward and adapt for CogAgent with `token_type_ids`"""
570
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
571
+ output_hidden_states = (
572
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
573
+ )
574
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
575
+
576
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
577
+
578
+ # retrieve input_ids and inputs_embeds
579
+ if input_ids is not None and inputs_embeds is not None:
580
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
581
+ elif input_ids is not None:
582
+ batch_size, seq_length = input_ids.shape
583
+ elif inputs_embeds is not None:
584
+ batch_size, seq_length, _ = inputs_embeds.shape
585
+ else:
586
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
587
+
588
+ seq_length_with_past = seq_length
589
+ past_key_values_length = 0
590
+
591
+ if past_key_values is not None:
592
+ past_key_values_length = past_key_values[0][0].shape[2]
593
+ seq_length_with_past = seq_length_with_past + past_key_values_length
594
+
595
+ if position_ids is None:
596
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
597
+ position_ids = torch.arange(
598
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
599
+ )
600
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
601
+ else:
602
+ position_ids = position_ids.view(-1, seq_length).long()
603
+
604
+ if inputs_embeds is None:
605
+ inputs_embeds = self.embed_tokens(input_ids)
606
+ # embed positions
607
+ if attention_mask is None:
608
+ attention_mask = torch.ones(
609
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
610
+ )
611
+ if cross_attention_mask is None:
612
+ cross_attention_mask = torch.ones(
613
+ (batch_size, 1), dtype=torch.bool, device=inputs_embeds.device
614
+ )
615
+ attention_mask = self._prepare_decoder_attention_mask(
616
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
617
+ )
618
+
619
+ hidden_states = inputs_embeds
620
+
621
+ # decoder layers
622
+ all_hidden_states = () if output_hidden_states else None
623
+ all_self_attns = () if output_attentions else None
624
+ next_decoder_cache = () if use_cache else None
625
+
626
+ for idx, decoder_layer in enumerate(self.layers):
627
+ if output_hidden_states:
628
+ all_hidden_states += (hidden_states,)
629
+
630
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
631
+ layer_outputs = decoder_layer(
632
+ hidden_states,
633
+ encoder_outputs=encoder_outputs,
634
+ token_type_ids=token_type_ids,
635
+ attention_mask=attention_mask,
636
+ cross_attention_mask=cross_attention_mask,
637
+ position_ids=position_ids,
638
+ past_key_value=past_key_value,
639
+ output_attentions=output_attentions,
640
+ use_cache=use_cache,
641
+ )
642
+ hidden_states = layer_outputs[0]
643
+
644
+ if use_cache:
645
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
646
+
647
+ if output_attentions:
648
+ all_self_attns += (layer_outputs[1],)
649
+
650
+ hidden_states = self.norm(hidden_states)
651
+
652
+ # add hidden states from the last decoder layer
653
+ if output_hidden_states:
654
+ all_hidden_states += (hidden_states,)
655
+
656
+ next_cache = next_decoder_cache if use_cache else None
657
+ if not return_dict:
658
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
659
+ return BaseModelOutputWithPast(
660
+ last_hidden_state=hidden_states,
661
+ past_key_values=next_cache,
662
+ hidden_states=all_hidden_states,
663
+ attentions=all_self_attns,
664
+ )
665
+
666
+ def get_input_embeddings(self):
667
+ return self.embed_tokens
668
+
669
+ def set_input_embeddings(self, value):
670
+ self.embed_tokens = value
671
+
672
+ # noinspection PyMethodMayBeStatic
673
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
674
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
675
+ # create causal mask
676
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
677
+ combined_attention_mask = None
678
+ if input_shape[-1] > 1:
679
+ combined_attention_mask = _make_causal_mask(
680
+ input_shape,
681
+ inputs_embeds.dtype,
682
+ device=inputs_embeds.device,
683
+ past_key_values_length=past_key_values_length,
684
+ )
685
+
686
+ if attention_mask is not None:
687
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
688
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
689
+ inputs_embeds.device
690
+ )
691
+ combined_attention_mask = (
692
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
693
+ )
694
+
695
+ return combined_attention_mask
696
+
697
+ def vqa_history_to_prompt(history, query):
698
+ # Only support single round chat in vqa mode
699
+ prompt = "<EOI>Question: "
700
+ # for i, (old_query, response) in enumerate(history):
701
+ # prompt += old_query + " Short answer: " + response + " Question: "
702
+ prompt += query + " Short answer:"
703
+ return prompt
704
+
705
+ def chat_old_history_to_prompt(history, query):
706
+ prompt = "<EOI>Question: "
707
+ for i, (old_query, response) in enumerate(history):
708
+ prompt += old_query + " Answer: " + response + "\nQuestion: "
709
+ prompt += query + " Answer:"
710
+ return prompt
711
+
712
+ def chat_history_to_prompt(history, query):
713
+ prompt = " [INST] "
714
+ for i, (old_query, response) in enumerate(history):
715
+ prompt += old_query + " [/INST] " + response + " [INST] "
716
+ prompt += query + " [/INST] "
717
+ return prompt
718
+
719
+
720
+ def base_history_to_prompt(history, query):
721
+ prompt = query
722
+ return prompt
723
+
724
+
725
+ _history_to_prompt = {
726
+ "base": base_history_to_prompt,
727
+ "chat": chat_history_to_prompt,
728
+ "chat_old": chat_old_history_to_prompt,
729
+ "vqa": vqa_history_to_prompt
730
+ }
731
+
732
+
733
+ class CogAgentForCausalLM(CogAgentPreTrainedModel):
734
+ _auto_class = "AutoModelForCausalLM"
735
+
736
+ def __init__(self, config):
737
+ super().__init__(config)
738
+ self.model = CogAgentModel(config)
739
+ self.vocab_size = config.vocab_size
740
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
741
+
742
+ # Initialize weights and apply final processing
743
+ self.post_init()
744
+
745
+ def get_input_embeddings(self):
746
+ return self.model.embed_tokens
747
+
748
+ def set_input_embeddings(self, value):
749
+ self.model.embed_tokens = value
750
+
751
+ def get_output_embeddings(self):
752
+ return self.lm_head
753
+
754
+ def set_output_embeddings(self, new_embeddings):
755
+ self.lm_head = new_embeddings
756
+
757
+ def set_decoder(self, decoder):
758
+ self.model = decoder
759
+
760
+ def get_decoder(self):
761
+ return self.model
762
+
763
+ def forward(
764
+ self,
765
+ input_ids: torch.LongTensor = None,
766
+ images: List[List[torch.Tensor]] = None,
767
+ cross_images: List[List[torch.Tensor]] = None,
768
+ token_type_ids: Optional[torch.LongTensor] = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ use_cache: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ labels: Optional[torch.LongTensor] = None,
778
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
779
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
780
+ output_hidden_states = (
781
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
782
+ )
783
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
784
+
785
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
786
+ outputs = self.model(
787
+ input_ids=input_ids,
788
+ images=images,
789
+ cross_images=cross_images,
790
+ token_type_ids=token_type_ids,
791
+ attention_mask=attention_mask,
792
+ position_ids=position_ids,
793
+ past_key_values=past_key_values,
794
+ inputs_embeds=inputs_embeds,
795
+ use_cache=use_cache,
796
+ output_attentions=output_attentions,
797
+ output_hidden_states=output_hidden_states,
798
+ return_dict=return_dict,
799
+ )
800
+
801
+ hidden_states = outputs[0]
802
+ logits = self.lm_head(hidden_states)
803
+ logits = logits.float()
804
+
805
+ loss = None
806
+ if labels is not None:
807
+ # Shift so that tokens < n predict n
808
+ shift_logits = logits[..., :-1, :].contiguous()
809
+ shift_labels = labels[..., 1:].contiguous()
810
+ # Flatten the tokens
811
+ loss_fct = CrossEntropyLoss()
812
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
813
+ shift_labels = shift_labels.view(-1)
814
+ # Enable model parallelism
815
+ shift_labels = shift_labels.to(shift_logits.device)
816
+ loss = loss_fct(shift_logits, shift_labels)
817
+
818
+ if not return_dict:
819
+ output = (logits,) + outputs[1:]
820
+ return (loss,) + output if loss is not None else output
821
+
822
+ return CausalLMOutputWithPast(
823
+ loss=loss,
824
+ logits=logits,
825
+ past_key_values=outputs.past_key_values,
826
+ hidden_states=outputs.hidden_states,
827
+ attentions=outputs.attentions,
828
+ )
829
+
830
+ def _prepare_attention_mask_for_generation(
831
+ self,
832
+ inputs: torch.Tensor,
833
+ pad_token_id: Optional[int],
834
+ eos_token_id: Optional[Union[int, List[int]]],
835
+ ) -> torch.LongTensor:
836
+ return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
837
+
838
+ def prepare_inputs_for_generation(
839
+ self, input_ids, token_type_ids, images=None, cross_images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
840
+ ):
841
+ # build position_ids if needed
842
+ position_ids = kwargs.get("position_ids", None)
843
+ if position_ids is None:
844
+ position_ids = build_position_ids(token_type_ids, attention_mask)
845
+
846
+ if past_key_values:
847
+ input_ids = input_ids[:, -1:]
848
+ token_type_ids = token_type_ids[:, -1:]
849
+ position_ids = position_ids[:, -1:]
850
+
851
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
852
+ if inputs_embeds is not None and past_key_values is None:
853
+ model_inputs = {"inputs_embeds": inputs_embeds}
854
+ else:
855
+ model_inputs = {"input_ids": input_ids}
856
+
857
+ model_inputs.update(
858
+ {
859
+ "token_type_ids": token_type_ids,
860
+ "images": images,
861
+ "cross_images": cross_images,
862
+ "position_ids": position_ids,
863
+ "past_key_values": past_key_values,
864
+ "use_cache": kwargs.get("use_cache"),
865
+ "attention_mask": attention_mask,
866
+ }
867
+ )
868
+ return model_inputs
869
+
870
+ def _update_model_kwargs_for_generation(
871
+ self,
872
+ outputs: "ModelOutput",
873
+ model_kwargs: Dict[str, Any],
874
+ is_encoder_decoder: bool = False,
875
+ standardize_cache_format: bool = False,
876
+ ) -> Dict[str, Any]:
877
+ # update past_key_values
878
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
879
+ outputs, standardize_cache_format=standardize_cache_format
880
+ )
881
+ if getattr(outputs, "state", None) is not None:
882
+ model_kwargs["state"] = outputs.state
883
+
884
+ # update token_type_ids with last value
885
+ if "token_type_ids" in model_kwargs:
886
+ token_type_ids = model_kwargs["token_type_ids"]
887
+ new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
888
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
889
+
890
+ if not is_encoder_decoder:
891
+ # update attention mask
892
+ if "attention_mask" in model_kwargs:
893
+ attention_mask = model_kwargs["attention_mask"]
894
+ model_kwargs["attention_mask"] = torch.cat(
895
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
896
+ )
897
+ else:
898
+ # update decoder attention mask
899
+ if "decoder_attention_mask" in model_kwargs:
900
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
901
+ model_kwargs["decoder_attention_mask"] = torch.cat(
902
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
903
+ dim=-1,
904
+ )
905
+
906
+ return model_kwargs
907
+
908
+ def _reorder_cache(self, past_key_values, beam_idx):
909
+ reordered_past = ()
910
+ for layer_past in past_key_values:
911
+ reordered_past += (
912
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
913
+ )
914
+ return reordered_past
915
+
916
+ def build_conversation_input_ids(
917
+ self,
918
+ tokenizer: "PreTrainedTokenizer",
919
+ *,
920
+ query: str,
921
+ history: Optional[List[Tuple[str, str]]] = None,
922
+ images: Optional[List["PIL.Image"]] = None,
923
+ template_version: Optional[Literal["base", "chat", "vqa"]] = None,
924
+ ):
925
+ image_size: int = self.config.vision_config['image_size']
926
+ cross_image_size: int = self.config.cross_image_size
927
+ patch_size: int = self.config.vision_config['patch_size']
928
+ template_version = template_version or self.config.template_version
929
+ assert images is None or len(images) <= 1, f"not support multi images by now."
930
+ history = history or []
931
+ text = _history_to_prompt[template_version](history, query)
932
+
933
+ input_ids = [tokenizer.bos_token_id]
934
+ token_type_ids = [LANGUAGE_TOKEN_TYPE]
935
+ if images is not None and len(images) == 1:
936
+ ori = images
937
+ # vision
938
+ transform = transforms.Compose(
939
+ [
940
+ transforms.Resize(
941
+ (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
942
+ ),
943
+ transforms.ToTensor(),
944
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
945
+ ]
946
+ )
947
+ images = [transform(ori[0])]
948
+ cross_transform = transforms.Compose(
949
+ [
950
+ transforms.Resize(
951
+ (cross_image_size, cross_image_size), interpolation=transforms.InterpolationMode.BICUBIC
952
+ ),
953
+ transforms.ToTensor(),
954
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
955
+ ]
956
+ )
957
+ cross_images = [cross_transform(ori[0])]
958
+ # language
959
+ vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
960
+ input_ids += [tokenizer.pad_token_id] * vision_token_num
961
+ token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
962
+ text_ids = tokenizer.encode(text, add_special_tokens=False)
963
+
964
+ input_ids += text_ids
965
+ token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
966
+ attention_mask = [1] * len(input_ids)
967
+
968
+ return {
969
+ 'input_ids': torch.tensor(input_ids, dtype=torch.long),
970
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
971
+ 'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
972
+ 'images': images,
973
+ 'cross_images': cross_images
974
+ }
pytorch_model-00001-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d09e6d7a4f2620d2d32c839747af292cbfbdf6cc9c35ea5a6619b30ef820856
3
+ size 4979106642
pytorch_model-00002-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6dbaddf8f16b064c0f6d55f97cd492fd44cbc3e20bb812d27f4e968d4010b63f
3
+ size 4987736182
pytorch_model-00003-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96aa4207396b0794637891ffee4d900af383c74c2e7b4f00c0115c0eb6630f24
3
+ size 4974066815
pytorch_model-00004-of-00004.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d8172706e1fcf3be3e3c9f96baac7f1cf42c984a8360423d58373f8d5f3afec
3
+ size 3654473430
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": true,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": false,
35
+ "model_max_length": 4096,
36
+ "pad_token": "<unk>",
37
+ "padding_side": "right",
38
+ "sp_model_kwargs": {},
39
+ "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
visual.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from argparse import Namespace
4
+ import xformers.ops as xops
5
+ from transformers.activations import ACT2FN
6
+
7
+
8
+ class PatchEmbedding(nn.Module):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
12
+ self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
13
+ self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
14
+
15
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
16
+ x = self.proj(images)
17
+ x = x.flatten(2).transpose(1, 2)
18
+ cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
19
+ x = torch.cat((cls_token, x), dim=1)
20
+ x += self.position_embedding.weight.unsqueeze(0)
21
+ return x
22
+
23
+
24
+ class Attention(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.num_heads = config.num_heads
28
+ head_dim = config.hidden_size // config.num_heads
29
+ self.scale = head_dim ** -0.5
30
+ self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3)
31
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
32
+ self.output_dropout = torch.nn.Dropout(config.dropout_prob)
33
+
34
+ def forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
35
+ B, L, _ = x.shape
36
+ qkv = self.query_key_value(x)
37
+ qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4) # 3, B, L, H, D
38
+ q, k, v = qkv[0], qkv[1], qkv[2]
39
+
40
+ out = xops.memory_efficient_attention(
41
+ q, k, v, scale=self.scale,
42
+ )
43
+ output = self.dense(out.view(B, L, -1))
44
+ output = self.output_dropout(output)
45
+ return output
46
+
47
+ def attention(self, q, k, v):
48
+ attn_weights = torch.matmul(q * self.scale, k.transpose(-2, -1))
49
+ attn_weights = attn_weights.softmax(dim=-1)
50
+ output = torch.matmul(attn_weights, v)
51
+ return output
52
+
53
+
54
+ class MLP(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.config = config
58
+ self.activation_fn = ACT2FN[config.hidden_act]
59
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
60
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ x = self.fc1(x)
64
+ x = self.activation_fn(x)
65
+ x = self.fc2(x)
66
+ return x
67
+
68
+
69
+ class TransformerLayer(nn.Module):
70
+ def __init__(self, config):
71
+ super().__init__()
72
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
73
+ self.attention = Attention(config)
74
+ self.mlp = MLP(config)
75
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
76
+
77
+ def forward(self, hidden_states):
78
+ attention_input = hidden_states
79
+ attention_output = self.input_layernorm(self.attention(attention_input))
80
+ hidden_states = attention_input + attention_output
81
+ mlp_input = hidden_states
82
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
83
+ output = mlp_input + mlp_output
84
+ return output
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.layers = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
91
+
92
+ def forward(self, hidden_states):
93
+ for layer_module in self.layers:
94
+ hidden_states = layer_module(hidden_states)
95
+ return hidden_states
96
+
97
+
98
+ class GLU(nn.Module):
99
+ def __init__(self, config, in_features):
100
+ super().__init__()
101
+ self.linear_proj = nn.Linear(in_features, config.hidden_size, bias=False)
102
+ self.norm1 = nn.LayerNorm(config.hidden_size)
103
+ self.act1 = nn.GELU()
104
+ self.act2 = nn.functional.silu
105
+ self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
106
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
107
+ self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
108
+
109
+ def forward(self, x):
110
+ x = self.linear_proj(x)
111
+ x = self.act1(self.norm1(x))
112
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
113
+ x = self.dense_4h_to_h(x)
114
+ return x
115
+
116
+
117
+ class EVA2CLIPModel(nn.Module):
118
+ def __init__(self, config):
119
+ super().__init__()
120
+ vision_config = Namespace(**config.vision_config)
121
+ self.patch_embedding = PatchEmbedding(vision_config)
122
+ self.transformer = Transformer(vision_config)
123
+ self.linear_proj = GLU(config, in_features=vision_config.hidden_size)
124
+ self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
125
+ self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
126
+ self.pos_embed = nn.Parameter(torch.zeros((vision_config.image_size // vision_config.patch_size) ** 2, vision_config.hidden_size))
127
+
128
+ def forward(self, images: "tensor(B, C, H, W)") -> "tensor(B, L, D)":
129
+ x = self.patch_embedding(images)
130
+ x = self.transformer(x)
131
+ x = x[:, 1:]
132
+ x = self.linear_proj(x + self.pos_embed.to(x.device).unsqueeze(0))
133
+ boi = self.boi.to(x.device).expand(x.shape[0], -1, -1)
134
+ eoi = self.eoi.to(x.device).expand(x.shape[0], -1, -1)
135
+ x = torch.cat((boi, x, eoi), dim=1)
136
+ return x