Transformers
remyx
Inference Endpoints
smellslikeml commited on
Commit
98aeb8d
·
1 Parent(s): 8202fdb

initial commit

Browse files
config.json CHANGED
@@ -3,6 +3,10 @@
3
  "architectures": [
4
  "PrismaticForConditionalGeneration"
5
  ],
 
 
 
 
6
  "hf_llm_id": "meta-llama/Meta-Llama-3.1-8B",
7
  "image_resize_strategy": "letterbox",
8
  "image_sizes": [
 
3
  "architectures": [
4
  "PrismaticForConditionalGeneration"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_prismatic.PrismaticConfig",
8
+ "AutoModelForVision2Seq": "modeling_prismatic.PrismaticForConditionalGeneration"
9
+ },
10
  "hf_llm_id": "meta-llama/Meta-Llama-3.1-8B",
11
  "image_resize_strategy": "letterbox",
12
  "image_sizes": [
configuration_prismatic.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+ "llama3-1-8b-pure": "meta-llama/Meta-Llama-3.1-8B",
51
+
52
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
53
+
54
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
55
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
56
+
57
+ "phi-2-3b": "microsoft/phi-2",
58
+ }
59
+ LLM_BACKBONE_TO_HF_METACLASS = {
60
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
61
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", "llama3-1-8b-pure": "llama",
62
+
63
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
64
+
65
+ "phi-2-3b": "phi",
66
+ }
67
+
68
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
69
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
70
+ # fmt: on
71
+
72
+ class PrismaticConfig(PretrainedConfig):
73
+ model_type: str = "prismatic"
74
+ is_composition: bool = False
75
+
76
+ def __init__(
77
+ self,
78
+ vision_backbone_id: str = "siglip-vit-so400m",
79
+ llm_backbone_id: str = "vicuna-v15-7b",
80
+ arch_specifier: str = "no-align+gelu-mlp",
81
+ use_fused_vision_backbone: Optional[bool] = None,
82
+ image_resize_strategy: str = "letterbox",
83
+ text_config: Optional[Dict[str, Any]] = None,
84
+ llm_max_length: int = 2048,
85
+ pad_token_id: int = 32000,
86
+ pad_to_multiple_of: int = 64,
87
+ output_projector_states: bool = False,
88
+ vocab_size: int = 32001, # Ensure vocab_size is passed and set
89
+ **kwargs: str,
90
+ ) -> None:
91
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
92
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
93
+
94
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
95
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
96
+
97
+ # Set Prismatic Configuration Fields
98
+ self.vision_backbone_id = vision_backbone_id
99
+ self.llm_backbone_id = llm_backbone_id
100
+ self.arch_specifier = arch_specifier
101
+ self.output_projector_states = output_projector_states
102
+ self.vocab_size = vocab_size
103
+
104
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
105
+ self.use_fused_vision_backbone = (
106
+ use_fused_vision_backbone
107
+ if use_fused_vision_backbone is not None
108
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
109
+ )
110
+
111
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
112
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
113
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
114
+ self.image_resize_strategy = image_resize_strategy
115
+
116
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
117
+ self.llm_max_length = llm_max_length
118
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
119
+
120
+ # Set padding_idx if not already set
121
+ if not hasattr(self, 'padding_idx'):
122
+ # self.padding_idx = pad_token_id
123
+ self.padding_idx = 0
124
+
125
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
126
+ self.text_config = (
127
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
128
+ if text_config is not None
129
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
130
+ )
131
+
132
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
133
+ super().__init__(pad_token_id=pad_token_id, vocab_size=vocab_size, **kwargs)
134
+
135
+
136
+ class OpenVLAConfig(PrismaticConfig):
137
+ model_type: str = "openvla"
138
+
139
+ def __init__(
140
+ self,
141
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
142
+ n_action_bins: int = 256,
143
+ vocab_size: int = 32001, # Default vocab size, adjust if necessary
144
+ **kwargs: str,
145
+ ) -> None:
146
+ self.norm_stats = norm_stats
147
+ self.n_action_bins = n_action_bins
148
+ self.vocab_size = vocab_size
149
+
150
+ super().__init__(**kwargs)
151
+
152
+ # Ensure padding_idx is within the valid range
153
+ if not hasattr(self, 'padding_idx') or self.padding_idx >= self.vocab_size:
154
+ print(f"Padding index {self.padding_idx} is out of range. Adjusting to {self.vocab_size - 1}.")
155
+ self.padding_idx = self.vocab_size - 1
156
+
modeling_prismatic.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5
+ from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6
+ logic in `prismatic.models.vlms.prismatic.py`.
7
+
8
+ Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
+
10
+ References [LLaVa, IDEFICS-2]:
11
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
13
+ """
14
+
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from functools import partial
18
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import timm
22
+ import tokenizers
23
+ import torch
24
+ import torch.nn as nn
25
+ import transformers
26
+ from timm.models.vision_transformer import LayerScale
27
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
28
+ from transformers.modeling_outputs import ModelOutput
29
+
30
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
31
+
32
+ # Get Logger
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
37
+ IGNORE_INDEX = -100
38
+
39
+
40
+ # === Utility Functions for Monkey-Patching ===
41
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
42
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
43
+ result = fn(*args, **kwargs)
44
+ return result[0] if isinstance(result, tuple) else result
45
+
46
+ return wrapper
47
+
48
+
49
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
50
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
51
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
52
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
54
+
55
+
56
+ def ls_apply_patch(ls_module: LayerScale):
57
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
58
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
59
+ del ls_module.gamma
60
+
61
+
62
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
63
+ class PrismaticVisionBackbone(nn.Module):
64
+ def __init__(
65
+ self,
66
+ use_fused_vision_backbone: bool,
67
+ image_sizes: List[int],
68
+ timm_model_ids: List[str],
69
+ timm_override_act_layers: List[Optional[str]],
70
+ ) -> None:
71
+ super().__init__()
72
+ self.use_fused_vision_backbone = use_fused_vision_backbone
73
+
74
+ # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
75
+ # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
76
+ # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
77
+ assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
78
+ self.featurizer = timm.create_model(
79
+ timm_model_ids[0],
80
+ pretrained=False,
81
+ num_classes=0,
82
+ img_size=image_sizes[0],
83
+ act_layer=timm_override_act_layers[0],
84
+ )
85
+ self.featurizer.forward = unpack_tuple(
86
+ partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
87
+ )
88
+ self.embed_dim = self.featurizer.embed_dim
89
+
90
+ # If `use_fused_vision_backbone` =>> create "beta" featurizer
91
+ if self.use_fused_vision_backbone:
92
+ self.fused_featurizer = timm.create_model(
93
+ timm_model_ids[1],
94
+ pretrained=False,
95
+ num_classes=0,
96
+ img_size=image_sizes[1],
97
+ act_layer=timm_override_act_layers[1],
98
+ )
99
+ self.fused_featurizer.forward = unpack_tuple(
100
+ partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})
101
+ )
102
+ self.embed_dim += self.fused_featurizer.embed_dim
103
+
104
+ # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale
105
+ for module in self.featurizer.modules():
106
+ if isinstance(module, LayerScale):
107
+ ls_apply_patch(module)
108
+
109
+ if self.use_fused_vision_backbone:
110
+ for module in self.fused_featurizer.modules():
111
+ if isinstance(module, LayerScale):
112
+ ls_apply_patch(module)
113
+
114
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
115
+ """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
116
+ if not self.use_fused_vision_backbone:
117
+ return self.featurizer(pixel_values)
118
+
119
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
120
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
121
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
122
+
123
+ return torch.cat([patches, patches_fused], dim=2)
124
+
125
+
126
+ # === Prismatic Projector (nn.Module) Definitions ===
127
+ class PrismaticProjector(nn.Module):
128
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
129
+ super().__init__()
130
+ self.use_fused_vision_backbone = use_fused_vision_backbone
131
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
132
+
133
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
134
+ if not self.use_fused_vision_backbone:
135
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
136
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
137
+ self.act_fn1 = nn.GELU()
138
+ else:
139
+ initial_projection_dim = 4 * vision_dim
140
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
141
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
142
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
143
+ self.act_fn1 = nn.GELU()
144
+ self.act_fn2 = nn.GELU()
145
+
146
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
147
+ if not self.use_fused_vision_backbone:
148
+ projected_features = self.fc1(img_patches)
149
+ projected_features = self.act_fn1(projected_features)
150
+ projected_features = self.fc2(projected_features)
151
+ else:
152
+ projected_features = self.fc1(img_patches)
153
+ projected_features = self.act_fn1(projected_features)
154
+ projected_features = self.fc2(projected_features)
155
+ projected_features = self.act_fn2(projected_features)
156
+ projected_features = self.fc3(projected_features)
157
+
158
+ return projected_features
159
+
160
+
161
+ # === Main HF Class Definitions ===
162
+ @dataclass
163
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
164
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
165
+
166
+ loss: Optional[torch.FloatTensor] = None
167
+ logits: torch.FloatTensor = None
168
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
169
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
170
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
171
+
172
+ # Additions for VLMs
173
+ projector_features: Optional[torch.FloatTensor] = None
174
+
175
+
176
+ class PrismaticPreTrainedModel(PreTrainedModel):
177
+ config_class: PretrainedConfig = PrismaticConfig
178
+ base_model_prefix: str = "model"
179
+ supports_gradient_checkpointing: bool = True
180
+
181
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
182
+ _skip_keys_device_placement: str = "past_key_values"
183
+ _supports_flash_attn_2: bool = True
184
+
185
+ def _init_weights(self, module: nn.Module) -> None:
186
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
187
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
188
+ # https://github.com/TRI-ML/prismatic-vlms
189
+ std = (
190
+ self.config.initializer_range
191
+ if hasattr(self.config, "initializer_range")
192
+ else self.config.text_config.initializer_range
193
+ )
194
+
195
+ if hasattr(module, "class_embedding"):
196
+ module.class_embedding.data.normal_(mean=0.0, std=std)
197
+
198
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
199
+ module.weight.data.normal_(mean=0.0, std=std)
200
+ if module.bias is not None:
201
+ module.bias.data.zero_()
202
+ elif isinstance(module, nn.Embedding):
203
+ module.weight.data.normal_(mean=0.0, std=std)
204
+ if module.padding_idx is not None:
205
+ module.weight.data[module.padding_idx].zero_()
206
+
207
+ @property
208
+ def _supports_sdpa(self) -> bool:
209
+ """Check LLM supports SDPA Attention"""
210
+ return self.language_model._supports_sdpa
211
+
212
+
213
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
214
+ def __init__(self, config: PrismaticConfig) -> None:
215
+ super().__init__(config)
216
+
217
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
218
+ if config.use_fused_vision_backbone is None:
219
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
220
+
221
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
222
+ raise NotImplementedError(
223
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
224
+ "if you urgently need support for latest TIMM versions."
225
+ )
226
+
227
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
228
+ logger.warning(
229
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
230
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
231
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
232
+ f"use the above versions."
233
+ )
234
+
235
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
236
+ self.vision_backbone = PrismaticVisionBackbone(
237
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
238
+ )
239
+
240
+ # Create Multimodal Projector
241
+ self.projector = PrismaticProjector(
242
+ config.use_fused_vision_backbone,
243
+ vision_dim=self.vision_backbone.embed_dim,
244
+ llm_dim=config.text_config.hidden_size,
245
+ )
246
+
247
+ print("CONFIG: ", config)
248
+ print("CONFIG text: ", config.text_config)
249
+ print("CONFIG attn implementation: ", config._attn_implementation)
250
+ # Instantiate LLM Backbone
251
+ self.language_model = AutoModelForCausalLM.from_config(
252
+ config.text_config, attn_implementation=config._attn_implementation
253
+ )
254
+ print("loaded language model: ", self.language_model)
255
+ #self.language_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
256
+ self.vocab_size = config.text_config.vocab_size
257
+ self.pad_token_id = config.pad_token_id
258
+
259
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
260
+ self.post_init()
261
+
262
+ # === `PreTrainedModel` Boilerplate ===
263
+ def get_input_embeddings(self) -> nn.Module:
264
+ return self.language_model.get_input_embeddings()
265
+
266
+ def set_input_embeddings(self, value: nn.Module) -> None:
267
+ self.language_model.set_input_embeddings(value)
268
+
269
+ def get_output_embeddings(self) -> nn.Module:
270
+ return self.language_model.get_output_embeddings()
271
+
272
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
273
+ self.language_model.set_output_embeddings(new_embeddings)
274
+
275
+ def get_decoder(self) -> nn.Module:
276
+ return self.language_model.get_decoder()
277
+
278
+ def set_decoder(self, decoder: nn.Module) -> None:
279
+ self.language_model.set_decoder(decoder)
280
+
281
+ def tie_weights(self) -> None:
282
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
283
+
284
+ def resize_token_embeddings(
285
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
286
+ ) -> nn.Embedding:
287
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
288
+
289
+ # Update config/instance variables
290
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
291
+ self.vocab_size = updated_embeddings.num_embeddings
292
+
293
+ return updated_embeddings
294
+
295
+ # === Core Prismatic VLM `forward()` Logic ===
296
+ def forward(
297
+ self,
298
+ input_ids: Optional[torch.LongTensor] = None,
299
+ attention_mask: Optional[torch.Tensor] = None,
300
+ pixel_values: Optional[torch.FloatTensor] = None,
301
+ labels: Optional[torch.LongTensor] = None,
302
+ inputs_embeds: Optional[torch.FloatTensor] = None,
303
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
304
+ use_cache: Optional[bool] = None,
305
+ output_attentions: Optional[bool] = None,
306
+ output_hidden_states: Optional[bool] = None,
307
+ output_projector_features: Optional[bool] = None,
308
+ return_dict: Optional[bool] = None,
309
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
310
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
311
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
312
+ output_hidden_states = (
313
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
314
+ )
315
+ output_projector_features = output_projector_features if output_projector_features is not None else False
316
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
317
+
318
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
319
+ use_cache = use_cache and not self.training
320
+
321
+ # Instantiate Placeholder for Projector Features
322
+ projected_patch_embeddings = None
323
+
324
+ # Note :: We only support forward passes with the following cases:
325
+ # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
326
+ # => Unimodal Forward :: (pixel_values is None)
327
+ # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
328
+
329
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
330
+ if input_ids.shape[1] == 1:
331
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
332
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
333
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
334
+
335
+ language_model_output = self.language_model(
336
+ input_ids=input_ids,
337
+ attention_mask=None,
338
+ position_ids=None,
339
+ past_key_values=past_key_values,
340
+ inputs_embeds=None,
341
+ labels=None,
342
+ use_cache=use_cache,
343
+ output_attentions=output_attentions,
344
+ output_hidden_states=output_hidden_states,
345
+ return_dict=return_dict,
346
+ )
347
+
348
+ # === Handle Unimodal Forward ===
349
+ elif pixel_values is None:
350
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
351
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
352
+
353
+ language_model_output = self.language_model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ position_ids=None,
357
+ past_key_values=None,
358
+ inputs_embeds=None,
359
+ labels=labels,
360
+ use_cache=use_cache,
361
+ output_attentions=output_attentions,
362
+ output_hidden_states=output_hidden_states,
363
+ return_dict=return_dict,
364
+ )
365
+
366
+ # === Handle Multimodal Forward ===
367
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
368
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
369
+
370
+ # Visual Feature Extraction
371
+ patch_features = self.vision_backbone(pixel_values)
372
+
373
+ # Projection Logic =>> Update Attention Mask
374
+ projected_patch_embeddings = self.projector(patch_features)
375
+ projected_patch_attention_mask = None
376
+ if attention_mask is not None:
377
+ projected_patch_attention_mask = torch.full(
378
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
379
+ fill_value=True,
380
+ dtype=attention_mask.dtype,
381
+ device=attention_mask.device,
382
+ )
383
+
384
+ # Get Input Embeddings (from Language Model Embeddings)
385
+ input_embeddings = self.get_input_embeddings()(input_ids)
386
+
387
+ # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
388
+ multimodal_embeddings = torch.cat(
389
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
390
+ )
391
+ multimodal_attention_mask = None
392
+ if attention_mask is not None:
393
+ multimodal_attention_mask = torch.cat(
394
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
395
+ )
396
+
397
+ # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
398
+ multimodal_labels = None
399
+ if labels is not None:
400
+ projected_patch_labels = torch.full(
401
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
402
+ fill_value=IGNORE_INDEX,
403
+ dtype=labels.dtype,
404
+ device=labels.device,
405
+ )
406
+ multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
407
+
408
+ # Dispatch to Language Model
409
+ language_model_output = self.language_model(
410
+ input_ids=None,
411
+ attention_mask=multimodal_attention_mask,
412
+ position_ids=None,
413
+ past_key_values=None,
414
+ inputs_embeds=multimodal_embeddings,
415
+ labels=multimodal_labels,
416
+ use_cache=use_cache,
417
+ output_attentions=output_attentions,
418
+ output_hidden_states=output_hidden_states,
419
+ return_dict=return_dict,
420
+ )
421
+
422
+ # === Otherwise =>> Assume Invalid! ===
423
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
424
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
425
+
426
+ else:
427
+ raise ValueError(
428
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
429
+ f"=> `input_ids` = {input_ids is not None}\n"
430
+ f"=> `attention_mask` = {attention_mask is not None}\n"
431
+ f"=> `pixel_values` = {pixel_values is not None}\n"
432
+ f"=> `labels` = {labels is not None}\n"
433
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
434
+ f"=> `past_key_values` = {past_key_values is not None}\n"
435
+ f"=> `use_cache` = {use_cache}"
436
+ )
437
+
438
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
439
+ if not return_dict:
440
+ if output_projector_features and (projected_patch_embeddings is not None):
441
+ return *language_model_output, projected_patch_embeddings
442
+
443
+ return language_model_output
444
+
445
+ return PrismaticCausalLMOutputWithPast(
446
+ loss=language_model_output.loss,
447
+ logits=language_model_output.logits,
448
+ past_key_values=language_model_output.past_key_values,
449
+ hidden_states=language_model_output.hidden_states,
450
+ attentions=language_model_output.attentions,
451
+ projector_features=projected_patch_embeddings,
452
+ )
453
+
454
+ # === GenerationMixin Methods ===
455
+ def prepare_inputs_for_generation(
456
+ self,
457
+ input_ids: Optional[torch.Tensor] = None,
458
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
459
+ inputs_embeds: Optional[torch.FloatTensor] = None,
460
+ pixel_values: Optional[torch.FloatTensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ **kwargs: str,
463
+ ) -> Dict[str, torch.Tensor]:
464
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
465
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
466
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
467
+ ):
468
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
469
+
470
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
471
+ if past_key_values is not None:
472
+ input_ids = input_ids[:, -1:]
473
+
474
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
475
+ if inputs_embeds is not None and past_key_values is None:
476
+ model_inputs = {"input_embeds": inputs_embeds}
477
+ else:
478
+ model_inputs = {"input_ids": input_ids}
479
+
480
+ # Make sure `pixel_values` are preserved in `model_inputs`
481
+ model_inputs.update(
482
+ {
483
+ "attention_mask": attention_mask,
484
+ "pixel_values": pixel_values,
485
+ "past_key_values": past_key_values,
486
+ "use_cache": kwargs.get("use_cache"),
487
+ }
488
+ )
489
+
490
+ return model_inputs
491
+
492
+ # Defer to Language Model (all handle this differently, with different return types)
493
+ def _reorder_cache(self, *args, **kwargs) -> Any:
494
+ return self.language_model._reorder_cache(*args, **kwargs)
495
+
496
+
497
+ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
498
+ config_class: PretrainedConfig = OpenVLAConfig
499
+
500
+ def __init__(self, config: OpenVLAConfig) -> None:
501
+ super().__init__(config)
502
+ self.norm_stats = config.norm_stats
503
+
504
+ # Compute action bins
505
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
506
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
507
+
508
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
509
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
510
+
511
+ def predict_action(
512
+ self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
513
+ ) -> np.ndarray:
514
+ """Thin wrapper around super().generate() that decodes predicted actions and de-normalizes them."""
515
+
516
+ # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
517
+ # in order for the predictions to match the training configuration and be accurate.
518
+ input_ids = torch.cat(
519
+ (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
520
+ )
521
+
522
+ # Run VLA inference
523
+ generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
524
+
525
+ # Extract predicted action tokens and translate into (normalized) continuous actions
526
+ predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
527
+ discretized_actions = self.vocab_size - predicted_action_token_ids
528
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
529
+ normalized_actions = self.bin_centers[discretized_actions]
530
+
531
+ # Unnormalize actions
532
+ action_norm_stats = self.get_action_stats(unnorm_key)
533
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
534
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
535
+ actions = np.where(
536
+ mask,
537
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
538
+ normalized_actions,
539
+ )
540
+
541
+ return actions
542
+
543
+ @staticmethod
544
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
545
+ if unnorm_key is None and len(norm_stats) != 1:
546
+ raise ValueError(
547
+ f"Your model was trained on more than one dataset. "
548
+ f"Please pass a `unnorm_key` from the following options to choose the statistics used for "
549
+ f"de-normalizing actions: {norm_stats.keys()}"
550
+ )
551
+
552
+ # If None, grab the (singular) dataset in `norm_stats` to use as `unnorm_key`
553
+ unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys()))
554
+ if unnorm_key not in norm_stats:
555
+ raise ValueError(
556
+ f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. "
557
+ f"Please choose from: {norm_stats.keys()}"
558
+ )
559
+
560
+ return unnorm_key
561
+
562
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
563
+ """Get the dimensionality of the policy's action space."""
564
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
565
+ return len(self.norm_stats[unnorm_key]["action"]["q01"])
566
+
567
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
568
+ """Get all the logged statistics for the given dataset."""
569
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
570
+ return self.norm_stats[unnorm_key]["action"]
preprocessor_config.json CHANGED
@@ -1,4 +1,8 @@
1
  {
 
 
 
 
2
  "image_processor_type": "PrismaticImageProcessor",
3
  "image_resize_strategy": "letterbox",
4
  "input_sizes": [
 
1
  {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
4
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
5
+ },
6
  "image_processor_type": "PrismaticImageProcessor",
7
  "image_resize_strategy": "letterbox",
8
  "input_sizes": [
processing_prismatic.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
49
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
50
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
51
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
52
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
53
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
54
+ """
55
+ self.use_fused_vision_backbone = use_fused_vision_backbone
56
+ self.image_resize_strategy = image_resize_strategy
57
+
58
+ # Handle `None` default values
59
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
60
+ means = [(0.5, 0.5, 0.5)] if means is None else means
61
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
62
+
63
+ # TIMM `data_cfg` Parameters
64
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
65
+
66
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
67
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
68
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
69
+
70
+ for idx in range(len(input_sizes)):
71
+ transform = timm.data.create_transform(
72
+ input_size=self.input_sizes[idx],
73
+ interpolation=self.interpolations[idx],
74
+ mean=self.means[idx],
75
+ std=self.stds[idx],
76
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
77
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
78
+ is_training=False, # No image augmentations when loading the transform!
79
+ )
80
+
81
+ # [Validation] Ensure appropriate transform structure, expected sizes
82
+ if not (
83
+ isinstance(transform, Compose)
84
+ and (len(transform.transforms) == 4)
85
+ and isinstance(transform.transforms[0], Resize)
86
+ and isinstance(transform.transforms[1], CenterCrop)
87
+ and isinstance(transform.transforms[2], ToTensor)
88
+ and isinstance(transform.transforms[3], Normalize)
89
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
90
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
91
+ ):
92
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
93
+
94
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
95
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
96
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
97
+ self.tvf_resize_params.append(
98
+ {
99
+ "size": resize_t.size,
100
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
101
+ "max_size": None,
102
+ "antialias": True,
103
+ }
104
+ )
105
+ self.tvf_crop_params.append({"output_size": crop_t.size})
106
+ self.tvf_normalize_params.append(
107
+ {
108
+ "mean": norm_t.mean.float().numpy().tolist(),
109
+ "std": norm_t.std.float().numpy().tolist(),
110
+ "inplace": False,
111
+ }
112
+ )
113
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
114
+
115
+ # Handle Prismatic `image_resize_strategy`
116
+ if self.image_resize_strategy == "resize-naive":
117
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
118
+ elif self.image_resize_strategy == "letterbox":
119
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
120
+ elif self.image_resize_strategy == "resize-crop":
121
+ pass
122
+ else:
123
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
124
+
125
+ # Dispatch **kwargs to super()
126
+ super().__init__(**kwargs)
127
+
128
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
129
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
130
+ if self.tvf_do_letterbox:
131
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
132
+
133
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
134
+ imgs_t = []
135
+ for idx in range(len(self.input_sizes)):
136
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
137
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
138
+ img_idx_t = TVF.to_tensor(img_idx)
139
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
140
+ imgs_t.append(img_idx_t)
141
+
142
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
143
+ img_t = torch.vstack(imgs_t)
144
+
145
+ return img_t
146
+
147
+ def preprocess(
148
+ self,
149
+ images: Union[Image.Image, List[Image.Image]],
150
+ return_tensors: Optional[Union[str, TensorType]] = None,
151
+ **_: str,
152
+ ) -> BatchFeature:
153
+ """
154
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
155
+ explicitly only handle PIL.Image.Image instances for simplicity.
156
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
157
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
158
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
159
+ """
160
+ if not isinstance(images, list):
161
+ images = [images]
162
+
163
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
164
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
165
+
166
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
167
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
168
+
169
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
170
+ return self.preprocess(images, **kwargs)
171
+
172
+
173
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
174
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
175
+ class PrismaticProcessor(ProcessorMixin):
176
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
177
+ image_processor_class: str = "AutoImageProcessor"
178
+ tokenizer_class: str = "AutoTokenizer"
179
+
180
+ def __init__(
181
+ self,
182
+ image_processor: Optional[ImageProcessingMixin] = None,
183
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
184
+ ) -> None:
185
+ super().__init__(image_processor, tokenizer)
186
+
187
+ def __call__(
188
+ self,
189
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
190
+ images: Union[Image.Image, List[Image.Image]],
191
+ padding: Union[bool, str, PaddingStrategy] = False,
192
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
193
+ max_length: Optional[int] = None,
194
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
195
+ ) -> BatchFeature:
196
+ """
197
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
198
+ forwards images to PrismaticImageProcessor.
199
+ @param text: The (batch) of text to encode; must be a string or list of strings.
200
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
201
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
202
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
203
+ @param max_length: Maximum length (in tokens) to truncate
204
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
205
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
206
+ """
207
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
208
+ text_inputs = self.tokenizer(
209
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
210
+ )
211
+
212
+ # [Validate] Need same number of images and text inputs!
213
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
214
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
215
+
216
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
217
+
218
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
219
+ def batch_decode(
220
+ self,
221
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
222
+ skip_special_tokens: bool = False,
223
+ clean_up_tokenization_spaces: Optional[bool] = None,
224
+ **kwargs: str,
225
+ ) -> List[str]:
226
+ return self.tokenizer.batch_decode(
227
+ sequences=sequences,
228
+ skip_special_tokens=skip_special_tokens,
229
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
230
+ **kwargs,
231
+ )
232
+
233
+ def decode(
234
+ self,
235
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
236
+ skip_special_tokens: bool = False,
237
+ clean_up_tokenization_spaces: Optional[bool] = None,
238
+ **kwargs: str,
239
+ ) -> str:
240
+ return self.tokenizer.decode(
241
+ token_ids=token_ids,
242
+ skip_special_tokens=skip_special_tokens,
243
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
244
+ **kwargs,
245
+ )
246
+
247
+ @property
248
+ def model_input_names(self) -> List[str]:
249
+ tokenizer_input_names = self.tokenizer.model_input_names
250
+ image_processor_input_names = self.image_processor.model_input_names
251
+
252
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
4
+ },
5
+ "processor_class": "PrismaticProcessor"
6
+ }
tokenizer_config.json CHANGED
@@ -2057,6 +2057,9 @@
2057
  "special": true
2058
  }
2059
  },
 
 
 
2060
  "bos_token": "<|begin_of_text|>",
2061
  "clean_up_tokenization_spaces": true,
2062
  "eos_token": "<|end_of_text|>",
 
2057
  "special": true
2058
  }
2059
  },
2060
+ "auto_map": {
2061
+ "AutoProcessor": "processing_prismatic.PrismaticProcessor"
2062
+ },
2063
  "bos_token": "<|begin_of_text|>",
2064
  "clean_up_tokenization_spaces": true,
2065
  "eos_token": "<|end_of_text|>",