katuni4ka commited on
Commit
f0b37fb
·
verified ·
1 Parent(s): 121db74

Upload 18 files

Browse files
__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from .image_processing_vlm import VLMImageProcessor
21
+ from .modeling_vlm import MultiModalityCausalLM
22
+ from .processing_vlm import VLChatProcessor
23
+
24
+ __all__ = [
25
+ "VLMImageProcessor",
26
+ "VLChatProcessor",
27
+ "MultiModalityCausalLM",
28
+ ]
attrdict_config.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+ try:
3
+ from collections import Mapping, MutableMapping, Sequence
4
+ except ImportError:
5
+ from collections.abc import Mapping, MutableMapping, Sequence
6
+ import re
7
+
8
+ import six
9
+
10
+
11
+
12
+ __all__ = ['Attr', 'MutableAttr']
13
+
14
+ def merge(left, right):
15
+ """
16
+ Merge two mappings objects together, combining overlapping Mappings,
17
+ and favoring right-values
18
+
19
+ left: The left Mapping object.
20
+ right: The right (favored) Mapping object.
21
+
22
+ NOTE: This is not commutative (merge(a,b) != merge(b,a)).
23
+ """
24
+ merged = {}
25
+
26
+ left_keys = frozenset(left)
27
+ right_keys = frozenset(right)
28
+
29
+ # Items only in the left Mapping
30
+ for key in left_keys - right_keys:
31
+ merged[key] = left[key]
32
+
33
+ # Items only in the right Mapping
34
+ for key in right_keys - left_keys:
35
+ merged[key] = right[key]
36
+
37
+ # in both
38
+ for key in left_keys & right_keys:
39
+ left_value = left[key]
40
+ right_value = right[key]
41
+
42
+ if (isinstance(left_value, Mapping) and
43
+ isinstance(right_value, Mapping)): # recursive merge
44
+ merged[key] = merge(left_value, right_value)
45
+ else: # overwrite with right value
46
+ merged[key] = right_value
47
+
48
+ return merged
49
+
50
+ @six.add_metaclass(ABCMeta)
51
+ class Attr(Mapping):
52
+ """
53
+ A mixin class for a mapping that allows for attribute-style access
54
+ of values.
55
+
56
+ A key may be used as an attribute if:
57
+ * It is a string
58
+ * It matches /^[A-Za-z][A-Za-z0-9_]*$/ (i.e., a public attribute)
59
+ * The key doesn't overlap with any class attributes (for Attr,
60
+ those would be 'get', 'items', 'keys', 'values', 'mro', and
61
+ 'register').
62
+
63
+ If a values which is accessed as an attribute is a Sequence-type
64
+ (and is not a string/bytes), it will be converted to a
65
+ _sequence_type with any mappings within it converted to Attrs.
66
+
67
+ NOTE: This means that if _sequence_type is not None, then a
68
+ sequence accessed as an attribute will be a different object
69
+ than if accessed as an attribute than if it is accessed as an
70
+ item.
71
+ """
72
+ @abstractmethod
73
+ def _configuration(self):
74
+ """
75
+ All required state for building a new instance with the same
76
+ settings as the current object.
77
+ """
78
+
79
+ @classmethod
80
+ def _constructor(cls, mapping, configuration):
81
+ """
82
+ A standardized constructor used internally by Attr.
83
+
84
+ mapping: A mapping of key-value pairs. It is HIGHLY recommended
85
+ that you use this as the internal key-value pair mapping, as
86
+ that will allow nested assignment (e.g., attr.foo.bar = baz)
87
+ configuration: The return value of Attr._configuration
88
+ """
89
+ raise NotImplementedError("You need to implement this")
90
+
91
+ def __call__(self, key):
92
+ """
93
+ Dynamically access a key-value pair.
94
+
95
+ key: A key associated with a value in the mapping.
96
+
97
+ This differs from __getitem__, because it returns a new instance
98
+ of an Attr (if the value is a Mapping object).
99
+ """
100
+ if key not in self:
101
+ raise AttributeError(
102
+ "'{cls} instance has no attribute '{name}'".format(
103
+ cls=self.__class__.__name__, name=key
104
+ )
105
+ )
106
+
107
+ return self._build(self[key])
108
+
109
+ def __getattr__(self, key):
110
+ """
111
+ Access an item as an attribute.
112
+ """
113
+ if key not in self or not self._valid_name(key):
114
+ raise AttributeError(
115
+ "'{cls}' instance has no attribute '{name}'".format(
116
+ cls=self.__class__.__name__, name=key
117
+ )
118
+ )
119
+
120
+ return self._build(self[key])
121
+
122
+ def __add__(self, other):
123
+ """
124
+ Add a mapping to this Attr, creating a new, merged Attr.
125
+
126
+ other: A mapping.
127
+
128
+ NOTE: Addition is not commutative. a + b != b + a.
129
+ """
130
+ if not isinstance(other, Mapping):
131
+ return NotImplemented
132
+
133
+ return self._constructor(merge(self, other), self._configuration())
134
+
135
+ def __radd__(self, other):
136
+ """
137
+ Add this Attr to a mapping, creating a new, merged Attr.
138
+
139
+ other: A mapping.
140
+
141
+ NOTE: Addition is not commutative. a + b != b + a.
142
+ """
143
+ if not isinstance(other, Mapping):
144
+ return NotImplemented
145
+
146
+ return self._constructor(merge(other, self), self._configuration())
147
+
148
+ def _build(self, obj):
149
+ """
150
+ Conditionally convert an object to allow for recursive mapping
151
+ access.
152
+
153
+ obj: An object that was a key-value pair in the mapping. If obj
154
+ is a mapping, self._constructor(obj, self._configuration())
155
+ will be called. If obj is a non-string/bytes sequence, and
156
+ self._sequence_type is not None, the obj will be converted
157
+ to type _sequence_type and build will be called on its
158
+ elements.
159
+ """
160
+ if isinstance(obj, Mapping):
161
+ obj = self._constructor(obj, self._configuration())
162
+ elif (isinstance(obj, Sequence) and
163
+ not isinstance(obj, (six.string_types, six.binary_type))):
164
+ sequence_type = getattr(self, '_sequence_type', None)
165
+
166
+ if sequence_type:
167
+ obj = sequence_type(self._build(element) for element in obj)
168
+
169
+ return obj
170
+
171
+ @classmethod
172
+ def _valid_name(cls, key):
173
+ """
174
+ Check whether a key is a valid attribute name.
175
+
176
+ A key may be used as an attribute if:
177
+ * It is a string
178
+ * It matches /^[A-Za-z][A-Za-z0-9_]*$/ (i.e., a public attribute)
179
+ * The key doesn't overlap with any class attributes (for Attr,
180
+ those would be 'get', 'items', 'keys', 'values', 'mro', and
181
+ 'register').
182
+ """
183
+ return (
184
+ isinstance(key, six.string_types) and
185
+ re.match('^[A-Za-z][A-Za-z0-9_]*$', key) and
186
+ not hasattr(cls, key)
187
+ )
188
+
189
+
190
+ @six.add_metaclass(ABCMeta)
191
+ class MutableAttr(Attr, MutableMapping):
192
+ """
193
+ A mixin class for a mapping that allows for attribute-style access
194
+ of values.
195
+ """
196
+ def _setattr(self, key, value):
197
+ """
198
+ Add an attribute to the object, without attempting to add it as
199
+ a key to the mapping.
200
+ """
201
+ super(MutableAttr, self).__setattr__(key, value)
202
+
203
+ def __setattr__(self, key, value):
204
+ """
205
+ Add an attribute.
206
+
207
+ key: The name of the attribute
208
+ value: The attributes contents
209
+ """
210
+ if self._valid_name(key):
211
+ self[key] = value
212
+ elif getattr(self, '_allow_invalid_attributes', True):
213
+ super(MutableAttr, self).__setattr__(key, value)
214
+ else:
215
+ raise TypeError(
216
+ "'{cls}' does not allow attribute creation.".format(
217
+ cls=self.__class__.__name__
218
+ )
219
+ )
220
+
221
+ def _delattr(self, key):
222
+ """
223
+ Delete an attribute from the object, without attempting to
224
+ remove it from the mapping.
225
+ """
226
+ super(MutableAttr, self).__delattr__(key)
227
+
228
+ def __delattr__(self, key, force=False):
229
+ """
230
+ Delete an attribute.
231
+
232
+ key: The name of the attribute
233
+ """
234
+ if self._valid_name(key):
235
+ del self[key]
236
+ elif getattr(self, '_allow_invalid_attributes', True):
237
+ super(MutableAttr, self).__delattr__(key)
238
+ else:
239
+ raise TypeError(
240
+ "'{cls}' does not allow attribute deletion.".format(
241
+ cls=self.__class__.__name__
242
+ )
243
+ )
244
+
245
+
246
+ class AttrDict(dict, MutableAttr):
247
+ """
248
+ A dict that implements MutableAttr.
249
+ """
250
+ def __init__(self, *args, **kwargs):
251
+ super(AttrDict, self).__init__(*args, **kwargs)
252
+
253
+ self._setattr('_sequence_type', tuple)
254
+ self._setattr('_allow_invalid_attributes', False)
255
+
256
+ def _configuration(self):
257
+ """
258
+ The configuration for an attrmap instance.
259
+ """
260
+ return self._sequence_type
261
+
262
+ def __getstate__(self):
263
+ """
264
+ Serialize the object.
265
+ """
266
+ return (
267
+ self.copy(),
268
+ self._sequence_type,
269
+ self._allow_invalid_attributes
270
+ )
271
+
272
+ def __setstate__(self, state):
273
+ """
274
+ Deserialize the object.
275
+ """
276
+ mapping, sequence_type, allow_invalid_attributes = state
277
+ self.update(mapping)
278
+ self._setattr('_sequence_type', sequence_type)
279
+ self._setattr('_allow_invalid_attributes', allow_invalid_attributes)
280
+
281
+ def __repr__(self):
282
+ return six.u('AttrDict({contents})').format(
283
+ contents=super(AttrDict, self).__repr__()
284
+ )
285
+
286
+ @classmethod
287
+ def _constructor(cls, mapping, configuration):
288
+ """
289
+ A standardized constructor.
290
+ """
291
+ attr = cls(mapping)
292
+ attr._setattr('_sequence_type', configuration)
293
+
294
+ return attr
clip_encoder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Dict, List, Literal, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torchvision.transforms
25
+ from einops import rearrange
26
+
27
+ from .siglip_vit import create_siglip_vit
28
+
29
+
30
+ class CLIPVisionTower(nn.Module):
31
+ def __init__(
32
+ self,
33
+ model_name: str = "siglip_large_patch16_384",
34
+ image_size: Union[Tuple[int, int], int] = 336,
35
+ select_feature: str = "patch",
36
+ select_layer: int = -2,
37
+ select_layers: list = None,
38
+ ckpt_path: str = "",
39
+ pixel_mean: Optional[List[float]] = None,
40
+ pixel_std: Optional[List[float]] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.model_name = model_name
46
+ self.select_feature = select_feature
47
+ self.select_layer = select_layer
48
+ self.select_layers = select_layers
49
+
50
+ vision_tower_params = {
51
+ "model_name": model_name,
52
+ "image_size": image_size,
53
+ "ckpt_path": ckpt_path,
54
+ "select_layer": select_layer,
55
+ }
56
+ vision_tower_params.update(kwargs)
57
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
58
+ vision_tower_params
59
+ )
60
+
61
+ if pixel_mean is not None and pixel_std is not None:
62
+ image_norm = torchvision.transforms.Normalize(
63
+ mean=pixel_mean, std=pixel_std
64
+ )
65
+ else:
66
+ image_norm = None
67
+
68
+ self.image_norm = image_norm
69
+
70
+ def build_vision_tower(self, vision_tower_params):
71
+ if self.model_name.startswith("siglip"):
72
+ self.select_feature = "same"
73
+ vision_tower = create_siglip_vit(**vision_tower_params)
74
+ forward_kwargs = dict()
75
+
76
+ elif self.model_name.startswith("sam"):
77
+ vision_tower = create_sam_vit(**vision_tower_params)
78
+ forward_kwargs = dict()
79
+
80
+ else: # huggingface
81
+ from transformers import CLIPVisionModel
82
+
83
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
84
+ forward_kwargs = dict(output_hidden_states=True)
85
+
86
+ return vision_tower, forward_kwargs
87
+
88
+ def feature_select(self, image_forward_outs):
89
+ if isinstance(image_forward_outs, torch.Tensor):
90
+ # the output has been the self.select_layer"s features
91
+ image_features = image_forward_outs
92
+ else:
93
+ image_features = image_forward_outs.hidden_states[self.select_layer]
94
+
95
+ if self.select_feature == "patch":
96
+ # if the output has cls_token
97
+ image_features = image_features[:, 1:]
98
+ elif self.select_feature == "cls_patch":
99
+ image_features = image_features
100
+ elif self.select_feature == "same":
101
+ image_features = image_features
102
+
103
+ else:
104
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
105
+ return image_features
106
+
107
+ def forward(self, images):
108
+ """
109
+
110
+ Args:
111
+ images (torch.Tensor): [b, 3, H, W]
112
+
113
+ Returns:
114
+ image_features (torch.Tensor): [b, n_patch, d]
115
+ """
116
+
117
+ if self.image_norm is not None:
118
+ images = self.image_norm(images)
119
+
120
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
121
+ image_features = self.feature_select(image_forward_outs)
122
+ return image_features
config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "tiny-random-janus",
3
+ "aligner_config": {
4
+ "cls": "MlpProjector",
5
+ "model_type": "aligner",
6
+ "params": {
7
+ "depth": 2,
8
+ "input_dim": 32,
9
+ "n_embed": 32,
10
+ "projector_type": "mlp_gelu"
11
+ }
12
+ },
13
+ "architectures": [
14
+ "MultiModalityCausalLM"
15
+ ],
16
+ "auto_map": {
17
+ "AutoConfig": "configuration_vlm.MultiModalityConfig",
18
+ "AutoModel": "modeling_vlm.MultiModalityCausalLM",
19
+ "AutoModelForCausalLM": "modeling_vlm.MultiModalityCausalLM"
20
+ },
21
+ "gen_aligner_config": {
22
+ "cls": "MlpProjector",
23
+ "model_type": "gen_aligner",
24
+ "params": {
25
+ "depth": 2,
26
+ "input_dim": 8,
27
+ "n_embed": 32,
28
+ "projector_type": "mlp_gelu"
29
+ }
30
+ },
31
+ "gen_head_config": {
32
+ "cls": "vision_head",
33
+ "model_type": "gen_head",
34
+ "params": {
35
+ "image_token_embed": 32,
36
+ "image_token_size": 16384,
37
+ "n_embed": 32
38
+ }
39
+ },
40
+ "gen_vision_config": {
41
+ "cls": "VQ-16",
42
+ "model_type": "gen_vision",
43
+ "params": {
44
+ "image_token_size": 16384,
45
+ "n_embed": 8
46
+ }
47
+ },
48
+ "language_config": {
49
+ "_attn_implementation_autoset": true,
50
+ "head_dim": 32,
51
+ "hidden_size": 32,
52
+ "intermediate_size": 32,
53
+ "max_position_embeddings": 16384,
54
+ "model_type": "llama",
55
+ "num_attention_heads": 1,
56
+ "num_hidden_layers": 2,
57
+ "num_key_value_heads": 1,
58
+ "torch_dtype": "bfloat16",
59
+ "vocab_size": 102400
60
+ },
61
+ "model_type": "multi_modality",
62
+ "torch_dtype": "float32",
63
+ "transformers_version": "4.47.1",
64
+ "vision_config": {
65
+ "cls": "CLIPVisionTower",
66
+ "model_type": "vision",
67
+ "params": {
68
+ "image_size": 384,
69
+ "model_name": "siglip_tiny_patch2_384",
70
+ "select_feature": "same",
71
+ "select_layer": -1
72
+ }
73
+ }
74
+ }
configuration_vlm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ import sys
3
+ from transformers import (
4
+ AutoConfig,
5
+ AutoModelForCausalLM,
6
+ LlamaConfig,
7
+ LlamaForCausalLM,
8
+ PreTrainedModel,
9
+ )
10
+ from .attrdict_config import AttrDict
11
+
12
+ class VisionConfig(PretrainedConfig):
13
+ model_type = "vision"
14
+ cls: str = ""
15
+ params: AttrDict = {}
16
+
17
+ def __init__(self, **kwargs):
18
+ super().__init__(**kwargs)
19
+
20
+ self.cls = kwargs.get("cls", "")
21
+ if not isinstance(self.cls, str):
22
+ self.cls = self.cls.__name__
23
+
24
+ self.params = AttrDict(kwargs.get("params", {}))
25
+
26
+
27
+ class AlignerConfig(PretrainedConfig):
28
+ model_type = "aligner"
29
+ cls: str = ""
30
+ params: AttrDict = {}
31
+
32
+ def __init__(self, **kwargs):
33
+ super().__init__(**kwargs)
34
+
35
+ self.cls = kwargs.get("cls", "")
36
+ if not isinstance(self.cls, str):
37
+ self.cls = self.cls.__name__
38
+
39
+ self.params = AttrDict(kwargs.get("params", {}))
40
+
41
+
42
+ class GenVisionConfig(PretrainedConfig):
43
+ model_type = "gen_vision"
44
+ cls: str = ""
45
+ params: AttrDict = {}
46
+
47
+ def __init__(self, **kwargs):
48
+ super().__init__(**kwargs)
49
+
50
+ self.cls = kwargs.get("cls", "")
51
+ if not isinstance(self.cls, str):
52
+ self.cls = self.cls.__name__
53
+
54
+ self.params = AttrDict(kwargs.get("params", {}))
55
+
56
+
57
+ class GenAlignerConfig(PretrainedConfig):
58
+ model_type = "gen_aligner"
59
+ cls: str = ""
60
+ params: AttrDict = {}
61
+
62
+ def __init__(self, **kwargs):
63
+ super().__init__(**kwargs)
64
+
65
+ self.cls = kwargs.get("cls", "")
66
+ if not isinstance(self.cls, str):
67
+ self.cls = self.cls.__name__
68
+
69
+ self.params = AttrDict(kwargs.get("params", {}))
70
+
71
+
72
+ class GenHeadConfig(PretrainedConfig):
73
+ model_type = "gen_head"
74
+ cls: str = ""
75
+ params: AttrDict = {}
76
+
77
+ def __init__(self, **kwargs):
78
+ super().__init__(**kwargs)
79
+
80
+ self.cls = kwargs.get("cls", "")
81
+ if not isinstance(self.cls, str):
82
+ self.cls = self.cls.__name__
83
+
84
+ self.params = AttrDict(kwargs.get("params", {}))
85
+
86
+
87
+ class MultiModalityConfig(PretrainedConfig):
88
+ model_type = "multi_modality"
89
+ vision_config: VisionConfig
90
+ aligner_config: AlignerConfig
91
+
92
+ gen_vision_config: GenVisionConfig
93
+ gen_aligner_config: GenAlignerConfig
94
+ gen_head_config: GenHeadConfig
95
+
96
+ language_config: LlamaConfig
97
+
98
+ def __init__(self, **kwargs):
99
+ super().__init__(**kwargs)
100
+ vision_config = kwargs.get("vision_config", {})
101
+ self.vision_config = VisionConfig(**vision_config)
102
+
103
+ aligner_config = kwargs.get("aligner_config", {})
104
+ self.aligner_config = AlignerConfig(**aligner_config)
105
+
106
+ gen_vision_config = kwargs.get("gen_vision_config", {})
107
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
108
+
109
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
110
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
111
+
112
+ gen_head_config = kwargs.get("gen_head_config", {})
113
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
114
+
115
+ language_config = kwargs.get("language_config", {})
116
+ if isinstance(language_config, LlamaConfig):
117
+ self.language_config = language_config
118
+ else:
119
+ self.language_config = LlamaConfig(**language_config)
conversation.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ """
21
+ From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
22
+ """
23
+
24
+ import dataclasses
25
+ from enum import IntEnum, auto
26
+ from typing import Dict, List
27
+
28
+
29
+ class SeparatorStyle(IntEnum):
30
+ """Separator styles."""
31
+
32
+ ADD_COLON_SINGLE = auto()
33
+ ADD_COLON_TWO = auto()
34
+ ADD_COLON_SPACE_SINGLE = auto()
35
+ NO_COLON_SINGLE = auto()
36
+ NO_COLON_TWO = auto()
37
+ ADD_NEW_LINE_SINGLE = auto()
38
+ LLAMA2 = auto()
39
+ CHATGLM = auto()
40
+ CHATML = auto()
41
+ CHATINTERN = auto()
42
+ DOLLY = auto()
43
+ RWKV = auto()
44
+ PHOENIX = auto()
45
+ ROBIN = auto()
46
+ DeepSeek = auto()
47
+ PLAIN = auto()
48
+ ALIGNMENT = auto()
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class Conversation:
53
+ """A class that manages prompt templates and keeps all conversation history."""
54
+
55
+ # The name of this template
56
+ name: str
57
+ # The template of the system prompt
58
+ system_template: str = "{system_message}"
59
+ # The system message
60
+ system_message: str = ""
61
+ # The names of two roles
62
+ roles: List[str] = (("USER", "ASSISTANT"),)
63
+ # All messages. Each item is (role, message).
64
+ messages: List[List[str]] = ()
65
+ # The number of few shot examples
66
+ offset: int = 0
67
+ # The separator style and configurations
68
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
69
+ sep: str = "\n"
70
+ sep2: str = None
71
+ # Stop criteria (the default one is EOS token)
72
+ stop_str: str = None
73
+ # Stops generation if meeting any token in this list
74
+ stop_token_ids: List[int] = None
75
+
76
+ def get_prompt(self) -> str:
77
+ """Get the prompt for generation."""
78
+ system_prompt = self.system_template.format(system_message=self.system_message)
79
+
80
+ if self.sep_style == SeparatorStyle.DeepSeek:
81
+ seps = [self.sep, self.sep2]
82
+ if system_prompt == "" or system_prompt is None:
83
+ ret = ""
84
+ else:
85
+ ret = system_prompt + seps[0]
86
+ for i, (role, message) in enumerate(self.messages):
87
+ if message:
88
+ ret += role + ": " + message + seps[i % 2]
89
+ else:
90
+ ret += role + ":"
91
+ return ret
92
+ elif self.sep_style == SeparatorStyle.LLAMA2:
93
+ seps = [self.sep, self.sep2]
94
+ if self.system_message:
95
+ ret = system_prompt
96
+ else:
97
+ ret = "[INST] "
98
+ for i, (role, message) in enumerate(self.messages):
99
+ tag = self.roles[i % 2]
100
+ if message:
101
+ if type(message) is tuple: # multimodal message
102
+ message, _ = message
103
+ if i == 0:
104
+ ret += message + " "
105
+ else:
106
+ ret += tag + " " + message + seps[i % 2]
107
+ else:
108
+ ret += tag
109
+ return ret
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = ""
113
+ for i, (role, message) in enumerate(self.messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ if i % 2 == 0:
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ return ret
124
+ elif self.sep_style == SeparatorStyle.ALIGNMENT:
125
+ seps = [self.sep, self.sep2]
126
+ ret = ""
127
+ for i, (role, message) in enumerate(self.messages):
128
+ if message:
129
+ if type(message) is tuple:
130
+ message, _, _ = message
131
+ if i % 2 == 0:
132
+ ret += "<image>\n" + seps[i % 2]
133
+ else:
134
+ ret += message + seps[i % 2]
135
+ else:
136
+ ret += ""
137
+ return ret
138
+ else:
139
+ raise ValueError(f"Invalid style: {self.sep_style}")
140
+
141
+ def get_prompt_for_current_round(self, content=None):
142
+ """Get current round formatted question prompt during sft training"""
143
+ if self.sep_style == SeparatorStyle.PLAIN:
144
+ formatted_question = "<image>\n"
145
+ elif self.sep_style == SeparatorStyle.DeepSeek:
146
+ formatted_question = (
147
+ f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:"
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported sep_style: {self.sep_style}")
151
+ return formatted_question
152
+
153
+ def set_system_message(self, system_message: str):
154
+ """Set the system message."""
155
+ self.system_message = system_message
156
+
157
+ def append_message(self, role: str, message: str):
158
+ """Append a new message."""
159
+ self.messages.append([role, message])
160
+
161
+ def reset_message(self):
162
+ """Reset a new message."""
163
+ self.messages = []
164
+
165
+ def update_last_message(self, message: str):
166
+ """Update the last output.
167
+
168
+ The last message is typically set to be None when constructing the prompt,
169
+ so we need to update it in-place after getting the response from a model.
170
+ """
171
+ self.messages[-1][1] = message
172
+
173
+ def to_gradio_chatbot(self):
174
+ """Convert the conversation to gradio chatbot format."""
175
+ ret = []
176
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
177
+ if i % 2 == 0:
178
+ ret.append([msg, None])
179
+ else:
180
+ ret[-1][-1] = msg
181
+ return ret
182
+
183
+ def to_openai_api_messages(self):
184
+ """Convert the conversation to OpenAI chat completion format."""
185
+ system_prompt = self.system_template.format(system_message=self.system_message)
186
+ ret = [{"role": "system", "content": system_prompt}]
187
+
188
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ ret.append({"role": "user", "content": msg})
191
+ else:
192
+ if msg is not None:
193
+ ret.append({"role": "assistant", "content": msg})
194
+ return ret
195
+
196
+ def copy(self):
197
+ return Conversation(
198
+ name=self.name,
199
+ system_template=self.system_template,
200
+ system_message=self.system_message,
201
+ roles=self.roles,
202
+ messages=[[x, y] for x, y in self.messages],
203
+ offset=self.offset,
204
+ sep_style=self.sep_style,
205
+ sep=self.sep,
206
+ sep2=self.sep2,
207
+ stop_str=self.stop_str,
208
+ stop_token_ids=self.stop_token_ids,
209
+ )
210
+
211
+ def dict(self):
212
+ return {
213
+ "template_name": self.name,
214
+ "system_message": self.system_message,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ }
219
+
220
+
221
+ # A global registry for all conversation templates
222
+ conv_templates: Dict[str, Conversation] = {}
223
+
224
+
225
+ def register_conv_template(template: Conversation, override: bool = False):
226
+ """Register a new conversation template."""
227
+ if not override:
228
+ assert (
229
+ template.name not in conv_templates
230
+ ), f"{template.name} has been registered."
231
+
232
+ conv_templates[template.name] = template
233
+
234
+
235
+ def get_conv_template(name: str) -> Conversation:
236
+ """Get a conversation template."""
237
+ return conv_templates[name].copy()
238
+
239
+
240
+ # llava_llama2 template
241
+ register_conv_template(
242
+ Conversation(
243
+ name="llava_llama2",
244
+ system_message="You are a helpful language and vision assistant. "
245
+ "You are able to understand the visual content that the user provides, "
246
+ "and assist the user with a variety of tasks using natural language.",
247
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
248
+ roles=("[INST]", "[/INST]"),
249
+ messages=(),
250
+ offset=0,
251
+ sep_style=SeparatorStyle.LLAMA2,
252
+ sep=" ",
253
+ sep2=" </s><s>",
254
+ stop_token_ids=[2],
255
+ )
256
+ )
257
+
258
+ # llama2 template
259
+ # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
260
+ register_conv_template(
261
+ Conversation(
262
+ name="llama-2",
263
+ system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
264
+ roles=("[INST]", "[/INST]"),
265
+ messages=(),
266
+ offset=0,
267
+ sep_style=SeparatorStyle.LLAMA2,
268
+ sep=" ",
269
+ sep2=" </s><s>",
270
+ stop_token_ids=[2],
271
+ )
272
+ )
273
+
274
+
275
+ # deepseek template
276
+ register_conv_template(
277
+ Conversation(
278
+ name="deepseek_old",
279
+ system_template="{system_message}",
280
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
281
+ # "thinking step by step to be sure you get the right answer.",
282
+ system_message="",
283
+ roles=("User", "Assistant"),
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.DeepSeek,
287
+ sep="\n\n",
288
+ sep2="<|end▁of▁sentence|>",
289
+ stop_token_ids=[100001],
290
+ stop_str=["User:", "<|end▁of▁sentence|>"],
291
+ )
292
+ )
293
+ register_conv_template(
294
+ Conversation(
295
+ name="deepseek",
296
+ system_template="{system_message}",
297
+ # system_message="You are a helpful assistant. Please answer truthfully and write out your "
298
+ # "thinking step by step to be sure you get the right answer.",
299
+ system_message="",
300
+ roles=("<|User|>", "<|Assistant|>"),
301
+ messages=(),
302
+ offset=0,
303
+ sep_style=SeparatorStyle.DeepSeek,
304
+ sep="\n\n",
305
+ sep2="<|end▁of▁sentence|>",
306
+ stop_token_ids=[100001],
307
+ stop_str=["<|User|>", "<|end▁of▁sentence|>"]
308
+ )
309
+ )
310
+
311
+ register_conv_template(
312
+ Conversation(
313
+ name="plain",
314
+ system_template="",
315
+ system_message="",
316
+ roles=("", ""),
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.PLAIN,
320
+ sep="",
321
+ sep2="",
322
+ stop_token_ids=[2],
323
+ stop_str=["</s>"],
324
+ )
325
+ )
326
+
327
+
328
+ register_conv_template(
329
+ Conversation(
330
+ name="alignment",
331
+ system_template="",
332
+ system_message="",
333
+ roles=("", ""),
334
+ messages=(),
335
+ offset=0,
336
+ sep_style=SeparatorStyle.ALIGNMENT,
337
+ sep="",
338
+ sep2="",
339
+ stop_token_ids=[2],
340
+ stop_str=["</s>"],
341
+ )
342
+ )
343
+
344
+
345
+ if __name__ == "__main__":
346
+ # print("Llama-2 template:")
347
+ # conv = get_conv_template("llama-2")
348
+ # conv.set_system_message("You are a helpful, respectful and honest assistant.")
349
+ # conv.append_message(conv.roles[0], "Hello!")
350
+ # conv.append_message(conv.roles[1], "Hi!")
351
+ # conv.append_message(conv.roles[0], "How are you?")
352
+ # conv.append_message(conv.roles[1], None)
353
+ # print(conv.get_prompt())
354
+
355
+ # print("\n")
356
+
357
+ print("deepseek template:")
358
+ conv = get_conv_template("deepseek")
359
+ conv.append_message(conv.roles[0], "Hello!")
360
+ conv.append_message(conv.roles[1], "Hi! This is Tony.")
361
+ conv.append_message(conv.roles[0], "Who are you?")
362
+ conv.append_message(conv.roles[1], "I am a helpful assistant.")
363
+ conv.append_message(conv.roles[0], "How are you?")
364
+ conv.append_message(conv.roles[1], None)
365
+ print(conv.get_prompt())
image_processing_vlm.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import List, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torchvision
25
+ import torchvision.transforms.functional
26
+ from PIL import Image
27
+ from transformers import AutoImageProcessor, PretrainedConfig
28
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
29
+ from transformers.image_utils import to_numpy_array
30
+ from transformers.utils import logging
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
35
+ IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
38
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
39
+
40
+
41
+ def expand2square(pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+
55
+ class VLMImageProcessorConfig(PretrainedConfig):
56
+ model_type = "deepseek_vlm"
57
+ image_size: int
58
+ min_size: int
59
+ image_mean: Union[Tuple[float, float, float], List[float]]
60
+ image_std: Union[Tuple[float, float, float], List[float]]
61
+ rescale_factor: float
62
+ do_normalize: bool
63
+
64
+ def __init__(
65
+ self,
66
+ image_size: int,
67
+ min_size: int = 14,
68
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
69
+ 0.48145466,
70
+ 0.4578275,
71
+ 0.40821073,
72
+ ),
73
+ image_std: Union[Tuple[float, float, float], List[float]] = (
74
+ 0.26862954,
75
+ 0.26130258,
76
+ 0.27577711,
77
+ ),
78
+ rescale_factor: float = 1.0 / 255.0,
79
+ do_normalize: bool = True,
80
+ **kwargs,
81
+ ):
82
+ self.image_size = image_size
83
+ self.min_size = min_size
84
+ self.image_mean = image_mean
85
+ self.image_std = image_std
86
+ self.rescale_factor = rescale_factor
87
+ self.do_normalize = do_normalize
88
+
89
+ super().__init__(**kwargs)
90
+
91
+
92
+ class VLMImageProcessor(BaseImageProcessor):
93
+ model_input_names = ["pixel_values"]
94
+
95
+ def __init__(
96
+ self,
97
+ image_size: int,
98
+ min_size: int = 14,
99
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
100
+ 0.48145466,
101
+ 0.4578275,
102
+ 0.40821073,
103
+ ),
104
+ image_std: Union[Tuple[float, float, float], List[float]] = (
105
+ 0.26862954,
106
+ 0.26130258,
107
+ 0.27577711,
108
+ ),
109
+ rescale_factor: float = 1.0 / 255.0,
110
+ do_normalize: bool = True,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(**kwargs)
114
+
115
+ self.image_size = image_size
116
+ self.rescale_factor = rescale_factor
117
+ self.image_mean = image_mean
118
+ self.image_std = image_std
119
+ self.min_size = min_size
120
+ self.do_normalize = do_normalize
121
+
122
+ if image_mean is None:
123
+ self.background_color = (127, 127, 127)
124
+ else:
125
+ self.background_color = tuple([int(x * 255) for x in image_mean])
126
+
127
+ def resize(self, pil_img: Image) -> np.ndarray:
128
+ """
129
+
130
+ Args:
131
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
132
+
133
+ Returns:
134
+ x (np.ndarray): [3, self.image_size, self.image_size]
135
+ """
136
+
137
+ width, height = pil_img.size
138
+ max_size = max(width, height)
139
+
140
+ size = [
141
+ max(int(height / max_size * self.image_size), self.min_size),
142
+ max(int(width / max_size * self.image_size), self.min_size),
143
+ ]
144
+
145
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
146
+ print(f"orig size = {pil_img.size}, new size = {size}")
147
+ raise ValueError("Invalid size!")
148
+
149
+ pil_img = torchvision.transforms.functional.resize(
150
+ pil_img,
151
+ size,
152
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
153
+ antialias=True,
154
+ )
155
+
156
+ pil_img = expand2square(pil_img, self.background_color)
157
+ x = to_numpy_array(pil_img)
158
+
159
+ # [H, W, 3] -> [3, H, W]
160
+ x = np.transpose(x, (2, 0, 1))
161
+
162
+ return x
163
+
164
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
165
+ # resize and pad to [self.image_size, self.image_size]
166
+ # then convert from [H, W, 3] to [3, H, W]
167
+ images: List[np.ndarray] = [self.resize(image) for image in images]
168
+
169
+ # resacle from [0, 255] -> [0, 1]
170
+ images = [
171
+ self.rescale(
172
+ image=image,
173
+ scale=self.rescale_factor,
174
+ input_data_format="channels_first",
175
+ )
176
+ for image in images
177
+ ]
178
+
179
+ # normalize
180
+ if self.do_normalize:
181
+ images = [
182
+ self.normalize(
183
+ image=image,
184
+ mean=self.image_mean,
185
+ std=self.image_std,
186
+ input_data_format="channels_first",
187
+ )
188
+ for image in images
189
+ ]
190
+
191
+ data = {"pixel_values": images}
192
+ return BatchFeature(data=data, tensor_type=return_tensors)
193
+
194
+ @property
195
+ def default_shape(self):
196
+ return [3, self.image_size, self.image_size]
197
+
198
+
199
+ AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ image_processor = VLMImageProcessor(
204
+ image_size=1024,
205
+ image_mean=IMAGENET_INCEPTION_MEAN,
206
+ image_std=IMAGENET_INCEPTION_STD,
207
+ do_normalize=True,
208
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb1579e90ef1537a0ec49d73f910158365a9dd73ca55a1f178278dfa7d7727f0
3
+ size 317149484
modeling_vlm.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ import torch
21
+ from einops import rearrange
22
+ from transformers import (
23
+ AutoConfig,
24
+ AutoModelForCausalLM,
25
+ LlamaConfig,
26
+ LlamaForCausalLM,
27
+ PreTrainedModel,
28
+ GenerationMixin
29
+ )
30
+ from transformers.configuration_utils import PretrainedConfig
31
+
32
+ from .clip_encoder import CLIPVisionTower
33
+ from .siglip_vit import create_siglip_vit
34
+ from .projector import MlpProjector
35
+ from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
36
+
37
+
38
+ class vision_head(torch.nn.Module):
39
+ def __init__(self, params):
40
+ super().__init__()
41
+ self.output_mlp_projector = torch.nn.Linear(
42
+ params.n_embed, params.image_token_embed
43
+ )
44
+ self.vision_activation = torch.nn.GELU()
45
+ self.vision_head = torch.nn.Linear(
46
+ params.image_token_embed, params.image_token_size
47
+ )
48
+
49
+ def forward(self, x):
50
+ x = self.output_mlp_projector(x)
51
+ x = self.vision_activation(x)
52
+ x = self.vision_head(x)
53
+ return x
54
+
55
+
56
+ def model_name_to_cls(cls_name):
57
+ if "MlpProjector" in cls_name:
58
+ cls = MlpProjector
59
+
60
+ elif "CLIPVisionTower" in cls_name:
61
+ cls = CLIPVisionTower
62
+
63
+ elif "VQ" in cls_name:
64
+ from janus.models.vq_model import VQ_models
65
+
66
+ cls = VQ_models[cls_name]
67
+ elif "vision_head" in cls_name:
68
+ cls = vision_head
69
+ else:
70
+ raise ValueError(f"class_name {cls_name} is invalid.")
71
+
72
+ return cls
73
+
74
+
75
+ class MultiModalityPreTrainedModel(PreTrainedModel):
76
+ config_class = MultiModalityConfig
77
+ base_model_prefix = "multi_modality"
78
+ _no_split_modules = []
79
+ _skip_keys_device_placement = "past_key_values"
80
+
81
+
82
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
83
+ def __init__(self, config: MultiModalityConfig):
84
+ super().__init__(config)
85
+
86
+ vision_config = config.vision_config
87
+ vision_cls = model_name_to_cls(vision_config.cls)
88
+ self.vision_model = vision_cls(**vision_config.params)
89
+
90
+ aligner_config = config.aligner_config
91
+ aligner_cls = model_name_to_cls(aligner_config.cls)
92
+ self.aligner = aligner_cls(aligner_config.params)
93
+
94
+ gen_vision_config = config.gen_vision_config
95
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
96
+ self.gen_vision_model = gen_vision_cls()
97
+
98
+ gen_aligner_config = config.gen_aligner_config
99
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
100
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
101
+
102
+ gen_head_config = config.gen_head_config
103
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
104
+ self.gen_head = gen_head_cls(gen_head_config.params)
105
+
106
+ self.gen_embed = torch.nn.Embedding(
107
+ gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
108
+ )
109
+
110
+ language_config = config.language_config
111
+ self.language_model = LlamaForCausalLM(language_config)
112
+
113
+ def prepare_inputs_embeds(
114
+ self,
115
+ input_ids: torch.LongTensor,
116
+ pixel_values: torch.FloatTensor,
117
+ images_seq_mask: torch.LongTensor,
118
+ images_emb_mask: torch.LongTensor,
119
+ **kwargs,
120
+ ):
121
+ """
122
+
123
+ Args:
124
+ input_ids (torch.LongTensor): [b, T]
125
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
126
+ images_seq_mask (torch.BoolTensor): [b, T]
127
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
128
+
129
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
130
+
131
+ Returns:
132
+ input_embeds (torch.Tensor): [b, T, D]
133
+ """
134
+
135
+ bs, n = pixel_values.shape[0:2]
136
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
137
+ # [b x n, T2, D]
138
+ images_embeds = self.aligner(self.vision_model(images))
139
+
140
+ # [b x n, T2, D] -> [b, n x T2, D]
141
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
142
+ # [b, n, T2] -> [b, n x T2]
143
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
144
+
145
+ # [b, T, D]
146
+ input_ids[input_ids < 0] = 0 # ignore the image embeddings
147
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
148
+
149
+ # replace with the image embeddings
150
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
151
+
152
+ return inputs_embeds
153
+
154
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
155
+ return self.gen_aligner(self.gen_embed(image_ids))
156
+
157
+ def forward(
158
+ self,
159
+ input_ids,
160
+ pixel_values=None,
161
+ past_key_values=None,
162
+ inputs_embeds=None,
163
+ attention_mask=None,
164
+ position_ids=None,
165
+ images_seq_mask=None,
166
+ images_emb_mask=None,
167
+ **kwargs,
168
+ ):
169
+ if inputs_embeds is None:
170
+ inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
171
+ return self.language_model.forward(
172
+ input_ids=None,
173
+ inputs_embeds=inputs_embeds,
174
+ attention_mask=attention_mask,
175
+ position_ids=position_ids,
176
+ past_key_values=past_key_values,
177
+ **kwargs,
178
+ )
179
+
180
+ def generate(
181
+ self,
182
+ input_ids=None,
183
+ pixel_values=None,
184
+ past_key_values=None,
185
+ inputs_embeds=None,
186
+ attention_mask=None,
187
+ position_ids=None,
188
+ images_seq_mask=None,
189
+ images_emb_mask=None,
190
+ **kwargs
191
+ ):
192
+ if inputs_embeds is None:
193
+ inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
194
+ return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
195
+
196
+
197
+
198
+
199
+ AutoConfig.register("vision", VisionConfig)
200
+ AutoConfig.register("aligner", AlignerConfig)
201
+ AutoConfig.register("gen_vision", GenVisionConfig)
202
+ AutoConfig.register("gen_aligner", GenAlignerConfig)
203
+ AutoConfig.register("gen_head", GenHeadConfig)
204
+ AutoConfig.register("multi_modality", MultiModalityConfig)
205
+ AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
preprocessor_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "background_color": [
3
+ 127,
4
+ 127,
5
+ 127
6
+ ],
7
+ "do_normalize": true,
8
+ "image_mean": [
9
+ 0.5,
10
+ 0.5,
11
+ 0.5
12
+ ],
13
+ "image_processor_type": "VLMImageProcessor",
14
+ "image_size": 384,
15
+ "image_std": [
16
+ 0.5,
17
+ 0.5,
18
+ 0.5
19
+ ],
20
+ "min_size": 14,
21
+ "processor_class": "VLChatProcessor",
22
+ "rescale_factor": 0.00392156862745098,
23
+ "auto_map": {
24
+ "AutoProcessor": "processing_vlm.VLChatProcessor"
25
+ }
26
+ }
processing_vlm.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Dict, List
22
+
23
+ import torch
24
+ from PIL.Image import Image
25
+ from transformers import LlamaTokenizerFast
26
+ from transformers.processing_utils import ProcessorMixin
27
+
28
+ from .image_processing_vlm import VLMImageProcessor
29
+ from .conversation import get_conv_template
30
+
31
+
32
+ class DictOutput(object):
33
+ def keys(self):
34
+ return self.__dict__.keys()
35
+
36
+ def __getitem__(self, item):
37
+ return self.__dict__[item]
38
+
39
+ def __setitem__(self, key, value):
40
+ self.__dict__[key] = value
41
+
42
+
43
+ @dataclass
44
+ class VLChatProcessorOutput(DictOutput):
45
+ sft_format: str
46
+ input_ids: torch.Tensor
47
+ pixel_values: torch.Tensor
48
+ num_image_tokens: torch.IntTensor
49
+
50
+ def __len__(self):
51
+ return len(self.input_ids)
52
+
53
+
54
+ @dataclass
55
+ class BatchedVLChatProcessorOutput(DictOutput):
56
+ sft_format: List[str]
57
+ input_ids: torch.Tensor
58
+ pixel_values: torch.Tensor
59
+ attention_mask: torch.Tensor
60
+ images_seq_mask: torch.BoolTensor
61
+ images_emb_mask: torch.BoolTensor
62
+
63
+ def to(self, device, dtype=torch.bfloat16):
64
+ self.input_ids = self.input_ids.to(device)
65
+ self.attention_mask = self.attention_mask.to(device)
66
+ self.images_seq_mask = self.images_seq_mask.to(device)
67
+ self.images_emb_mask = self.images_emb_mask.to(device)
68
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
69
+ return self
70
+
71
+
72
+ class VLChatProcessor(ProcessorMixin):
73
+ image_processor_class = "AutoImageProcessor"
74
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
75
+
76
+ attributes = ["image_processor", "tokenizer"]
77
+
78
+ system_prompt = (
79
+ "You are a helpful language and vision assistant. "
80
+ "You are able to understand the visual content that the user provides, "
81
+ "and assist the user with a variety of tasks using natural language."
82
+ )
83
+
84
+ def __init__(
85
+ self,
86
+ image_processor: VLMImageProcessor,
87
+ tokenizer: LlamaTokenizerFast,
88
+ image_tag: str = "<image_placeholder>",
89
+ image_start_tag: str = "<begin_of_image>",
90
+ image_end_tag: str = "<end_of_image>",
91
+ pad_tag: str = "<|▁pad▁|>",
92
+ num_image_tokens: int = 576,
93
+ add_special_token: bool = False,
94
+ sft_format: str = "deepseek",
95
+ mask_prompt: bool = True,
96
+ ignore_id: int = -100,
97
+ **kwargs,
98
+ ):
99
+ self.image_processor = image_processor
100
+ self.tokenizer = tokenizer
101
+
102
+ image_id = self.tokenizer.vocab.get(image_tag)
103
+ if image_id is None:
104
+ special_tokens = [image_tag]
105
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
106
+ self.tokenizer.add_special_tokens(special_tokens_dict)
107
+ print(f"Add image tag = {image_tag} to the tokenizer")
108
+
109
+ self.image_tag = image_tag
110
+ self.image_start_tag = image_start_tag
111
+ self.image_end_tag = image_end_tag
112
+ self.pad_tag = pad_tag
113
+
114
+ self.num_image_tokens = num_image_tokens
115
+ self.add_special_token = add_special_token
116
+ self.sft_format = sft_format
117
+ self.mask_prompt = mask_prompt
118
+ self.ignore_id = ignore_id
119
+
120
+ super().__init__(
121
+ image_processor,
122
+ tokenizer,
123
+ image_tag,
124
+ num_image_tokens,
125
+ add_special_token,
126
+ sft_format,
127
+ mask_prompt,
128
+ ignore_id,
129
+ **kwargs,
130
+ )
131
+
132
+ def new_chat_template(self):
133
+ conv = get_conv_template(self.sft_format)
134
+ conv.set_system_message(self.system_prompt)
135
+ return conv
136
+
137
+ def apply_sft_template_for_multi_turn_prompts(
138
+ self,
139
+ conversations: List[Dict[str, str]],
140
+ sft_format: str = "deepseek",
141
+ system_prompt: str = "",
142
+ ):
143
+ """
144
+ Applies the SFT template to conversation.
145
+
146
+ An example of conversation:
147
+ conversation = [
148
+ {
149
+ "role": "User",
150
+ "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
151
+ "images": [
152
+ "./multi-images/attribute_comparison_1.png",
153
+ "./multi-images/attribute_comparison_2.png"
154
+ ]
155
+ },
156
+ {
157
+ "role": "Assistant",
158
+ "content": ""
159
+ }
160
+ ]
161
+
162
+ Args:
163
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
164
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
165
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
166
+
167
+ Returns:
168
+ sft_prompt (str): The formatted text.
169
+ """
170
+
171
+ conv = get_conv_template(sft_format)
172
+ conv.set_system_message(system_prompt)
173
+ for message in conversations:
174
+ conv.append_message(message["role"], message["content"].strip())
175
+ sft_prompt = conv.get_prompt().strip()
176
+
177
+ return sft_prompt
178
+
179
+ @property
180
+ def image_token(self):
181
+ return self.image_tag
182
+
183
+ @property
184
+ def image_id(self):
185
+ image_id = self.tokenizer.vocab.get(self.image_tag)
186
+ return image_id
187
+
188
+ @property
189
+ def image_start_id(self):
190
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
191
+ return image_start_id
192
+
193
+ @property
194
+ def image_end_id(self):
195
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
196
+ return image_end_id
197
+
198
+ @property
199
+ def image_start_token(self):
200
+ return self.image_start_tag
201
+
202
+ @property
203
+ def image_end_token(self):
204
+ return self.image_end_tag
205
+
206
+ @property
207
+ def pad_id(self):
208
+ pad_id = self.tokenizer.vocab.get(self.pad_tag)
209
+ # pad_id = self.tokenizer.pad_token_id
210
+ # if pad_id is None:
211
+ # pad_id = self.tokenizer.eos_token_id
212
+
213
+ return pad_id
214
+
215
+ def add_image_token(
216
+ self,
217
+ image_indices: List[int],
218
+ input_ids: torch.LongTensor,
219
+ ):
220
+ """
221
+
222
+ Args:
223
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
224
+ input_ids (torch.LongTensor): [N]
225
+
226
+ Returns:
227
+ input_ids (torch.LongTensor): [N + image tokens]
228
+ num_image_tokens (torch.IntTensor): [n_images]
229
+ """
230
+
231
+ input_slices = []
232
+
233
+ start = 0
234
+ for index in image_indices:
235
+ if self.add_special_token:
236
+ end = index + 1
237
+ else:
238
+ end = index
239
+
240
+ # original text tokens
241
+ input_slices.append(input_ids[start:end])
242
+
243
+ # add boi, image tokens, eoi and set the mask as False
244
+ input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
245
+ input_slices.append(
246
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
247
+ )
248
+ input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
249
+ start = index + 1
250
+
251
+ # the left part
252
+ input_slices.append(input_ids[start:])
253
+
254
+ # concat all slices
255
+ input_ids = torch.cat(input_slices, dim=0)
256
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
257
+
258
+ return input_ids, num_image_tokens
259
+
260
+ def process_one(
261
+ self,
262
+ prompt: str = None,
263
+ conversations: List[Dict[str, str]] = None,
264
+ images: List[Image] = None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+
269
+ Args:
270
+ prompt (str): the formatted prompt;
271
+ conversations (List[Dict]): conversations with a list of messages;
272
+ images (List[ImageType]): the list of images;
273
+ **kwargs:
274
+
275
+ Returns:
276
+ outputs (BaseProcessorOutput): the output of the processor,
277
+ - input_ids (torch.LongTensor): [N + image tokens]
278
+ - target_ids (torch.LongTensor): [N + image tokens]
279
+ - images (torch.FloatTensor): [n_images, 3, H, W]
280
+ - image_id (int): the id of the image token
281
+ - num_image_tokens (List[int]): the number of image tokens
282
+ """
283
+
284
+ assert (
285
+ prompt is None or conversations is None
286
+ ), "prompt and conversations cannot be used at the same time."
287
+
288
+ if prompt is None:
289
+ # apply sft format
290
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
291
+ conversations=conversations,
292
+ sft_format=self.sft_format,
293
+ system_prompt=self.system_prompt,
294
+ )
295
+ else:
296
+ sft_format = prompt
297
+
298
+ # tokenize
299
+ input_ids = self.tokenizer.encode(sft_format)
300
+ input_ids = torch.LongTensor(input_ids)
301
+
302
+ # add image tokens to the input_ids
303
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
304
+ image_indices = image_token_mask.nonzero()
305
+ input_ids, num_image_tokens = self.add_image_token(
306
+ image_indices=image_indices,
307
+ input_ids=input_ids,
308
+ )
309
+
310
+ # load images
311
+ images_outputs = self.image_processor(images, return_tensors="pt")
312
+
313
+ prepare = VLChatProcessorOutput(
314
+ sft_format=sft_format,
315
+ input_ids=input_ids,
316
+ pixel_values=images_outputs.pixel_values,
317
+ num_image_tokens=num_image_tokens,
318
+ )
319
+
320
+ return prepare
321
+
322
+ def __call__(
323
+ self,
324
+ *,
325
+ prompt: str = None,
326
+ conversations: List[Dict[str, str]] = None,
327
+ images: List[Image] = None,
328
+ force_batchify: bool = True,
329
+ **kwargs,
330
+ ):
331
+ """
332
+
333
+ Args:
334
+ prompt (str): the formatted prompt;
335
+ conversations (List[Dict]): conversations with a list of messages;
336
+ images (List[ImageType]): the list of images;
337
+ force_batchify (bool): force batchify the inputs;
338
+ **kwargs:
339
+
340
+ Returns:
341
+ outputs (BaseProcessorOutput): the output of the processor,
342
+ - input_ids (torch.LongTensor): [N + image tokens]
343
+ - images (torch.FloatTensor): [n_images, 3, H, W]
344
+ - image_id (int): the id of the image token
345
+ - num_image_tokens (List[int]): the number of image tokens
346
+ """
347
+
348
+ prepare = self.process_one(
349
+ prompt=prompt, conversations=conversations, images=images
350
+ )
351
+
352
+ if force_batchify:
353
+ prepare = self.batchify([prepare])
354
+
355
+ return prepare
356
+
357
+ def batchify(
358
+ self, prepare_list: List[VLChatProcessorOutput]
359
+ ) -> BatchedVLChatProcessorOutput:
360
+ """
361
+ Preprocesses the inputs for multimodal inference.
362
+
363
+ Args:
364
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
365
+
366
+ Returns:
367
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
368
+ """
369
+
370
+ batch_size = len(prepare_list)
371
+ sft_format = []
372
+ n_images = []
373
+ seq_lens = []
374
+ for prepare in prepare_list:
375
+ n_images.append(len(prepare.num_image_tokens))
376
+ seq_lens.append(len(prepare))
377
+
378
+ input_token_max_len = max(seq_lens)
379
+ max_n_images = max(1, max(n_images))
380
+
381
+ batched_input_ids = torch.full(
382
+ (batch_size, input_token_max_len), self.pad_id
383
+ ).long() # FIXME
384
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
385
+ batched_pixel_values = torch.zeros(
386
+ (batch_size, max_n_images, *self.image_processor.default_shape)
387
+ ).float()
388
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
389
+ batched_images_emb_mask = torch.zeros(
390
+ (batch_size, max_n_images, self.num_image_tokens)
391
+ ).bool()
392
+
393
+ for i, prepare in enumerate(prepare_list):
394
+ input_ids = prepare.input_ids
395
+ seq_len = len(prepare)
396
+ n_image = len(prepare.num_image_tokens)
397
+ # left-padding
398
+ batched_attention_mask[i, -seq_len:] = 1
399
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
400
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
401
+
402
+ if n_image > 0:
403
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
404
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
405
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
406
+
407
+ sft_format.append(prepare.sft_format)
408
+
409
+ batched_prepares = BatchedVLChatProcessorOutput(
410
+ input_ids=batched_input_ids,
411
+ attention_mask=batched_attention_mask,
412
+ pixel_values=batched_pixel_values,
413
+ images_seq_mask=batched_images_seq_mask,
414
+ images_emb_mask=batched_images_emb_mask,
415
+ sft_format=sft_format,
416
+ )
417
+
418
+ return batched_prepares
processor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_special_token": false,
3
+ "ignore_id": -100,
4
+ "image_end_tag": "<end_of_image>",
5
+ "image_start_tag": "<begin_of_image>",
6
+ "image_tag": "<image_placeholder>",
7
+ "mask_prompt": true,
8
+ "num_image_tokens": 576,
9
+ "processor_class": "VLChatProcessor",
10
+ "sft_format": "deepseek",
11
+ "auto_map": {
12
+ "AutoProcessor": "processing_vlm.VLChatProcessor"
13
+ }
14
+ }
projector.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from typing import Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from .attrdict_config import AttrDict
25
+
26
+
27
+ class MlpProjector(nn.Module):
28
+ def __init__(self, cfg):
29
+ super().__init__()
30
+
31
+ self.cfg = cfg
32
+
33
+ if cfg.projector_type == "identity":
34
+ modules = nn.Identity()
35
+
36
+ elif cfg.projector_type == "linear":
37
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
38
+
39
+ elif cfg.projector_type == "mlp_gelu":
40
+ mlp_depth = cfg.get("depth", 1)
41
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
45
+ modules = nn.Sequential(*modules)
46
+
47
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
48
+ mlp_depth = cfg.get("depth", 1)
49
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
50
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
51
+
52
+ modules = []
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(nn.GELU())
55
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
56
+ modules = nn.Sequential(*modules)
57
+
58
+ else:
59
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
60
+
61
+ self.layers = modules
62
+
63
+ def forward(
64
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
65
+ ):
66
+ """
67
+
68
+ Args:
69
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
70
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
71
+ otherwise it is the feature from the single vision encoder.
72
+
73
+ Returns:
74
+ x (torch.Tensor): [b, s, c]
75
+ """
76
+
77
+ if isinstance(x_or_tuple, tuple):
78
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
79
+ high_x, low_x = x_or_tuple
80
+ high_x = self.high_up_proj(high_x)
81
+ low_x = self.low_up_proj(low_x)
82
+ x = torch.concat([high_x, low_x], dim=-1)
83
+ else:
84
+ x = x_or_tuple
85
+
86
+ return self.layers(x)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ cfg = AttrDict(
91
+ input_dim=1024,
92
+ n_embed=2048,
93
+ depth=2,
94
+ projector_type="low_high_hybrid_split_mlp_gelu",
95
+ )
96
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
97
+
98
+ m = MlpProjector(cfg)
99
+ out = m(inputs)
100
+ print(out.shape)
siglip_vit.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
21
+ import math
22
+ import warnings
23
+ from dataclasses import dataclass
24
+ from functools import partial
25
+ from typing import (
26
+ Callable,
27
+ Dict,
28
+ Final,
29
+ List,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from timm.layers import (
43
+ AttentionPoolLatent,
44
+ DropPath,
45
+ LayerType,
46
+ Mlp,
47
+ PatchDropout,
48
+ PatchEmbed,
49
+ resample_abs_pos_embed,
50
+ )
51
+ from timm.models._manipulate import checkpoint_seq, named_apply
52
+
53
+
54
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
55
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
56
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
57
+ def norm_cdf(x):
58
+ # Computes standard normal cumulative distribution function
59
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
60
+
61
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
62
+ warnings.warn(
63
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
64
+ "The distribution of values may be incorrect.",
65
+ stacklevel=2,
66
+ )
67
+
68
+ with torch.no_grad():
69
+ # Values are generated by using a truncated uniform distribution and
70
+ # then using the inverse CDF for the normal distribution.
71
+ # Get upper and lower cdf values
72
+ l = norm_cdf((a - mean) / std) # noqa: E741
73
+ u = norm_cdf((b - mean) / std)
74
+
75
+ # Uniformly fill tensor with values from [l, u], then translate to
76
+ # [2l-1, 2u-1].
77
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
78
+
79
+ # Use inverse cdf transform for normal distribution to get truncated
80
+ # standard normal
81
+ tensor.erfinv_()
82
+
83
+ # Transform to proper mean, std
84
+ tensor.mul_(std * math.sqrt(2.0))
85
+ tensor.add_(mean)
86
+
87
+ # Clamp to ensure it's in the proper range
88
+ tensor.clamp_(min=a, max=b)
89
+ return tensor
90
+
91
+
92
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
93
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
94
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
95
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype.
96
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
97
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
98
+ with values outside :math:`[a, b]` redrawn until they are within
99
+ the bounds. The method used for generating the random values works
100
+ best when :math:`a \leq \text{mean} \leq b`.
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ Examples:
108
+ >>> w = torch.empty(3, 5)
109
+ >>> nn.init.trunc_normal_(w)
110
+ """
111
+
112
+ with torch.no_grad():
113
+ dtype = tensor.dtype
114
+ tensor_fp32 = tensor.float()
115
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
116
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
117
+ tensor.copy_(tensor_dtype)
118
+
119
+
120
+ def init_weights(self):
121
+ if self.pos_embed is not None:
122
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
123
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
124
+
125
+
126
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
127
+ """ViT weight initialization, original timm impl (for reproducibility)"""
128
+ if isinstance(module, nn.Linear):
129
+ trunc_normal_(module.weight, std=0.02)
130
+ if module.bias is not None:
131
+ nn.init.zeros_(module.bias)
132
+ elif hasattr(module, "init_weights"):
133
+ module.init_weights()
134
+
135
+
136
+ class Attention(nn.Module):
137
+ fused_attn: Final[bool]
138
+
139
+ def __init__(
140
+ self,
141
+ dim: int,
142
+ num_heads: int = 8,
143
+ qkv_bias: bool = False,
144
+ qk_norm: bool = False,
145
+ attn_drop: float = 0.0,
146
+ proj_drop: float = 0.0,
147
+ norm_layer: nn.Module = nn.LayerNorm,
148
+ ) -> None:
149
+ super().__init__()
150
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
151
+ self.num_heads = num_heads
152
+ self.head_dim = dim // num_heads
153
+ self.scale = self.head_dim**-0.5
154
+ # self.fused_attn = use_fused_attn()
155
+ self.fused_attn = True
156
+
157
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
158
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
159
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
160
+ self.attn_drop = nn.Dropout(attn_drop)
161
+ self.proj = nn.Linear(dim, dim)
162
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
163
+
164
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
165
+ B, N, C = x.shape
166
+ qkv = (
167
+ self.qkv(x)
168
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
169
+ .permute(2, 0, 3, 1, 4)
170
+ )
171
+ q, k, v = qkv.unbind(0)
172
+ q, k = self.q_norm(q), self.k_norm(k)
173
+
174
+ if self.fused_attn:
175
+ x = F.scaled_dot_product_attention(
176
+ q,
177
+ k,
178
+ v,
179
+ dropout_p=self.attn_drop.p if self.training else 0.0,
180
+ )
181
+ else:
182
+ q = q * self.scale
183
+ attn = q @ k.transpose(-2, -1)
184
+ attn = attn.softmax(dim=-1)
185
+ attn = self.attn_drop(attn)
186
+ x = attn @ v
187
+
188
+ x = x.transpose(1, 2).reshape(B, N, C)
189
+ x = self.proj(x)
190
+ x = self.proj_drop(x)
191
+ return x
192
+
193
+
194
+ class LayerScale(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ init_values: float = 1e-5,
199
+ inplace: bool = False,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.inplace = inplace
203
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
204
+
205
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
206
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
207
+
208
+
209
+ class Block(nn.Module):
210
+ def __init__(
211
+ self,
212
+ dim: int,
213
+ num_heads: int,
214
+ mlp_ratio: float = 4.0,
215
+ qkv_bias: bool = False,
216
+ qk_norm: bool = False,
217
+ proj_drop: float = 0.0,
218
+ attn_drop: float = 0.0,
219
+ init_values: Optional[float] = None,
220
+ drop_path: float = 0.0,
221
+ act_layer: nn.Module = nn.GELU,
222
+ norm_layer: nn.Module = nn.LayerNorm,
223
+ mlp_layer: nn.Module = Mlp,
224
+ ) -> None:
225
+ super().__init__()
226
+ self.norm1 = norm_layer(dim)
227
+ self.attn = Attention(
228
+ dim,
229
+ num_heads=num_heads,
230
+ qkv_bias=qkv_bias,
231
+ qk_norm=qk_norm,
232
+ attn_drop=attn_drop,
233
+ proj_drop=proj_drop,
234
+ norm_layer=norm_layer,
235
+ )
236
+ self.ls1 = (
237
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
238
+ )
239
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
240
+
241
+ self.norm2 = norm_layer(dim)
242
+ self.mlp = mlp_layer(
243
+ in_features=dim,
244
+ hidden_features=int(dim * mlp_ratio),
245
+ act_layer=act_layer,
246
+ drop=proj_drop,
247
+ )
248
+ self.ls2 = (
249
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
250
+ )
251
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
252
+
253
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
254
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
255
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
256
+ return x
257
+
258
+
259
+ class VisionTransformer(nn.Module):
260
+ """Vision Transformer
261
+
262
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
263
+ - https://arxiv.org/abs/2010.11929
264
+ """
265
+
266
+ dynamic_img_size: Final[bool]
267
+
268
+ def __init__(
269
+ self,
270
+ img_size: Union[int, Tuple[int, int]] = 224,
271
+ patch_size: Union[int, Tuple[int, int]] = 16,
272
+ in_chans: int = 3,
273
+ num_classes: int = 1000,
274
+ global_pool: Literal["", "avg", "token", "map"] = "token",
275
+ embed_dim: int = 768,
276
+ depth: int = 12,
277
+ num_heads: int = 12,
278
+ mlp_ratio: float = 4.0,
279
+ qkv_bias: bool = True,
280
+ qk_norm: bool = False,
281
+ init_values: Optional[float] = None,
282
+ class_token: bool = True,
283
+ no_embed_class: bool = False,
284
+ reg_tokens: int = 0,
285
+ pre_norm: bool = False,
286
+ fc_norm: Optional[bool] = None,
287
+ dynamic_img_size: bool = False,
288
+ dynamic_img_pad: bool = False,
289
+ drop_rate: float = 0.0,
290
+ pos_drop_rate: float = 0.0,
291
+ patch_drop_rate: float = 0.0,
292
+ proj_drop_rate: float = 0.0,
293
+ attn_drop_rate: float = 0.0,
294
+ drop_path_rate: float = 0.0,
295
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
296
+ embed_layer: Callable = PatchEmbed,
297
+ norm_layer: Optional[LayerType] = None,
298
+ act_layer: Optional[LayerType] = None,
299
+ block_fn: Type[nn.Module] = Block,
300
+ mlp_layer: Type[nn.Module] = Mlp,
301
+ ignore_head: bool = False,
302
+ ) -> None:
303
+ """
304
+ Args:
305
+ img_size: Input image size.
306
+ patch_size: Patch size.
307
+ in_chans: Number of image input channels.
308
+ num_classes: Mumber of classes for classification head.
309
+ global_pool: Type of global pooling for final sequence (default: 'token').
310
+ embed_dim: Transformer embedding dimension.
311
+ depth: Depth of transformer.
312
+ num_heads: Number of attention heads.
313
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
314
+ qkv_bias: Enable bias for qkv projections if True.
315
+ init_values: Layer-scale init values (layer-scale enabled if not None).
316
+ class_token: Use class token.
317
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
318
+ reg_tokens: Number of register tokens.
319
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
320
+ drop_rate: Head dropout rate.
321
+ pos_drop_rate: Position embedding dropout rate.
322
+ attn_drop_rate: Attention dropout rate.
323
+ drop_path_rate: Stochastic depth rate.
324
+ weight_init: Weight initialization scheme.
325
+ embed_layer: Patch embedding layer.
326
+ norm_layer: Normalization layer.
327
+ act_layer: MLP activation layer.
328
+ block_fn: Transformer block layer.
329
+ """
330
+ super().__init__()
331
+ assert global_pool in ("", "avg", "token", "map")
332
+ assert class_token or global_pool != "token"
333
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
334
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
335
+ # act_layer = get_act_layer(act_layer) or nn.GELU
336
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
337
+ act_layer = nn.GELU
338
+
339
+ self.num_classes = num_classes
340
+ self.global_pool = global_pool
341
+ self.num_features = self.embed_dim = (
342
+ embed_dim # num_features for consistency with other models
343
+ )
344
+ self.num_prefix_tokens = 1 if class_token else 0
345
+ self.num_prefix_tokens += reg_tokens
346
+ self.num_reg_tokens = reg_tokens
347
+ self.has_class_token = class_token
348
+ self.no_embed_class = (
349
+ no_embed_class # don't embed prefix positions (includes reg)
350
+ )
351
+ self.dynamic_img_size = dynamic_img_size
352
+ self.grad_checkpointing = False
353
+ self.ignore_head = ignore_head
354
+
355
+ embed_args = {}
356
+ if dynamic_img_size:
357
+ # flatten deferred until after pos embed
358
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
359
+ self.patch_embed = embed_layer(
360
+ img_size=img_size,
361
+ patch_size=patch_size,
362
+ in_chans=in_chans,
363
+ embed_dim=embed_dim,
364
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
365
+ dynamic_img_pad=dynamic_img_pad,
366
+ **embed_args,
367
+ )
368
+ num_patches = self.patch_embed.num_patches
369
+
370
+ self.cls_token = (
371
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
372
+ )
373
+ self.reg_token = (
374
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
375
+ )
376
+ embed_len = (
377
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
378
+ )
379
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
380
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
381
+ if patch_drop_rate > 0:
382
+ self.patch_drop = PatchDropout(
383
+ patch_drop_rate,
384
+ num_prefix_tokens=self.num_prefix_tokens,
385
+ )
386
+ else:
387
+ self.patch_drop = nn.Identity()
388
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
389
+
390
+ dpr = [
391
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
392
+ ] # stochastic depth decay rule
393
+ self.blocks = nn.Sequential(
394
+ *[
395
+ block_fn(
396
+ dim=embed_dim,
397
+ num_heads=num_heads,
398
+ mlp_ratio=mlp_ratio,
399
+ qkv_bias=qkv_bias,
400
+ qk_norm=qk_norm,
401
+ init_values=init_values,
402
+ proj_drop=proj_drop_rate,
403
+ attn_drop=attn_drop_rate,
404
+ drop_path=dpr[i],
405
+ norm_layer=norm_layer,
406
+ act_layer=act_layer,
407
+ mlp_layer=mlp_layer,
408
+ )
409
+ for i in range(depth)
410
+ ]
411
+ )
412
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
413
+
414
+ # Classifier Head
415
+ if global_pool == "map":
416
+ AttentionPoolLatent.init_weights = init_weights
417
+ self.attn_pool = AttentionPoolLatent(
418
+ self.embed_dim,
419
+ num_heads=num_heads,
420
+ mlp_ratio=mlp_ratio,
421
+ norm_layer=norm_layer,
422
+ )
423
+ else:
424
+ self.attn_pool = None
425
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
426
+ self.head_drop = nn.Dropout(drop_rate)
427
+ self.head = (
428
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
429
+ )
430
+
431
+ if weight_init != "skip":
432
+ self.init_weights(weight_init)
433
+
434
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
435
+ assert mode in ("jax", "jax_nlhb", "moco", "")
436
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
437
+ trunc_normal_(self.pos_embed, std=0.02)
438
+ if self.cls_token is not None:
439
+ nn.init.normal_(self.cls_token, std=1e-6)
440
+ named_apply(init_weights_vit_timm, self)
441
+
442
+ @torch.jit.ignore
443
+ def no_weight_decay(self) -> Set:
444
+ return {"pos_embed", "cls_token", "dist_token"}
445
+
446
+ @torch.jit.ignore
447
+ def group_matcher(self, coarse: bool = False) -> Dict:
448
+ return dict(
449
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
450
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
451
+ )
452
+
453
+ @torch.jit.ignore
454
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
455
+ self.grad_checkpointing = enable
456
+
457
+ @torch.jit.ignore
458
+ def get_classifier(self) -> nn.Module:
459
+ return self.head
460
+
461
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
462
+ self.num_classes = num_classes
463
+ if global_pool is not None:
464
+ assert global_pool in ("", "avg", "token", "map")
465
+ if global_pool == "map" and self.attn_pool is None:
466
+ assert (
467
+ False
468
+ ), "Cannot currently add attention pooling in reset_classifier()."
469
+ elif global_pool != "map " and self.attn_pool is not None:
470
+ self.attn_pool = None # remove attention pooling
471
+ self.global_pool = global_pool
472
+ self.head = (
473
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
474
+ )
475
+
476
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
477
+ if self.dynamic_img_size:
478
+ B, H, W, C = x.shape
479
+ pos_embed = resample_abs_pos_embed(
480
+ self.pos_embed,
481
+ (H, W),
482
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
483
+ )
484
+ x = x.view(B, -1, C)
485
+ else:
486
+ pos_embed = self.pos_embed
487
+
488
+ to_cat = []
489
+ if self.cls_token is not None:
490
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
491
+ if self.reg_token is not None:
492
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
493
+
494
+ if self.no_embed_class:
495
+ # deit-3, updated JAX (big vision)
496
+ # position embedding does not overlap with class token, add then concat
497
+ x = x + pos_embed
498
+ if to_cat:
499
+ x = torch.cat(to_cat + [x], dim=1)
500
+ else:
501
+ # original timm, JAX, and deit vit impl
502
+ # pos_embed has entry for class token, concat then add
503
+ if to_cat:
504
+ x = torch.cat(to_cat + [x], dim=1)
505
+ x = x + pos_embed
506
+
507
+ return self.pos_drop(x)
508
+
509
+ def _intermediate_layers(
510
+ self,
511
+ x: torch.Tensor,
512
+ n: Union[int, Sequence] = 1,
513
+ ) -> List[torch.Tensor]:
514
+ outputs, num_blocks = [], len(self.blocks)
515
+ take_indices = set(
516
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
517
+ )
518
+
519
+ # forward pass
520
+ x = self.patch_embed(x)
521
+ x = self._pos_embed(x)
522
+ x = self.patch_drop(x)
523
+ x = self.norm_pre(x)
524
+ for i, blk in enumerate(self.blocks):
525
+ x = blk(x)
526
+ if i in take_indices:
527
+ outputs.append(x)
528
+
529
+ return outputs
530
+
531
+ def get_intermediate_layers(
532
+ self,
533
+ x: torch.Tensor,
534
+ n: Union[int, Sequence] = 1,
535
+ reshape: bool = False,
536
+ return_prefix_tokens: bool = False,
537
+ norm: bool = False,
538
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
539
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
540
+ Inspired by DINO / DINOv2 interface
541
+ """
542
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
543
+ outputs = self._intermediate_layers(x, n)
544
+ if norm:
545
+ outputs = [self.norm(out) for out in outputs]
546
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
547
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
548
+
549
+ if reshape:
550
+ grid_size = self.patch_embed.grid_size
551
+ outputs = [
552
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
553
+ .permute(0, 3, 1, 2)
554
+ .contiguous()
555
+ for out in outputs
556
+ ]
557
+
558
+ if return_prefix_tokens:
559
+ return tuple(zip(outputs, prefix_tokens))
560
+ return tuple(outputs)
561
+
562
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
563
+ x = self.patch_embed(x)
564
+ x = self._pos_embed(x)
565
+ x = self.patch_drop(x)
566
+ x = self.norm_pre(x)
567
+ if self.grad_checkpointing and not torch.jit.is_scripting():
568
+ x = checkpoint_seq(self.blocks, x)
569
+ else:
570
+ x = self.blocks(x)
571
+ x = self.norm(x)
572
+ return x
573
+
574
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
575
+ if self.attn_pool is not None:
576
+ x = self.attn_pool(x)
577
+ elif self.global_pool == "avg":
578
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
579
+ elif self.global_pool:
580
+ x = x[:, 0] # class token
581
+ x = self.fc_norm(x)
582
+ x = self.head_drop(x)
583
+ return x if pre_logits else self.head(x)
584
+
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
+ x = self.forward_features(x)
587
+ if not self.ignore_head:
588
+ x = self.forward_head(x)
589
+ return x
590
+
591
+
592
+ @dataclass
593
+ class SigLIPVisionCfg:
594
+ width: int = 1152
595
+ layers: Union[Tuple[int, int, int, int], int] = 27
596
+ heads: int = 16
597
+ patch_size: int = 14
598
+ image_size: Union[Tuple[int, int], int] = 336
599
+ global_pool: str = "map"
600
+ mlp_ratio: float = 3.7362
601
+ class_token: bool = False
602
+ num_classes: int = 0
603
+ use_checkpoint: bool = False
604
+
605
+
606
+ SigLIP_MODEL_CONFIG = {
607
+ "siglip_so400m_patch14_384": {
608
+ "image_size": 336,
609
+ "patch_size": 14,
610
+ "width": 1152,
611
+ "layers": 27,
612
+ "heads": 16,
613
+ "mlp_ratio": 3.7362,
614
+ "global_pool": "map",
615
+ "use_checkpoint": False,
616
+ },
617
+ "siglip_so400m_patch14_224": {
618
+ "image_size": 224,
619
+ "patch_size": 14,
620
+ "width": 1152,
621
+ "layers": 27,
622
+ "heads": 16,
623
+ "mlp_ratio": 3.7362,
624
+ "global_pool": "map",
625
+ "use_checkpoint": False,
626
+ },
627
+ "siglip_large_patch16_384": {
628
+ "image_size": 384,
629
+ "patch_size": 16,
630
+ "width": 1024,
631
+ "layers": 24,
632
+ "heads": 16,
633
+ "mlp_ratio": 4,
634
+ "global_pool": "map",
635
+ "use_checkpoint": False,
636
+ },
637
+ "siglip_tiny_patch2_384": {
638
+ "image_size": 384,
639
+ "patch_size": 16,
640
+ "width": 32,
641
+ "layers": 2,
642
+ "heads": 1,
643
+ "mlp_ratio": 4,
644
+ "global_pool": "map",
645
+ "use_checkpoint": False,
646
+ }
647
+ }
648
+
649
+
650
+ def create_siglip_vit(
651
+ model_name: str = "siglip_so400m_patch14_384",
652
+ image_size: int = 384,
653
+ select_layer: int = -1,
654
+ ckpt_path: str = "",
655
+ **kwargs,
656
+ ):
657
+ assert (
658
+ model_name in SigLIP_MODEL_CONFIG.keys()
659
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
660
+
661
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
662
+
663
+ if select_layer <= 0:
664
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
665
+ else:
666
+ layers = min(vision_cfg.layers, select_layer)
667
+
668
+ model = VisionTransformer(
669
+ img_size=image_size,
670
+ patch_size=vision_cfg.patch_size,
671
+ embed_dim=vision_cfg.width,
672
+ depth=layers,
673
+ num_heads=vision_cfg.heads,
674
+ mlp_ratio=vision_cfg.mlp_ratio,
675
+ class_token=vision_cfg.class_token,
676
+ global_pool=vision_cfg.global_pool,
677
+ ignore_head=kwargs.get("ignore_head", True),
678
+ weight_init=kwargs.get("weight_init", "skip"),
679
+ num_classes=0,
680
+ )
681
+
682
+ if ckpt_path:
683
+ state_dict = torch.load(ckpt_path, map_location="cpu")
684
+
685
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
686
+ print(
687
+ f"SigLIP-ViT restores from {ckpt_path},\n"
688
+ f"\tincompatible_keys:', {incompatible_keys}."
689
+ )
690
+
691
+ return model
special_tokens_map.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<image_placeholder>",
4
+ "<patch_placeholder>",
5
+ "<|ref|>",
6
+ "<|/ref|>",
7
+ "<|det|>",
8
+ "<|/det|>",
9
+ "<|grounding|>",
10
+ "<|User|>",
11
+ "<|Assistant|>"
12
+ ],
13
+ "bos_token": {
14
+ "content": "<|begin▁of▁sentence|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "eos_token": {
21
+ "content": "<|end▁of▁sentence|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "pad_token": {
28
+ "content": "<|▁pad▁|>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vq_model.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ from dataclasses import dataclass, field
22
+ from typing import List
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from functools import partial
29
+
30
+
31
+ @dataclass
32
+ class ModelArgs:
33
+ codebook_size: int = 16384
34
+ codebook_embed_dim: int = 8
35
+ codebook_l2_norm: bool = True
36
+ codebook_show_usage: bool = True
37
+ commit_loss_beta: float = 0.25
38
+ entropy_loss_ratio: float = 0.0
39
+
40
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
41
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
42
+ z_channels: int = 256
43
+ dropout_p: float = 0.0
44
+
45
+
46
+ class Encoder(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels=3,
50
+ ch=128,
51
+ ch_mult=(1, 1, 2, 2, 4),
52
+ num_res_blocks=2,
53
+ norm_type="group",
54
+ dropout=0.0,
55
+ resamp_with_conv=True,
56
+ z_channels=256,
57
+ ):
58
+ super().__init__()
59
+ self.num_resolutions = len(ch_mult)
60
+ self.num_res_blocks = num_res_blocks
61
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
62
+
63
+ # downsampling
64
+ in_ch_mult = (1,) + tuple(ch_mult)
65
+ self.conv_blocks = nn.ModuleList()
66
+ for i_level in range(self.num_resolutions):
67
+ conv_block = nn.Module()
68
+ # res & attn
69
+ res_block = nn.ModuleList()
70
+ attn_block = nn.ModuleList()
71
+ block_in = ch * in_ch_mult[i_level]
72
+ block_out = ch * ch_mult[i_level]
73
+ for _ in range(self.num_res_blocks):
74
+ res_block.append(
75
+ ResnetBlock(
76
+ block_in, block_out, dropout=dropout, norm_type=norm_type
77
+ )
78
+ )
79
+ block_in = block_out
80
+ if i_level == self.num_resolutions - 1:
81
+ attn_block.append(AttnBlock(block_in, norm_type))
82
+ conv_block.res = res_block
83
+ conv_block.attn = attn_block
84
+ # downsample
85
+ if i_level != self.num_resolutions - 1:
86
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
87
+ self.conv_blocks.append(conv_block)
88
+
89
+ # middle
90
+ self.mid = nn.ModuleList()
91
+ self.mid.append(
92
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
93
+ )
94
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
95
+ self.mid.append(
96
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
97
+ )
98
+
99
+ # end
100
+ self.norm_out = Normalize(block_in, norm_type)
101
+ self.conv_out = nn.Conv2d(
102
+ block_in, z_channels, kernel_size=3, stride=1, padding=1
103
+ )
104
+
105
+ def forward(self, x):
106
+ h = self.conv_in(x)
107
+ # downsampling
108
+ for i_level, block in enumerate(self.conv_blocks):
109
+ for i_block in range(self.num_res_blocks):
110
+ h = block.res[i_block](h)
111
+ if len(block.attn) > 0:
112
+ h = block.attn[i_block](h)
113
+ if i_level != self.num_resolutions - 1:
114
+ h = block.downsample(h)
115
+
116
+ # middle
117
+ for mid_block in self.mid:
118
+ h = mid_block(h)
119
+
120
+ # end
121
+ h = self.norm_out(h)
122
+ h = nonlinearity(h)
123
+ h = self.conv_out(h)
124
+ return h
125
+
126
+
127
+ class Decoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ z_channels=256,
131
+ ch=128,
132
+ ch_mult=(1, 1, 2, 2, 4),
133
+ num_res_blocks=2,
134
+ norm_type="group",
135
+ dropout=0.0,
136
+ resamp_with_conv=True,
137
+ out_channels=3,
138
+ ):
139
+ super().__init__()
140
+ self.num_resolutions = len(ch_mult)
141
+ self.num_res_blocks = num_res_blocks
142
+
143
+ block_in = ch * ch_mult[self.num_resolutions - 1]
144
+ # z to block_in
145
+ self.conv_in = nn.Conv2d(
146
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
147
+ )
148
+
149
+ # middle
150
+ self.mid = nn.ModuleList()
151
+ self.mid.append(
152
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
153
+ )
154
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
155
+ self.mid.append(
156
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
157
+ )
158
+
159
+ # upsampling
160
+ self.conv_blocks = nn.ModuleList()
161
+ for i_level in reversed(range(self.num_resolutions)):
162
+ conv_block = nn.Module()
163
+ # res & attn
164
+ res_block = nn.ModuleList()
165
+ attn_block = nn.ModuleList()
166
+ block_out = ch * ch_mult[i_level]
167
+ for _ in range(self.num_res_blocks + 1):
168
+ res_block.append(
169
+ ResnetBlock(
170
+ block_in, block_out, dropout=dropout, norm_type=norm_type
171
+ )
172
+ )
173
+ block_in = block_out
174
+ if i_level == self.num_resolutions - 1:
175
+ attn_block.append(AttnBlock(block_in, norm_type))
176
+ conv_block.res = res_block
177
+ conv_block.attn = attn_block
178
+ # downsample
179
+ if i_level != 0:
180
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
181
+ self.conv_blocks.append(conv_block)
182
+
183
+ # end
184
+ self.norm_out = Normalize(block_in, norm_type)
185
+ self.conv_out = nn.Conv2d(
186
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
187
+ )
188
+
189
+ @property
190
+ def last_layer(self):
191
+ return self.conv_out.weight
192
+
193
+ def forward(self, z):
194
+ # z to block_in
195
+ h = self.conv_in(z)
196
+
197
+ # middle
198
+ for mid_block in self.mid:
199
+ h = mid_block(h)
200
+
201
+ # upsampling
202
+ for i_level, block in enumerate(self.conv_blocks):
203
+ for i_block in range(self.num_res_blocks + 1):
204
+ h = block.res[i_block](h)
205
+ if len(block.attn) > 0:
206
+ h = block.attn[i_block](h)
207
+ if i_level != self.num_resolutions - 1:
208
+ h = block.upsample(h)
209
+
210
+ # end
211
+ h = self.norm_out(h)
212
+ h = nonlinearity(h)
213
+ h = self.conv_out(h)
214
+ return h
215
+
216
+
217
+ class VectorQuantizer(nn.Module):
218
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
219
+ super().__init__()
220
+ self.n_e = n_e
221
+ self.e_dim = e_dim
222
+ self.beta = beta
223
+ self.entropy_loss_ratio = entropy_loss_ratio
224
+ self.l2_norm = l2_norm
225
+ self.show_usage = show_usage
226
+
227
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
228
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
229
+ if self.l2_norm:
230
+ self.embedding.weight.data = F.normalize(
231
+ self.embedding.weight.data, p=2, dim=-1
232
+ )
233
+ if self.show_usage:
234
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
235
+
236
+ def forward(self, z):
237
+ # reshape z -> (batch, height, width, channel) and flatten
238
+ z = torch.einsum("b c h w -> b h w c", z).contiguous()
239
+ z_flattened = z.view(-1, self.e_dim)
240
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
241
+
242
+ if self.l2_norm:
243
+ z = F.normalize(z, p=2, dim=-1)
244
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
245
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
246
+ else:
247
+ embedding = self.embedding.weight
248
+
249
+ d = (
250
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
251
+ + torch.sum(embedding**2, dim=1)
252
+ - 2
253
+ * torch.einsum(
254
+ "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
255
+ )
256
+ )
257
+
258
+ min_encoding_indices = torch.argmin(d, dim=1)
259
+ z_q = embedding[min_encoding_indices].view(z.shape)
260
+ perplexity = None
261
+ min_encodings = None
262
+ vq_loss = None
263
+ commit_loss = None
264
+ entropy_loss = None
265
+
266
+ # compute loss for embedding
267
+ if self.training:
268
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
269
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
270
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
271
+
272
+ # preserve gradients
273
+ z_q = z + (z_q - z).detach()
274
+
275
+ # reshape back to match original input shape
276
+ z_q = torch.einsum("b h w c -> b c h w", z_q)
277
+
278
+ return (
279
+ z_q,
280
+ (vq_loss, commit_loss, entropy_loss),
281
+ (perplexity, min_encodings, min_encoding_indices),
282
+ )
283
+
284
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
285
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
286
+ if self.l2_norm:
287
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
288
+ else:
289
+ embedding = self.embedding.weight
290
+ z_q = embedding[indices] # (b*h*w, c)
291
+
292
+ if shape is not None:
293
+ if channel_first:
294
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
295
+ # reshape back to match original input shape
296
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
297
+ else:
298
+ z_q = z_q.view(shape)
299
+ return z_q
300
+
301
+
302
+ class ResnetBlock(nn.Module):
303
+ def __init__(
304
+ self,
305
+ in_channels,
306
+ out_channels=None,
307
+ conv_shortcut=False,
308
+ dropout=0.0,
309
+ norm_type="group",
310
+ ):
311
+ super().__init__()
312
+ self.in_channels = in_channels
313
+ out_channels = in_channels if out_channels is None else out_channels
314
+ self.out_channels = out_channels
315
+ self.use_conv_shortcut = conv_shortcut
316
+
317
+ self.norm1 = Normalize(in_channels, norm_type)
318
+ self.conv1 = nn.Conv2d(
319
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
320
+ )
321
+ self.norm2 = Normalize(out_channels, norm_type)
322
+ self.dropout = nn.Dropout(dropout)
323
+ self.conv2 = nn.Conv2d(
324
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
325
+ )
326
+
327
+ if self.in_channels != self.out_channels:
328
+ if self.use_conv_shortcut:
329
+ self.conv_shortcut = nn.Conv2d(
330
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
331
+ )
332
+ else:
333
+ self.nin_shortcut = nn.Conv2d(
334
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
335
+ )
336
+
337
+ def forward(self, x):
338
+ h = x
339
+ h = self.norm1(h)
340
+ h = nonlinearity(h)
341
+ h = self.conv1(h)
342
+ h = self.norm2(h)
343
+ h = nonlinearity(h)
344
+ h = self.dropout(h)
345
+ h = self.conv2(h)
346
+
347
+ if self.in_channels != self.out_channels:
348
+ if self.use_conv_shortcut:
349
+ x = self.conv_shortcut(x)
350
+ else:
351
+ x = self.nin_shortcut(x)
352
+ return x + h
353
+
354
+
355
+ class AttnBlock(nn.Module):
356
+ def __init__(self, in_channels, norm_type="group"):
357
+ super().__init__()
358
+ self.norm = Normalize(in_channels, norm_type)
359
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
360
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
361
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
362
+ self.proj_out = nn.Conv2d(
363
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
364
+ )
365
+
366
+ def forward(self, x):
367
+ h_ = x
368
+ h_ = self.norm(h_)
369
+ q = self.q(h_)
370
+ k = self.k(h_)
371
+ v = self.v(h_)
372
+
373
+ # compute attention
374
+ b, c, h, w = q.shape
375
+ q = q.reshape(b, c, h * w)
376
+ q = q.permute(0, 2, 1) # b,hw,c
377
+ k = k.reshape(b, c, h * w) # b,c,hw
378
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
379
+ w_ = w_ * (int(c) ** (-0.5))
380
+ w_ = F.softmax(w_, dim=2)
381
+
382
+ # attend to values
383
+ v = v.reshape(b, c, h * w)
384
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
385
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
386
+ h_ = h_.reshape(b, c, h, w)
387
+
388
+ h_ = self.proj_out(h_)
389
+
390
+ return x + h_
391
+
392
+
393
+ def nonlinearity(x):
394
+ # swish
395
+ return x * torch.sigmoid(x)
396
+
397
+
398
+ def Normalize(in_channels, norm_type="group"):
399
+ assert norm_type in ["group", "batch"]
400
+ if norm_type == "group":
401
+ return nn.GroupNorm(
402
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
403
+ )
404
+ elif norm_type == "batch":
405
+ return nn.SyncBatchNorm(in_channels)
406
+
407
+
408
+ class Upsample(nn.Module):
409
+ def __init__(self, in_channels, with_conv):
410
+ super().__init__()
411
+ self.with_conv = with_conv
412
+ if self.with_conv:
413
+ self.conv = nn.Conv2d(
414
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
415
+ )
416
+
417
+ def forward(self, x):
418
+ if x.dtype != torch.float32:
419
+ x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
420
+ torch.bfloat16
421
+ )
422
+ else:
423
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
424
+
425
+ if self.with_conv:
426
+ x = self.conv(x)
427
+ return x
428
+
429
+
430
+ class Downsample(nn.Module):
431
+ def __init__(self, in_channels, with_conv):
432
+ super().__init__()
433
+ self.with_conv = with_conv
434
+ if self.with_conv:
435
+ # no asymmetric padding in torch conv, must do it ourselves
436
+ self.conv = nn.Conv2d(
437
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
438
+ )
439
+
440
+ def forward(self, x):
441
+ if self.with_conv:
442
+ pad = (0, 1, 0, 1)
443
+ x = F.pad(x, pad, mode="constant", value=0)
444
+ x = self.conv(x)
445
+ else:
446
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
447
+ return x
448
+
449
+
450
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
451
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
452
+ flat_affinity /= temperature
453
+ probs = F.softmax(flat_affinity, dim=-1)
454
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
455
+ if loss_type == "softmax":
456
+ target_probs = probs
457
+ else:
458
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
459
+ avg_probs = torch.mean(target_probs, dim=0)
460
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
461
+ sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
462
+ loss = sample_entropy - avg_entropy
463
+ return loss
464
+
465
+
466
+ class VQModel(nn.Module):
467
+ def __init__(self, config: ModelArgs):
468
+ super().__init__()
469
+ self.config = config
470
+ self.encoder = Encoder(
471
+ ch_mult=config.encoder_ch_mult,
472
+ z_channels=config.z_channels,
473
+ dropout=config.dropout_p,
474
+ )
475
+ self.decoder = Decoder(
476
+ ch_mult=config.decoder_ch_mult,
477
+ z_channels=config.z_channels,
478
+ dropout=config.dropout_p,
479
+ )
480
+
481
+ self.quantize = VectorQuantizer(
482
+ config.codebook_size,
483
+ config.codebook_embed_dim,
484
+ config.commit_loss_beta,
485
+ config.entropy_loss_ratio,
486
+ config.codebook_l2_norm,
487
+ config.codebook_show_usage,
488
+ )
489
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
490
+ self.post_quant_conv = nn.Conv2d(
491
+ config.codebook_embed_dim, config.z_channels, 1
492
+ )
493
+
494
+ def encode(self, x):
495
+ h = self.encoder(x)
496
+ h = self.quant_conv(h)
497
+ quant, emb_loss, info = self.quantize(h)
498
+ return quant, emb_loss, info
499
+
500
+ def decode(self, quant):
501
+ quant = self.post_quant_conv(quant)
502
+ dec = self.decoder(quant)
503
+ return dec
504
+
505
+ def decode_code(self, code_b, shape=None, channel_first=True):
506
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
507
+ dec = self.decode(quant_b)
508
+ return dec
509
+
510
+ def forward(self, input):
511
+ quant, diff, _ = self.encode(input)
512
+ dec = self.decode(quant)
513
+ return dec, diff
514
+
515
+
516
+ #################################################################################
517
+ # VQ Model Configs #
518
+ #################################################################################
519
+ def VQ_16(**kwargs):
520
+ return VQModel(
521
+ ModelArgs(
522
+ encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
523
+ )
524
+ )
525
+
526
+
527
+ VQ_models = {"VQ-16": VQ_16}