delinqu commited on
Commit
d8503a6
·
verified ·
1 Parent(s): fd1764e

Upload folder using huggingface_hub

Browse files
action_tokenizer.py CHANGED
@@ -1,27 +1,16 @@
1
- # MIT License
2
- # Copyright (c) 2025 IPEC at Shanghai AI Laboratory
3
- # Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
4
- # distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
5
- # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
7
- # coding=utf-8
8
-
9
  """
10
  action_tokenizer.py
11
 
12
  Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
13
  """
14
- from typing import List, Union, Dict, Tuple, Optional
15
  import numpy as np
16
  from transformers import PreTrainedTokenizerBase
17
- from pathlib import Path
18
- import json
19
  from scipy.stats import norm
20
  import torch
21
 
22
  ACTION_TOKEN = '<ACTION{:05d}>'
23
 
24
- """Spatial Tokenizer"""
25
  class ActionTokenizer:
26
  def __init__(
27
  self,
@@ -67,7 +56,6 @@ class ActionTokenizer:
67
  def vocab_size(self) -> int:
68
  return self._vocab_size
69
 
70
- """Spatial Tokenizer"""
71
  class TranslationTokenizer:
72
  def __init__(
73
  self,
@@ -258,7 +246,7 @@ class GripperTokenzier:
258
  def vocab_size(self) -> int:
259
  return self.num_bins
260
 
261
- class SphericalCoordinateActionTokenizer:
262
  range_bins = {
263
  "translation": {
264
  "theta_bins": (0.0, np.pi),
@@ -282,7 +270,7 @@ class SphericalCoordinateActionTokenizer:
282
  min_action: float = -1.0,
283
  max_action: float = 1.0,
284
  ):
285
- """set bin_policy if exist, otherwise, caculate bin_policy from gs_params.(unifrom if None Gaussian)
286
  gs_params: Optional[Dict],
287
  bin_policy: Optional[Dict],
288
  """
@@ -293,7 +281,6 @@ class SphericalCoordinateActionTokenizer:
293
 
294
  # set bin policy
295
  self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
296
-
297
  self.translation_tokenizer = TranslationTokenizer(
298
  self.tokenizer,
299
  self.num_bins["translation"],
@@ -406,13 +393,11 @@ class SphericalCoordinateActionTokenizer:
406
  embeddings: tensor (S,E)
407
  """
408
  from scipy.interpolate import griddata
409
- # __import__("ipdb").set_trace()
410
-
411
  new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
412
  trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
413
  trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
414
 
415
- print("🔥 overwrite bin policy and tokenizer bins ...")
416
  self.bin_policy = new_policy
417
  self.min_sigma = min_sigma
418
  self.translation_tokenizer.set_bins(new_policy["translation"])
@@ -442,5 +427,5 @@ class SphericalCoordinateActionTokenizer:
442
  device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
443
  embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
444
  embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
445
- print("🚀 DONE! adapt spatial embedding to new gaussian distributation finished.")
446
  print(embeddings.weight.data)
 
 
 
 
 
 
 
 
 
1
  """
2
  action_tokenizer.py
3
 
4
  Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
5
  """
6
+ from typing import List, Union, Dict, Optional
7
  import numpy as np
8
  from transformers import PreTrainedTokenizerBase
 
 
9
  from scipy.stats import norm
10
  import torch
11
 
12
  ACTION_TOKEN = '<ACTION{:05d}>'
13
 
 
14
  class ActionTokenizer:
15
  def __init__(
16
  self,
 
56
  def vocab_size(self) -> int:
57
  return self._vocab_size
58
 
 
59
  class TranslationTokenizer:
60
  def __init__(
61
  self,
 
246
  def vocab_size(self) -> int:
247
  return self.num_bins
248
 
249
+ class SpatialActionTokenizer:
250
  range_bins = {
251
  "translation": {
252
  "theta_bins": (0.0, np.pi),
 
270
  min_action: float = -1.0,
271
  max_action: float = 1.0,
272
  ):
273
+ """set bin_policy if exist, otherwise, caculate bin_policy from gs_params or use uniform bin grids.
274
  gs_params: Optional[Dict],
275
  bin_policy: Optional[Dict],
276
  """
 
281
 
282
  # set bin policy
283
  self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
 
284
  self.translation_tokenizer = TranslationTokenizer(
285
  self.tokenizer,
286
  self.num_bins["translation"],
 
393
  embeddings: tensor (S,E)
394
  """
395
  from scipy.interpolate import griddata
 
 
396
  new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
397
  trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
398
  trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
399
 
400
+ print("overwrite bin policy and tokenizer bins ...")
401
  self.bin_policy = new_policy
402
  self.min_sigma = min_sigma
403
  self.translation_tokenizer.set_bins(new_policy["translation"])
 
427
  device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
428
  embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
429
  embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
430
+ print("DONE! adapt spatial embedding to new gaussian distributation finished.")
431
  print(embeddings.weight.data)
config.json CHANGED
@@ -1,5 +1,4 @@
1
  {
2
- "_name_or_path": "../pretrained/2025-01-05_09-12-37_oxe_spatial_vla_paligemma3b_zoe_gsN8194_gpu64-204k",
3
  "_vocab_size": 265347,
4
  "action_token_begin_idx": 257153,
5
  "architectures": [
@@ -317,4 +316,4 @@
317
  "use_bias_in_fusion_residual": null,
318
  "use_pretrained_backbone": false
319
  }
320
- }
 
1
  {
 
2
  "_vocab_size": 265347,
3
  "action_token_begin_idx": 257153,
4
  "architectures": [
 
316
  "use_bias_in_fusion_residual": null,
317
  "use_pretrained_backbone": false
318
  }
319
+ }
configuration_spatialvla.py CHANGED
@@ -1,12 +1,16 @@
1
- # MIT License
2
- # Copyright (c) 2025 IPEC at Shanghai AI Laboratory
3
- # Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
4
- # distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
5
- # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
7
- # Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
8
  # coding=utf-8
9
-
 
 
 
 
 
 
 
 
 
 
 
10
  """PaliGemmamodel configuration"""
11
 
12
  import warnings
@@ -15,59 +19,9 @@ from transformers.configuration_utils import PretrainedConfig
15
  from transformers.utils import logging
16
  from transformers import CONFIG_MAPPING, AutoConfig
17
 
18
-
19
  logger = logging.get_logger(__name__)
20
 
21
-
22
  class SpatialVLAConfig(PretrainedConfig):
23
- r"""
24
- This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
25
- PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
26
- with the defaults will yield a similar configuration to that of the PaliGemma-2B.
27
-
28
- e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
29
-
30
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
- documentation from [`PretrainedConfig`] for more information.
32
-
33
- Args:
34
- vision_config (`PaliGemmaVisionConfig`, *optional*):
35
- Custom vision config or dict
36
- text_config (`Union[AutoConfig, dict]`, *optional*):
37
- The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
38
- ignore_index (`int`, *optional*, defaults to -100):
39
- The ignore index for the loss function.
40
- image_token_index (`int`, *optional*, defaults to 256000):
41
- The image token index to encode the image prompt.
42
- vocab_size (`int`, *optional*, defaults to 257152):
43
- Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
44
- `inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
45
- projection_dim (`int`, *optional*, defaults to 2048):
46
- Dimension of the multimodal projection space.
47
- hidden_size (`int`, *optional*, defaults to 2048):
48
- Dimension of the hidden layer of the Language model.
49
-
50
- Example:
51
-
52
- ```python
53
- >>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
54
-
55
- >>> # Initializing a Siglip-like vision config
56
- >>> vision_config = SiglipVisionConfig()
57
-
58
- >>> # Initializing a PaliGemma config
59
- >>> text_config = GemmaConfig()
60
-
61
- >>> # Initializing a PaliGemma paligemma-3b-224 style configuration
62
- >>> configuration = PaliGemmaConfig(vision_config, text_config)
63
-
64
- >>> # Initializing a model from the paligemma-3b-224 style configuration
65
- >>> model = PaliGemmaForConditionalGeneration(configuration)
66
-
67
- >>> # Accessing the model configuration
68
- >>> configuration = model.config
69
- ```"""
70
-
71
  model_type = "spatialvla"
72
  sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
73
 
@@ -87,7 +41,6 @@ class SpatialVLAConfig(PretrainedConfig):
87
  ego3d_patch_reso=4,
88
  n_freqs=8,
89
  use_vision_zoe=True,
90
- # wrap_lora=False,
91
  **kwargs,
92
  ):
93
  self._ignore_index = ignore_index
@@ -138,19 +91,15 @@ class SpatialVLAConfig(PretrainedConfig):
138
  vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
139
  self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
140
  else:
141
- print(f"🔥 init from default configurations ... {self.vision_zoe_config}")
142
- # BUG: initializing zoe in default cause key error
143
- # self.vision_zoe_config = CONFIG_MAPPING["zoedepth"]()
144
  pass
145
 
146
- # NOTE: additional attributes
147
  self.action_token_begin_idx = action_token_begin_idx
148
  self.spatial_token_num = spatial_token_num
149
  self.use_spatial_token = use_spatial_token
150
  self.ego3d_patch_reso = ego3d_patch_reso
151
  self.n_freqs = n_freqs
152
  self.use_vision_zoe = use_vision_zoe
153
- # self.wrap_lora = wrap_lora
154
 
155
  super().__init__(**kwargs)
156
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
  """PaliGemmamodel configuration"""
15
 
16
  import warnings
 
19
  from transformers.utils import logging
20
  from transformers import CONFIG_MAPPING, AutoConfig
21
 
 
22
  logger = logging.get_logger(__name__)
23
 
 
24
  class SpatialVLAConfig(PretrainedConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model_type = "spatialvla"
26
  sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
27
 
 
41
  ego3d_patch_reso=4,
42
  n_freqs=8,
43
  use_vision_zoe=True,
 
44
  **kwargs,
45
  ):
46
  self._ignore_index = ignore_index
 
91
  vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
92
  self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
93
  else:
 
 
 
94
  pass
95
 
96
+ # additional attributes
97
  self.action_token_begin_idx = action_token_begin_idx
98
  self.spatial_token_num = spatial_token_num
99
  self.use_spatial_token = use_spatial_token
100
  self.ego3d_patch_reso = ego3d_patch_reso
101
  self.n_freqs = n_freqs
102
  self.use_vision_zoe = use_vision_zoe
 
103
 
104
  super().__init__(**kwargs)
105
 
modeling_gemma2.py CHANGED
@@ -1,4 +1,5 @@
1
- # custom gemma2 to support flash_attention_2
 
2
  # coding=utf-8
3
  # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
4
  #
@@ -205,10 +206,7 @@ def flash_attention_forward(
205
  ) -> Tuple[torch.Tensor, None]:
206
  # NOTE: None mask cause un defined https://github.com/huggingface/transformers/blob/c8c8dffbe45ebef0a8dba4a51024e5e5e498596b/src/transformers/models/gemma2/modeling_gemma2.py#L211
207
  seq_len = query.shape[2]
208
- # print(f"🔥 query {query.shape}, key {key.shape}, value: {value.shape}")
209
  if mask is not None:
210
- # print(f"🔥 mask {mask.shape}")
211
- # seq_len = mask.shape[1]
212
  query = query[:, :, :seq_len]
213
  value = value[:, :, :seq_len]
214
 
 
1
+ # custom gemma2 to support flash_attention_2,
2
+ # source from https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/gemma2/modeling_gemma2.py
3
  # coding=utf-8
4
  # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
5
  #
 
206
  ) -> Tuple[torch.Tensor, None]:
207
  # NOTE: None mask cause un defined https://github.com/huggingface/transformers/blob/c8c8dffbe45ebef0a8dba4a51024e5e5e498596b/src/transformers/models/gemma2/modeling_gemma2.py#L211
208
  seq_len = query.shape[2]
 
209
  if mask is not None:
 
 
210
  query = query[:, :, :seq_len]
211
  value = value[:, :, :seq_len]
212
 
modeling_spatialvla.py CHANGED
@@ -1,153 +1,118 @@
1
- # MIT License
2
- # Copyright (c) 2025 IPEC at Shanghai AI Laboratory
3
- # Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
4
- # distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
5
- # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
7
- # Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
8
  # coding=utf-8
9
-
 
 
 
 
 
 
 
 
 
 
 
 
10
  """PyTorch PaliGemmamodel."""
11
 
12
  from dataclasses import dataclass
13
  from typing import List, Optional, Tuple, Union
14
 
 
15
  import torch
16
  import torch.utils.checkpoint
17
  from torch import nn
18
  from torch.linalg import inv
19
- import torchvision.transforms.functional as F
20
-
21
- import os
22
  from transformers.cache_utils import Cache, HybridCache, StaticCache
23
  from transformers.generation import GenerationMixin
24
  from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
25
  from transformers.utils import (
26
  ModelOutput,
27
- add_start_docstrings,
28
- add_start_docstrings_to_model_forward,
29
- is_flash_attn_2_available,
30
  logging,
31
- replace_return_docstrings,
32
  )
33
  from .configuration_spatialvla import SpatialVLAConfig
34
- from .modeling_ego3d import Ego3DPositionEmbeddingMLP, process_zoe
35
  from .modeling_gemma2 import Gemma2ForCausalLM
 
36
 
37
- if is_flash_attn_2_available():
38
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
39
-
40
- from transformers import AutoModel, AutoModelForCausalLM, ZoeDepthForDepthEstimation
41
-
42
 
43
  logger = logging.get_logger(__name__)
44
 
45
- _CONFIG_FOR_DOC = "PaliGemmaConfig"
46
-
47
- # constant
48
- SIGLIP_MEAN, SIGLIP_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
49
-
50
- # Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
51
- # But Paligemma has no causal mask on prefix
52
- def _prepare_4d_causal_attention_mask_with_cache_position(
53
- attention_mask: torch.Tensor,
54
- sequence_length: int,
55
- target_length: int,
56
- dtype: torch.dtype,
57
- device: torch.device,
58
- min_dtype: float,
59
- cache_position: torch.Tensor,
60
- batch_size: int,
61
- is_training: bool = False,
62
- token_type_ids: torch.Tensor = None,
63
- **kwargs,
64
- ):
65
  """
66
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
67
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
68
-
69
- Args:
70
- attention_mask (`torch.Tensor`):
71
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
72
- sequence_length (`int`):
73
- The sequence length being processed.
74
- target_length (`int`):
75
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
76
- dtype (`torch.dtype`):
77
- The dtype to use for the 4D attention mask.
78
- device (`torch.device`):
79
- The device to plcae the 4D attention mask on.
80
- min_dtype (`float`):
81
- The minimum value representable with the dtype `dtype`.
82
- cache_position (`torch.Tensor`):
83
- Indices depicting the position of the input sequence tokens in the sequence.
84
- batch_size (`torch.Tensor`):
85
- Batch size.
86
- is_training (`bool`):
87
- Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
88
- """
89
- if attention_mask is not None and attention_mask.dim() == 4:
90
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
91
- causal_mask = attention_mask
92
- else:
93
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
94
- # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
95
- if sequence_length != 1:
96
- if is_training:
97
- causal_mask = torch.triu(causal_mask, diagonal=1)
98
- else:
99
- causal_mask[:, :sequence_length] = 0.0
100
 
101
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
102
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
103
- if attention_mask is not None:
104
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
105
- mask_length = attention_mask.shape[-1]
106
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
107
- padding_mask = padding_mask == 0
108
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
109
- padding_mask, min_dtype
110
- )
111
- # we are training thus we need to create a full mask on the image + prefix but causal on suffix
112
- if is_training:
113
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
114
- token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
115
- )
116
- return causal_mask
 
 
 
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  @dataclass
120
  class SpatialVLACausalLMOutputWithPast(ModelOutput):
121
- """
122
- Base class for PaliGemmacausal language model (or autoregressive) outputs.
123
-
124
- Args:
125
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
126
- Language modeling loss (for next-token prediction).
127
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
128
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
129
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
130
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
131
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
132
-
133
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
134
- `past_key_values` input) to speed up sequential decoding.
135
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
136
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
137
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
138
-
139
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
140
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
141
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
142
- sequence_length)`.
143
-
144
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
145
- heads.
146
- image_hidden_states (`torch.FloatTensor`, *optional*):
147
- A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
148
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
149
- """
150
-
151
  loss: Optional[torch.FloatTensor] = None
152
  logits: torch.FloatTensor = None
153
  past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
@@ -155,7 +120,6 @@ class SpatialVLACausalLMOutputWithPast(ModelOutput):
155
  attentions: Optional[Tuple[torch.FloatTensor]] = None
156
  image_hidden_states: Optional[torch.FloatTensor] = None
157
 
158
-
159
  class SpatialVLAMultiModalProjector(nn.Module):
160
  def __init__(self, config: SpatialVLAConfig):
161
  super().__init__()
@@ -163,31 +127,8 @@ class SpatialVLAMultiModalProjector(nn.Module):
163
 
164
  def forward(self, image_features):
165
  hidden_states = self.linear(image_features)
166
-
167
  return hidden_states
168
 
169
-
170
- PALIGEMMA_START_DOCSTRING = r"""
171
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
172
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
173
- etc.)
174
-
175
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
176
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
177
- and behavior.
178
-
179
- Parameters:
180
- config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]):
181
- Model configuration class with all the parameters of the model. Initializing with a config file does not
182
- load the weights associated with the model, only the configuration. Check out the
183
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
184
- """
185
-
186
-
187
- @add_start_docstrings(
188
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
189
- PALIGEMMA_START_DOCSTRING,
190
- )
191
  class SpatialVLAPreTrainedModel(PreTrainedModel):
192
  config_class = SpatialVLAConfig
193
  base_model_prefix = "model"
@@ -202,8 +143,6 @@ class SpatialVLAPreTrainedModel(PreTrainedModel):
202
  _supports_sdpa = True
203
 
204
  def _init_weights(self, module):
205
- # important: this ported version of PaliGemmaisn't meant for training from scratch - only
206
- # inference and fine-tuning
207
  std = (
208
  self.config.initializer_range
209
  if hasattr(self.config, "initializer_range")
@@ -222,99 +161,20 @@ class SpatialVLAPreTrainedModel(PreTrainedModel):
222
  if module.padding_idx is not None:
223
  module.weight.data[module.padding_idx].zero_()
224
 
225
-
226
- PALIGEMMA_INPUTS_DOCSTRING = r"""
227
- Args:
228
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
229
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
230
- it.
231
-
232
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
233
- [`PreTrainedTokenizer.__call__`] for details.
234
-
235
- [What are input IDs?](../glossary#input-ids)
236
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
237
- The tensors corresponding to the input images. Pixel values can be obtained using
238
- [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
239
- [`SiglipImageProcessor`] for processing images).
240
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
241
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
242
-
243
- - 1 for tokens that are **not masked**,
244
- - 0 for tokens that are **masked**.
245
-
246
- [What are attention masks?](../glossary#attention-mask)
247
-
248
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
249
- [`PreTrainedTokenizer.__call__`] for details.
250
-
251
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
252
- `past_key_values`).
253
-
254
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
255
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
256
- information on the default strategy.
257
-
258
- - 1 indicates the head is **not masked**,
259
- - 0 indicates the head is **masked**.
260
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
261
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
262
- config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
263
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
264
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
265
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
266
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
267
-
268
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
269
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
270
-
271
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
272
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
273
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
274
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
275
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
276
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
277
- model's internal embedding lookup matrix.
278
- use_cache (`bool`, *optional*):
279
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
280
- `past_key_values`).
281
- output_attentions (`bool`, *optional*):
282
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
283
- tensors for more detail.
284
- output_hidden_states (`bool`, *optional*):
285
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
286
- more detail.
287
- return_dict (`bool`, *optional*):
288
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
289
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
290
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
291
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
292
- the complete sequence length.
293
- """
294
-
295
-
296
- @add_start_docstrings(
297
- """The PALIGEMMA model which consists of a vision backbone and a language model.""",
298
- PALIGEMMA_START_DOCSTRING,
299
- )
300
  class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMixin):
301
  def __init__(self, config: SpatialVLAConfig, vision_model=None, vision_zoe_model=None, projector_model=None, language_model=None):
302
  super().__init__(config)
303
- # vision model
304
  self.vision_tower = vision_model or AutoModel.from_config(config=config.vision_config)
305
- # projector
306
  self.multi_modal_projector = projector_model or SpatialVLAMultiModalProjector(config)
307
- # language model
308
  self.vocab_size = config.text_config.vocab_size
309
  if language_model is None:
310
- language_model = Gemma2ForCausalLM(config=config.text_config) if config.text_config.model_type == "gemma2" else AutoModelForCausalLM.from_config(config=config.text_config)
311
- # set tile key
312
  if language_model._tied_weights_keys is not None:
313
  self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
314
  self.language_model = language_model
315
 
316
  if config.use_vision_zoe:
317
- # zoe model
318
  self.vision_zoe_model = vision_zoe_model or ZoeDepthForDepthEstimation(config.vision_zoe_config)
319
  self.position_embedding_3d = Ego3DPositionEmbeddingMLP(
320
  config.ego3d_patch_reso**2 * 3, num_pos_feats=config.vision_config.hidden_size, n_freqs=config.n_freqs
@@ -326,15 +186,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
326
  uv_h = torch.stack([x, y, torch.ones_like(x)], dim=0).reshape(3, -1) # (3 hw)
327
  self.register_buffer("uv_h", uv_h, persistent=False)
328
 
329
- # NOTE: add shared addtional spatial token embeddings for <ACTION> <IMG>
330
  if config.use_spatial_token:
331
  self.spatial_embed_tokens = nn.Embedding(self.config.spatial_token_num, config.text_config.hidden_size)
332
  else:
333
  self.spatial_embed_tokens = None
334
-
335
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
336
- # self.post_init() # BUG: cause from_pretrained failed!
337
- # self.position_embedding_3d._reset_parameters()
338
 
339
 
340
  def backproject_patch(self, K: torch.Tensor, depth: torch.Tensor, patch_size=14, reso=2) -> torch.Tensor:
@@ -343,44 +200,48 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
343
  Args:
344
  K: camera intrinsic matrix (b 3 3)
345
  depth: depth map (b 1 h w)
346
- pixel_offset: offset to the pixel coordinate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
- # __import__("ipdb").set_trace()
349
  b, c, h, w = depth.shape
350
  hp, wp = h // patch_size, w // patch_size
351
  sub_hp = sub_wp = reso
352
- patch_depth = torch.nn.functional.interpolate(depth, size=(hp * reso, wp * reso), mode="area").reshape(b, c, -1)
353
-
354
- # import torchvision; torchvision.utils.save_image(zoe_pixel_values[0], "zoe_image.png")
355
  p_cam = (inv(K.float()) @ self.uv_h.float()) * patch_depth # (b 3 3) @ (3 hw) -> (b 3 hw) * (b 1 hw) -> (b 3 hw)
356
  patch_p_cam = p_cam.reshape(b, 3, hp, sub_hp, wp, sub_wp).permute(0, 2, 4, 3, 5, 1).reshape(b, hp * wp, -1)
357
  return patch_p_cam
358
 
359
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
360
  def get_input_embeddings(self):
361
  return self.language_model.get_input_embeddings()
362
 
363
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
364
  def set_input_embeddings(self, value):
365
  self.language_model.set_input_embeddings(value)
366
 
367
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
368
  def get_output_embeddings(self):
369
  return self.language_model.get_output_embeddings()
370
 
371
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
372
  def set_output_embeddings(self, new_embeddings):
373
  self.language_model.set_output_embeddings(new_embeddings)
374
 
375
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
376
  def set_decoder(self, decoder):
377
  self.language_model.set_decoder(decoder)
378
 
379
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
380
  def get_decoder(self):
381
  return self.language_model.get_decoder()
382
 
383
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
384
  def tie_weights(self):
385
  return self.language_model.tie_weights()
386
 
@@ -390,11 +251,7 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
390
  pad_to_multiple_of: Optional[int] = None,
391
  mean_resizing: bool = True,
392
  ) -> nn.Embedding:
393
- # TODO: is_deepspeed_zero3_enabled gather
394
- print(f"resize token embeddings from {self.language_model.get_output_embeddings().weight.shape} to (*,{new_num_tokens})")
395
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
396
-
397
- # update base model and current model config
398
  vocab_size = model_embeds.weight.shape[0]
399
  self.config.text_config.vocab_size = self.vocab_size = self.config._vocab_size = vocab_size
400
  self.tie_weights()
@@ -431,18 +288,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
431
  )
432
 
433
  if attention_mask is not None and attention_mask.dim() == 4:
434
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
435
  return attention_mask
436
 
437
- causal_mask = torch.full(
438
- (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
439
- )
440
- # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
441
  if sequence_length != 1:
442
- if is_training:
443
- causal_mask = torch.triu(causal_mask, diagonal=1)
444
- else:
445
- causal_mask[:, :sequence_length] = 0.0
446
 
447
  causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
448
  causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
@@ -451,29 +302,13 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
451
  mask_length = attention_mask.shape[-1]
452
  padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
453
  padding_mask = padding_mask == 0
454
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
455
- padding_mask, min_dtype
456
- )
457
- # we are training thus we need to create a full mask on the image + prefix but causal on suffix
458
  if is_training:
459
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
460
- token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
461
- )
462
  return causal_mask
463
 
464
  def get_image_features(self, pixel_values: torch.FloatTensor, intrinsic: torch.FloatTensor):
465
- """
466
- Obtains image last hidden states from the vision tower and apply multimodal projection.
467
-
468
- Args:
469
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
470
- The tensors corresponding to the input images.
471
- Returns:
472
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
473
- """
474
- # mintrinsic = intrinsic.reshape(-1, 3, 3)
475
- # siglip vision tower
476
- siglip_pixel_values = F.normalize(pixel_values, mean=SIGLIP_MEAN, std=SIGLIP_STD)
477
  image_outputs = self.vision_tower(siglip_pixel_values)
478
 
479
  # ego3d position encoding
@@ -482,13 +317,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
482
  with torch.no_grad():
483
  pvh, pvw = pixel_values.shape[-2:]
484
  depth = self.vision_zoe_model(pixel_values=zoe_pixel_values).predicted_depth
485
- depth = torch.nn.functional.interpolate(
486
  depth.unsqueeze(1),
487
  size=(pvh+2*ph, pvw+2*pw),
488
  mode="bicubic",
489
  align_corners=True,
490
  )[..., ph:-ph, pw:-pw]
491
- # depth = torch.clamp(depth, 0., 4.0) # NOTE: we find that depth w/o clamp performs better
492
  xyz = self.backproject_patch(
493
  intrinsic, depth, patch_size=self.config.vision_config.patch_size, reso=self.config.ego3d_patch_reso
494
  ) # (b, n, 3*4)
@@ -500,8 +334,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
500
  image_features = image_features / (self.config.text_config.hidden_size**0.5)
501
  return image_features
502
 
503
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
504
- @replace_return_docstrings(output_type=SpatialVLACausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
505
  def forward(
506
  self,
507
  input_ids: torch.LongTensor = None,
@@ -521,93 +353,29 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
521
  return_dict: Optional[bool] = None,
522
  num_logits_to_keep: int = 0,
523
  ) -> Union[Tuple, SpatialVLACausalLMOutputWithPast]:
524
- r"""
525
- Args:
526
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
527
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
528
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
529
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
530
-
531
- num_logits_to_keep (`int`, *optional*):
532
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
533
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
534
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
535
-
536
- Returns:
537
-
538
- Example:
539
-
540
- ```python
541
- >>> from PIL import Image
542
- >>> import requests
543
- >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
544
-
545
- >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
546
- >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
547
-
548
- >>> prompt = "answer en Where is the cow standing?"
549
- >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
550
- >>> image = Image.open(requests.get(url, stream=True).raw)
551
-
552
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
553
-
554
- >>> # Generate
555
- >>> generate_ids = model.generate(**inputs, max_length=30)
556
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
557
- "answer en Where is the cow standing?\nbeach"
558
- ```"""
559
- # print(f"**************************************\n \
560
- # input_ids {input_ids} \n \
561
- # labels {labels} \n \
562
- # token_type_ids {token_type_ids} \n \
563
- # attention_mask {attention_mask} \n \
564
- # actions {actions} \n \
565
- # **************************************"
566
- # )
567
- # print(f"model.language_model.config._attn_implementation {self.language_model.config._attn_implementation} model.config.vision_config._attn_implementation_internal {self.config.vision_config._attn_implementation_internal} \n \
568
- # model.vision_tower.config._attn_implementation {self.vision_tower.config._attn_implementation} model.config.vision_config._attn_implementation_internal {self.config.vision_config._attn_implementation_internal}")
569
- # __import__("ipdb").set_trace()
570
- if (input_ids is None) ^ (inputs_embeds is not None):
571
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
572
-
573
- if pixel_values is not None and inputs_embeds is not None:
574
- raise ValueError(
575
- "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
576
- )
577
 
578
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
579
- output_hidden_states = (
580
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
581
- )
582
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
583
 
584
  is_training = token_type_ids is not None and labels is not None
585
 
586
- if inputs_embeds is None:
587
- inputs_embeds = self.get_input_embeddings()(input_ids).clone() ## avoid checkpint grad True
588
-
589
- # NOTE: replace the fixed embeddings with trainable spatial embeddings
590
- # BUG: LoRA causes inputs_embeds requires_grad = True
591
- # peft: https://github.com/huggingface/peft/blob/ec92cdcc41fe1b141bfe1e0da69b38a7e601cc80/src/peft/peft_model.py#L687
592
- # hf: https://github.com/huggingface/transformers/blob/05260a1fc1c8571a2b421ce72b680d5f1bc3e5a4/src/transformers/modeling_utils.py#L2545
593
- # lora w/ prompt: https://discuss.huggingface.co/t/combine-between-lora-and-prompt-tunning/65151
594
  if self.config.use_spatial_token:
595
  spatial_selected = (input_ids >= self.config.action_token_begin_idx) & (input_ids < self.config.action_token_begin_idx + self.config.spatial_token_num)
596
  inputs_embeds[spatial_selected] = inputs_embeds[spatial_selected] * 0.0 + self.spatial_embed_tokens(input_ids[spatial_selected] - self.config.action_token_begin_idx)
597
 
598
  if cache_position is None:
599
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
600
- cache_position = torch.arange(
601
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
602
- )
603
 
604
  if position_ids is None:
605
  position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
606
 
607
- # Merge text and images
608
  if pixel_values is not None:
609
  image_features = self.get_image_features(pixel_values, intrinsic)
610
-
611
  special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
612
  special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
613
  if inputs_embeds[special_image_mask].numel() != image_features.numel():
@@ -647,20 +415,16 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
647
  logits = outputs.logits
648
  loss = None
649
  if labels is not None:
650
- # Upcast to float if we need to compute the loss to avoid potential precision issues
651
  logits = logits.float()
652
  shift_logits = logits[..., :-1, :]
653
  shift_labels = labels[..., 1:]
654
  if attention_mask is not None:
655
- # we use the input attention mask to shift the logits and labels, because it is 2D.
656
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
657
  shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
658
  shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
659
  shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
660
  else:
661
  shift_logits = shift_logits.contiguous()
662
  shift_labels = shift_labels.contiguous()
663
- # Flatten the tokens
664
  loss_fct = nn.CrossEntropyLoss()
665
 
666
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
@@ -679,6 +443,7 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
679
  image_hidden_states=image_features if pixel_values is not None else None,
680
  )
681
 
 
682
  def prepare_inputs_for_generation(
683
  self,
684
  input_ids,
@@ -695,7 +460,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
695
  labels=None,
696
  **kwargs,
697
  ):
698
- # Overwritten -- custom `position_ids` and `pixel_values` handling
699
  model_inputs = self.language_model.prepare_inputs_for_generation(
700
  input_ids,
701
  past_key_values=past_key_values,
@@ -708,19 +472,13 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
708
  token_type_ids=token_type_ids,
709
  **kwargs,
710
  )
711
-
712
- # position_ids in Paligemma are 1-indexed
713
  if model_inputs.get("position_ids") is not None:
714
  model_inputs["position_ids"] += 1
715
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
716
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
717
  if cache_position[0] == 0:
718
  model_inputs["pixel_values"] = pixel_values
719
  is_training = token_type_ids is not None and labels is not None
720
  if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
721
- causal_mask = self._update_causal_mask(
722
- attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
723
- )
724
  model_inputs["attention_mask"] = causal_mask
725
  model_inputs["intrinsic"] = intrinsic
726
  return model_inputs
@@ -765,9 +523,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
765
  weights_only=weights_only,
766
  **kwargs,
767
  )
768
- # NOTE: tie the weights of the embed_tokens with lm head (donot work if un_tie_weight)
769
- # model.language_model.tie_weights()
770
- # NOTE: tie the data of spatial_embed_tokens with embed_tokens (BUG: forweight sync issue in training)
771
  if model.config.use_spatial_token:
772
  model.language_model.model.embed_tokens.weight.data[-model.config.spatial_token_num:] = model.spatial_embed_tokens.weight.data
773
  return model
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  """PyTorch PaliGemmamodel."""
16
 
17
  from dataclasses import dataclass
18
  from typing import List, Optional, Tuple, Union
19
 
20
+ import os
21
  import torch
22
  import torch.utils.checkpoint
23
  from torch import nn
24
  from torch.linalg import inv
25
+ import torchvision.transforms.functional as TF
26
+ import torch.nn.functional as F
 
27
  from transformers.cache_utils import Cache, HybridCache, StaticCache
28
  from transformers.generation import GenerationMixin
29
  from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
30
  from transformers.utils import (
31
  ModelOutput,
 
 
 
32
  logging,
 
33
  )
34
  from .configuration_spatialvla import SpatialVLAConfig
 
35
  from .modeling_gemma2 import Gemma2ForCausalLM
36
+ from transformers import AutoModel, ZoeDepthForDepthEstimation
37
 
38
+ SIGLIP_MEAN, SIGLIP_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
39
+ ZOE_MEAN, ZOE_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
 
 
 
40
 
41
  logger = logging.get_logger(__name__)
42
 
43
+ class Ego3DPositionEmbeddingMLP(nn.Module):
44
+ """Absolute pos embedding, learned.
45
+ https://github.com/kwea123/nerf_pl/blob/52aeb387da64a9ad9a0f914ea9b049ffc598b20c/models/nerf.py#L4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ def __init__(self, in_channels=3, num_pos_feats=768, n_freqs=8, logscale=True):
49
+ super(Ego3DPositionEmbeddingMLP, self).__init__()
50
+ self.n_freqs = n_freqs
51
+ self.freq_out_channels = in_channels * (2 * n_freqs + 1)
52
+ if logscale:
53
+ freq_bands = 2 ** torch.linspace(0, n_freqs - 1, n_freqs)
54
+ else:
55
+ freq_bands = torch.linspace(1, 2 ** (n_freqs - 1), n_freqs)
56
+
57
+ center = torch.tensor([0., 0., 2.]).repeat(in_channels // 3)
58
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
59
+ self.register_buffer("center", center, persistent=False)
60
+
61
+ self.position_embedding_head = nn.Sequential(
62
+ nn.Linear(self.freq_out_channels, num_pos_feats),
63
+ nn.LayerNorm(num_pos_feats),
64
+ nn.ReLU(),
65
+ nn.Linear(num_pos_feats, num_pos_feats),
66
+ )
67
+ self._reset_parameters()
68
 
69
+ def _reset_parameters(self):
70
+ """init with small weights to maintain stable training."""
71
+ for p in self.parameters():
72
+ if p.dim() > 1:
73
+ nn.init.xavier_uniform_(p, gain=0.01)
74
+
75
+ @torch.no_grad()
76
+ def frequency_encoding(self, xyz):
77
+ """
78
+ Embeds x to (x, sin(2^k x), cos(2^k x), ...)
79
+ Different from the paper, "x" is also in the output
80
+ See https://github.com/bmild/nerf/issues/12
81
+ x \in [-2, 2]
82
+ y \in [-2, 2]
83
+ z \in [0., 4]
84
+ Inputs:
85
+ x: (b n m)
86
+ Outputs:
87
+ out: (b n o)
88
+ """
89
+ xyz_n = ((xyz - self.center) / 2.0).to(self.freq_bands.dtype)
90
+ xyz_feq = xyz_n.unsqueeze(-1) * self.freq_bands # (b n m 1)
91
+ sin_xyz, cos_xyz = torch.sin(xyz_feq), torch.cos(xyz_feq) # (b n m nf)
92
+ encoding = torch.cat([xyz_n.unsqueeze(-1), sin_xyz, cos_xyz], -1).reshape(*xyz.shape[:2], -1)
93
+ return encoding
94
+
95
+ def forward(self, xyz):
96
+ """Forward pass, xyz is (B, N, 3or6), output (B, N, F)."""
97
+ freq_encoding = self.frequency_encoding(xyz)
98
+ position_embedding = self.position_embedding_head(freq_encoding)
99
+ return position_embedding
100
+
101
+ def process_zoe(pixel_values, pad_mode="reflect", output_size=(384, 512)):
102
+ """https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/zoedepth/image_processing_zoedepth.py"""
103
+ # h, w = images.shape[-2:]
104
+ # pad
105
+ ph, pw = 31, 31 # int((h / 2)**0.5 * 3), int((w / 2)**0.5 * 3) # 32, 31
106
+ images = F.pad(pixel_values, (pw, pw, ph, ph), mode=pad_mode)
107
+ # resize
108
+ size = (384, 384) # get_resize_output_image_size
109
+ images = F.interpolate(images, size=size, mode="bicubic", align_corners=True)
110
+ # zoe: padding -> resize -> nomalize. we follow `nomalize -> padding -> resize` from siglip
111
+ images = TF.normalize(images, mean=ZOE_MEAN, std=ZOE_STD)
112
+ return images, ph, pw
113
 
114
  @dataclass
115
  class SpatialVLACausalLMOutputWithPast(ModelOutput):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  loss: Optional[torch.FloatTensor] = None
117
  logits: torch.FloatTensor = None
118
  past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
 
120
  attentions: Optional[Tuple[torch.FloatTensor]] = None
121
  image_hidden_states: Optional[torch.FloatTensor] = None
122
 
 
123
  class SpatialVLAMultiModalProjector(nn.Module):
124
  def __init__(self, config: SpatialVLAConfig):
125
  super().__init__()
 
127
 
128
  def forward(self, image_features):
129
  hidden_states = self.linear(image_features)
 
130
  return hidden_states
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  class SpatialVLAPreTrainedModel(PreTrainedModel):
133
  config_class = SpatialVLAConfig
134
  base_model_prefix = "model"
 
143
  _supports_sdpa = True
144
 
145
  def _init_weights(self, module):
 
 
146
  std = (
147
  self.config.initializer_range
148
  if hasattr(self.config, "initializer_range")
 
161
  if module.padding_idx is not None:
162
  module.weight.data[module.padding_idx].zero_()
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMixin):
165
  def __init__(self, config: SpatialVLAConfig, vision_model=None, vision_zoe_model=None, projector_model=None, language_model=None):
166
  super().__init__(config)
167
+
168
  self.vision_tower = vision_model or AutoModel.from_config(config=config.vision_config)
 
169
  self.multi_modal_projector = projector_model or SpatialVLAMultiModalProjector(config)
 
170
  self.vocab_size = config.text_config.vocab_size
171
  if language_model is None:
172
+ language_model = Gemma2ForCausalLM(config=config.text_config)
 
173
  if language_model._tied_weights_keys is not None:
174
  self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
175
  self.language_model = language_model
176
 
177
  if config.use_vision_zoe:
 
178
  self.vision_zoe_model = vision_zoe_model or ZoeDepthForDepthEstimation(config.vision_zoe_config)
179
  self.position_embedding_3d = Ego3DPositionEmbeddingMLP(
180
  config.ego3d_patch_reso**2 * 3, num_pos_feats=config.vision_config.hidden_size, n_freqs=config.n_freqs
 
186
  uv_h = torch.stack([x, y, torch.ones_like(x)], dim=0).reshape(3, -1) # (3 hw)
187
  self.register_buffer("uv_h", uv_h, persistent=False)
188
 
189
+ # shared spatial embeddings for <ACTION> <IMG>
190
  if config.use_spatial_token:
191
  self.spatial_embed_tokens = nn.Embedding(self.config.spatial_token_num, config.text_config.hidden_size)
192
  else:
193
  self.spatial_embed_tokens = None
 
194
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
 
195
 
196
 
197
  def backproject_patch(self, K: torch.Tensor, depth: torch.Tensor, patch_size=14, reso=2) -> torch.Tensor:
 
200
  Args:
201
  K: camera intrinsic matrix (b 3 3)
202
  depth: depth map (b 1 h w)
203
+ patch_size: patch size for siglip
204
+ reso: reso^2 -> sample points in each patch
205
+ patch sz = 14 ......
206
+ ┌────────┬────────┐
207
+ │ ─ ─ │ ─ ─ │
208
+ │ points │ ├─ ─ ─
209
+ │ ─ ─ │ ─ ─ │
210
+ ├────────┼────────┤
211
+ │ ─ ─ │ ─ ─ │
212
+ │ │ │
213
+ │ ─ ─ │ ─ ─ │
214
+ └────────┴────────┘
215
+ reso=2───►points=4
216
+
217
+
218
  """
 
219
  b, c, h, w = depth.shape
220
  hp, wp = h // patch_size, w // patch_size
221
  sub_hp = sub_wp = reso
222
+ patch_depth = F.interpolate(depth, size=(hp * reso, wp * reso), mode="area").reshape(b, c, -1)
 
 
223
  p_cam = (inv(K.float()) @ self.uv_h.float()) * patch_depth # (b 3 3) @ (3 hw) -> (b 3 hw) * (b 1 hw) -> (b 3 hw)
224
  patch_p_cam = p_cam.reshape(b, 3, hp, sub_hp, wp, sub_wp).permute(0, 2, 4, 3, 5, 1).reshape(b, hp * wp, -1)
225
  return patch_p_cam
226
 
 
227
  def get_input_embeddings(self):
228
  return self.language_model.get_input_embeddings()
229
 
 
230
  def set_input_embeddings(self, value):
231
  self.language_model.set_input_embeddings(value)
232
 
 
233
  def get_output_embeddings(self):
234
  return self.language_model.get_output_embeddings()
235
 
 
236
  def set_output_embeddings(self, new_embeddings):
237
  self.language_model.set_output_embeddings(new_embeddings)
238
 
 
239
  def set_decoder(self, decoder):
240
  self.language_model.set_decoder(decoder)
241
 
 
242
  def get_decoder(self):
243
  return self.language_model.get_decoder()
244
 
 
245
  def tie_weights(self):
246
  return self.language_model.tie_weights()
247
 
 
251
  pad_to_multiple_of: Optional[int] = None,
252
  mean_resizing: bool = True,
253
  ) -> nn.Embedding:
 
 
254
  model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
 
 
255
  vocab_size = model_embeds.weight.shape[0]
256
  self.config.text_config.vocab_size = self.vocab_size = self.config._vocab_size = vocab_size
257
  self.tie_weights()
 
288
  )
289
 
290
  if attention_mask is not None and attention_mask.dim() == 4:
 
291
  return attention_mask
292
 
293
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device)
 
 
 
294
  if sequence_length != 1:
295
+ if is_training: causal_mask = torch.triu(causal_mask, diagonal=1)
296
+ else: causal_mask[:, :sequence_length] = 0.0
 
 
297
 
298
  causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
299
  causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
 
302
  mask_length = attention_mask.shape[-1]
303
  padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
304
  padding_mask = padding_mask == 0
305
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
 
 
 
306
  if is_training:
307
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0)
 
 
308
  return causal_mask
309
 
310
  def get_image_features(self, pixel_values: torch.FloatTensor, intrinsic: torch.FloatTensor):
311
+ siglip_pixel_values = TF.normalize(pixel_values, mean=SIGLIP_MEAN, std=SIGLIP_STD)
 
 
 
 
 
 
 
 
 
 
 
312
  image_outputs = self.vision_tower(siglip_pixel_values)
313
 
314
  # ego3d position encoding
 
317
  with torch.no_grad():
318
  pvh, pvw = pixel_values.shape[-2:]
319
  depth = self.vision_zoe_model(pixel_values=zoe_pixel_values).predicted_depth
320
+ depth = F.interpolate(
321
  depth.unsqueeze(1),
322
  size=(pvh+2*ph, pvw+2*pw),
323
  mode="bicubic",
324
  align_corners=True,
325
  )[..., ph:-ph, pw:-pw]
 
326
  xyz = self.backproject_patch(
327
  intrinsic, depth, patch_size=self.config.vision_config.patch_size, reso=self.config.ego3d_patch_reso
328
  ) # (b, n, 3*4)
 
334
  image_features = image_features / (self.config.text_config.hidden_size**0.5)
335
  return image_features
336
 
 
 
337
  def forward(
338
  self,
339
  input_ids: torch.LongTensor = None,
 
353
  return_dict: Optional[bool] = None,
354
  num_logits_to_keep: int = 0,
355
  ) -> Union[Tuple, SpatialVLACausalLMOutputWithPast]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ output_attentions = output_attentions or self.config.output_attentions
358
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
359
+ return_dict = return_dict or self.config.use_return_dict
 
 
360
 
361
  is_training = token_type_ids is not None and labels is not None
362
 
363
+ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids).clone() # avoid checkpint grad True
364
+
 
 
 
 
 
 
365
  if self.config.use_spatial_token:
366
  spatial_selected = (input_ids >= self.config.action_token_begin_idx) & (input_ids < self.config.action_token_begin_idx + self.config.spatial_token_num)
367
  inputs_embeds[spatial_selected] = inputs_embeds[spatial_selected] * 0.0 + self.spatial_embed_tokens(input_ids[spatial_selected] - self.config.action_token_begin_idx)
368
 
369
  if cache_position is None:
370
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
371
+ cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
 
 
372
 
373
  if position_ids is None:
374
  position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
375
 
376
+ # merge
377
  if pixel_values is not None:
378
  image_features = self.get_image_features(pixel_values, intrinsic)
 
379
  special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
380
  special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
381
  if inputs_embeds[special_image_mask].numel() != image_features.numel():
 
415
  logits = outputs.logits
416
  loss = None
417
  if labels is not None:
 
418
  logits = logits.float()
419
  shift_logits = logits[..., :-1, :]
420
  shift_labels = labels[..., 1:]
421
  if attention_mask is not None:
 
 
422
  shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
423
  shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
424
  shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
425
  else:
426
  shift_logits = shift_logits.contiguous()
427
  shift_labels = shift_labels.contiguous()
 
428
  loss_fct = nn.CrossEntropyLoss()
429
 
430
  flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
 
443
  image_hidden_states=image_features if pixel_values is not None else None,
444
  )
445
 
446
+ # AR inference
447
  def prepare_inputs_for_generation(
448
  self,
449
  input_ids,
 
460
  labels=None,
461
  **kwargs,
462
  ):
 
463
  model_inputs = self.language_model.prepare_inputs_for_generation(
464
  input_ids,
465
  past_key_values=past_key_values,
 
472
  token_type_ids=token_type_ids,
473
  **kwargs,
474
  )
 
 
475
  if model_inputs.get("position_ids") is not None:
476
  model_inputs["position_ids"] += 1
 
 
477
  if cache_position[0] == 0:
478
  model_inputs["pixel_values"] = pixel_values
479
  is_training = token_type_ids is not None and labels is not None
480
  if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
481
+ causal_mask = self._update_causal_mask(attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training)
 
 
482
  model_inputs["attention_mask"] = causal_mask
483
  model_inputs["intrinsic"] = intrinsic
484
  return model_inputs
 
523
  weights_only=weights_only,
524
  **kwargs,
525
  )
 
 
 
526
  if model.config.use_spatial_token:
527
  model.language_model.model.embed_tokens.weight.data[-model.config.spatial_token_num:] = model.spatial_embed_tokens.weight.data
528
  return model
processing_spatialvla.py CHANGED
@@ -1,142 +1,38 @@
1
- # MIT License
2
- # Copyright (c) 2025 IPEC at Shanghai AI Laboratory
3
- # Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
4
- # distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
5
- # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
7
- # Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
8
  # coding=utf-8
9
-
10
- """
11
- Processor class for PaliGemma.
12
- """
13
-
 
 
 
 
 
 
 
 
14
  import logging
15
  from typing import List, Optional, Union, Dict
16
- import torch
17
  import numpy as np
18
-
19
  from transformers.feature_extraction_utils import BatchFeature
20
  from transformers.image_utils import ImageInput, is_valid_image
21
- from transformers.processing_utils import (
22
- ImagesKwargs,
23
- ProcessingKwargs,
24
- ProcessorMixin,
25
- TextKwargs,
26
- Unpack,
27
- _validate_images_text_input_order,
28
- )
29
- from transformers.tokenization_utils_base import (
30
- AddedToken,
31
- PreTokenizedInput,
32
- TextInput,
33
- )
34
  from transformers.utils import logging
35
- from .action_tokenizer import SphericalCoordinateActionTokenizer
36
-
 
 
 
 
 
 
 
37
  logger = logging.get_logger(__name__)
38
 
39
- IMAGE_TOKEN = "<image>"
40
- EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
41
-
42
-
43
- class PaliGemmaTextKwargs(TextKwargs):
44
- suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
45
-
46
-
47
- class PaliGemmaImagesKwargs(ImagesKwargs):
48
- do_convert_rgb: Optional[bool]
49
-
50
-
51
- class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
52
- text_kwargs: PaliGemmaTextKwargs
53
- images_kwargs: PaliGemmaImagesKwargs
54
- _defaults = {
55
- "text_kwargs": {
56
- "padding": False,
57
- },
58
- "images_kwargs": {
59
- "data_format": "channels_first",
60
- },
61
- }
62
-
63
-
64
- # Copied from transformers.models.idefics2.processing_idefics2.is_url
65
- def is_url(val) -> bool:
66
- return isinstance(val, str) and val.startswith("http")
67
-
68
-
69
- # Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
70
- def is_image_or_image_url(elem):
71
- return is_url(elem) or is_valid_image(elem)
72
-
73
-
74
- def _is_str_or_image(elem):
75
- return isinstance(elem, (str)) or is_image_or_image_url(elem)
76
-
77
-
78
- def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
79
- """
80
- Builds a string from the input prompt and image tokens.
81
- For example, for the call:
82
- build_string_from_input(
83
- prompt="Prefix str"
84
- bos_token="<s>",
85
- image_seq_len=3,
86
- image_token="<im>",
87
- )
88
- The output will be:
89
- "<im><im><im><s>Initial str"
90
- Args:
91
- prompt (`List[Union[str, ImageInput]]`): The input prompt.
92
- bos_token (`str`): The beginning of sentence token.
93
- image_seq_len (`int`): The length of the image sequence.
94
- image_token (`str`): The image token.
95
- num_images (`int`): Number of images in the prompt.
96
- """
97
- return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
98
-
99
-
100
- # Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images
101
- def make_batched_images(images) -> List[List[ImageInput]]:
102
- """
103
- Accepts images in list or nested list format, and makes a list of images for preprocessing.
104
-
105
- Args:
106
- images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
107
- The input image.
108
-
109
- Returns:
110
- list: A list of images.
111
- """
112
- if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
113
- return [img for img_list in images for img in img_list]
114
-
115
- elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
116
- return images
117
-
118
- elif is_valid_image(images):
119
- return [images]
120
-
121
- raise ValueError(f"Could not make batched video from {images}")
122
-
123
-
124
  class SpatialVLAProcessor(ProcessorMixin):
125
- r"""
126
- Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
127
-
128
- [`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`LlamaTokenizerFast`]. See the
129
- [`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information.
130
-
131
- Args:
132
- image_processor ([`SiglipImageProcessor`], *optional*):
133
- The image processor is a required input.
134
- tokenizer ([`LlamaTokenizerFast`], *optional*):
135
- The tokenizer is a required input.
136
- chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
137
- in a chat into a tokenizable string.
138
- """
139
-
140
  attributes = ["image_processor", "tokenizer"]
141
  valid_kwargs = ["chat_template"]
142
  image_processor_class = "SiglipImageProcessor"
@@ -192,17 +88,13 @@ class SpatialVLAProcessor(ProcessorMixin):
192
  self.dataset_intrinsics = {}
193
  height, width = image_processor.size["height"], image_processor.size["width"]
194
 
 
195
  for k, v in intrinsic_config.items():
196
  K = torch.tensor(v["intrinsic"]).float()
197
- h, w = v["height"], v["width"]
198
- K[0, 0] *= width / w
199
- K[1, 1] *= height / h
200
- K[0, 2] *= width / w
201
- K[1, 2] *= height / h
202
  self.dataset_intrinsics[k] = K
203
- print(f"scale intrinsic of {k} from {v['intrinsic']} to {K} ...")
204
 
205
- self.action_tokenizer = SphericalCoordinateActionTokenizer(
206
  tokenizer=tokenizer, num_bins=action_config["num_bins"],
207
  bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
208
  min_sigma=min_sigma,
@@ -212,70 +104,10 @@ class SpatialVLAProcessor(ProcessorMixin):
212
  self,
213
  images: ImageInput = None,
214
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
215
- audio=None,
216
- videos=None,
217
  unnorm_key: Optional[str] = None,
218
  suffix_actions: Optional[np.array] = None, # (t e)
219
  **kwargs: Unpack[PaliGemmaProcessorKwargs],
220
  ) -> BatchFeature:
221
- """
222
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
223
- and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
224
- the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
225
- SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
226
- of the above two methods for more information.
227
-
228
- The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to
229
- the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for
230
- the prefix and the suffix. For instance,
231
- ```python
232
- image = PIL_cow_image
233
- prompt = "answer en Where is the cow standing?"
234
- suffix = "on the beach"
235
- inputs = processor(text=prompt, images=image, suffix=suffix)
236
- ```
237
- Here `inputs` will contain the `input_ids` and `token_type_ids` that follow
238
- ```python
239
- inputs["input_ids"][:, 256:]
240
- # tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]])
241
- inputs["token_type_ids"][:, 256:]
242
- tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
243
- ```
244
- Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type.
245
-
246
-
247
- Args:
248
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
249
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
250
- tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
251
- number of channels, H and W are image height and width.
252
- text (`str`, `List[str]`, `List[List[str]]`):
253
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
254
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
255
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
256
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
257
- If set, will return tensors of a particular framework. Acceptable values are:
258
-
259
- - `'tf'`: Return TensorFlow `tf.constant` objects.
260
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
261
- - `'np'`: Return NumPy `np.ndarray` objects.
262
- - `'jax'`: Return JAX `jnp.ndarray` objects.
263
- suffix (`str`, `List[str]`, `List[List[str]]`):
264
- The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
265
- for more information. If your prompt is "<image> What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench".
266
-
267
- Returns:
268
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
269
-
270
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
271
- is provided, the `input_ids` will also contain the suffix input ids.
272
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
273
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
274
- `None`).
275
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
276
- - **labels** -- Labels compatible with training if `suffix` is not None
277
- """
278
- # check if images and text inputs are reversed for BC
279
  images, text = _validate_images_text_input_order(images, text)
280
 
281
  output_kwargs = self._merge_kwargs(
@@ -294,9 +126,7 @@ class SpatialVLAProcessor(ProcessorMixin):
294
  if images is None:
295
  raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
296
  if text is None:
297
- logger.warning_once(
298
- "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
299
- )
300
  text = ""
301
 
302
  if _is_str_or_image(text):
@@ -306,31 +136,19 @@ class SpatialVLAProcessor(ProcessorMixin):
306
 
307
  if text is not None and images is not None:
308
  if not any(IMAGE_TOKEN in sample for sample in text):
309
- # logger.warning(
310
- # "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special "
311
- # "image tokens in the text, as many tokens as there are images per each text. It is recommended to "
312
- # "add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images "
313
- # "each text has and add special tokens."
314
- # )
315
  if isinstance(text, List) and isinstance(images, List):
316
  if len(images) != len(text):
317
  raise ValueError(
318
  f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
319
  )
320
-
321
- # make a nested list of lists to be able to iterate over the images and text below
322
  if is_valid_image(images):
323
  images = [[images]]
324
  elif isinstance(images, list) and is_valid_image(images[0]):
325
  images = [[image] for image in images]
326
  elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
327
  raise ValueError("images must be an image, list of images or list of list of images")
328
-
329
- if suffix is not None and _is_str_or_image(suffix):
330
- suffix = [suffix]
331
- if suffix is not None:
332
- suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
333
-
334
  input_strings = [
335
  build_string_from_input(
336
  prompt=prompt,
@@ -355,7 +173,6 @@ class SpatialVLAProcessor(ProcessorMixin):
355
  input_strings = [f"{sample}\n" for sample in expanded_samples]
356
  pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
357
 
358
- # max_length has to account for the image tokens
359
  if output_kwargs["text_kwargs"].get("max_length", None) is not None:
360
  output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
361
 
@@ -391,7 +208,6 @@ class SpatialVLAProcessor(ProcessorMixin):
391
  return self.tokenizer.decode(*args, **kwargs)
392
 
393
  @property
394
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
395
  def model_input_names(self):
396
  tokenizer_input_names = self.tokenizer.model_input_names
397
  image_processor_input_names = self.image_processor.model_input_names
@@ -407,7 +223,7 @@ class SpatialVLAProcessor(ProcessorMixin):
407
  assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
408
 
409
  if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
410
- print(f"[warning] Padding zero action!")
411
  predicted_action_token_ids = np.concatenate(
412
  [
413
  predicted_action_token_ids,
@@ -417,9 +233,8 @@ class SpatialVLAProcessor(ProcessorMixin):
417
  predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
418
  normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
419
 
420
- # Unnormalize actions
421
  if unnorm_key is None:
422
- print(f"🔥 unnorm_key {unnorm_key} is not in statistics, use next one")
423
  unnorm_key = next(self.statistics.keys())
424
  action_norm_stats = self.statistics[unnorm_key]["action"]
425
 
 
 
 
 
 
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
  import logging
16
  from typing import List, Optional, Union, Dict
 
17
  import numpy as np
18
+ import torch
19
  from transformers.feature_extraction_utils import BatchFeature
20
  from transformers.image_utils import ImageInput, is_valid_image
21
+ from transformers.processing_utils import Unpack, _validate_images_text_input_order, ProcessorMixin
22
+ from transformers.tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
 
 
 
 
 
 
 
 
 
 
 
23
  from transformers.utils import logging
24
+ from transformers.models.paligemma.processing_paligemma import (
25
+ make_batched_images,
26
+ build_string_from_input,
27
+ _is_str_or_image,
28
+ PaliGemmaProcessorKwargs,
29
+ IMAGE_TOKEN,
30
+ EXTRA_TOKENS
31
+ )
32
+ from .action_tokenizer import SpatialActionTokenizer
33
  logger = logging.get_logger(__name__)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  class SpatialVLAProcessor(ProcessorMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  attributes = ["image_processor", "tokenizer"]
37
  valid_kwargs = ["chat_template"]
38
  image_processor_class = "SiglipImageProcessor"
 
88
  self.dataset_intrinsics = {}
89
  height, width = image_processor.size["height"], image_processor.size["width"]
90
 
91
+ # scale intrinsic matrix
92
  for k, v in intrinsic_config.items():
93
  K = torch.tensor(v["intrinsic"]).float()
94
+ K[:2] *= torch.tensor([width / v["width"], height / v["height"]])[:, None]
 
 
 
 
95
  self.dataset_intrinsics[k] = K
 
96
 
97
+ self.action_tokenizer = SpatialActionTokenizer(
98
  tokenizer=tokenizer, num_bins=action_config["num_bins"],
99
  bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
100
  min_sigma=min_sigma,
 
104
  self,
105
  images: ImageInput = None,
106
  text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
 
 
107
  unnorm_key: Optional[str] = None,
108
  suffix_actions: Optional[np.array] = None, # (t e)
109
  **kwargs: Unpack[PaliGemmaProcessorKwargs],
110
  ) -> BatchFeature:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  images, text = _validate_images_text_input_order(images, text)
112
 
113
  output_kwargs = self._merge_kwargs(
 
126
  if images is None:
127
  raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
128
  if text is None:
129
+ logger.warning_once( "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model.")
 
 
130
  text = ""
131
 
132
  if _is_str_or_image(text):
 
136
 
137
  if text is not None and images is not None:
138
  if not any(IMAGE_TOKEN in sample for sample in text):
 
 
 
 
 
 
139
  if isinstance(text, List) and isinstance(images, List):
140
  if len(images) != len(text):
141
  raise ValueError(
142
  f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
143
  )
 
 
144
  if is_valid_image(images):
145
  images = [[images]]
146
  elif isinstance(images, list) and is_valid_image(images[0]):
147
  images = [[image] for image in images]
148
  elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
149
  raise ValueError("images must be an image, list of images or list of list of images")
150
+ if suffix is not None and _is_str_or_image(suffix): suffix = [suffix]
151
+ if suffix is not None: suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
 
 
 
 
152
  input_strings = [
153
  build_string_from_input(
154
  prompt=prompt,
 
173
  input_strings = [f"{sample}\n" for sample in expanded_samples]
174
  pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
175
 
 
176
  if output_kwargs["text_kwargs"].get("max_length", None) is not None:
177
  output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
178
 
 
208
  return self.tokenizer.decode(*args, **kwargs)
209
 
210
  @property
 
211
  def model_input_names(self):
212
  tokenizer_input_names = self.tokenizer.model_input_names
213
  image_processor_input_names = self.image_processor.model_input_names
 
223
  assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
224
 
225
  if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
226
+ logger.warning(f"Padding zero action!")
227
  predicted_action_token_ids = np.concatenate(
228
  [
229
  predicted_action_token_ids,
 
233
  predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
234
  normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
235
 
 
236
  if unnorm_key is None:
237
+ logger.warning(f"unnorm_key {unnorm_key} is not in statistics, use next one")
238
  unnorm_key = next(self.statistics.keys())
239
  action_norm_stats = self.statistics[unnorm_key]["action"]
240
 
test_huggingface.py CHANGED
@@ -1,17 +1,12 @@
1
  import os
2
  import argparse
3
  from pathlib import Path
4
- import shutil
5
- import os
6
- import argparse
7
- from pathlib import Path
8
- import shutil
9
  import torch
10
  from PIL import Image
11
  from transformers import AutoModel, AutoProcessor
12
 
13
  parser = argparse.ArgumentParser("Huggingface AutoModel Tesing")
14
- parser.add_argument("--model_name_or_path", default="", help="pretrained model name or path.")
15
  parser.add_argument("--num_images", type=int, default=1, help="num_images for testing.")
16
 
17
  args = parser.parse_args()
@@ -32,4 +27,4 @@ if __name__ == "__main__":
32
  print(generation_outputs, processor.batch_decode(generation_outputs))
33
 
34
  actions = processor.decode_actions(generation_outputs, unnorm_key="bridge_orig/1.0.0")
35
- print(actions)
 
1
  import os
2
  import argparse
3
  from pathlib import Path
 
 
 
 
 
4
  import torch
5
  from PIL import Image
6
  from transformers import AutoModel, AutoProcessor
7
 
8
  parser = argparse.ArgumentParser("Huggingface AutoModel Tesing")
9
+ parser.add_argument("--model_name_or_path", default=".", help="pretrained model name or path.")
10
  parser.add_argument("--num_images", type=int, default=1, help="num_images for testing.")
11
 
12
  args = parser.parse_args()
 
27
  print(generation_outputs, processor.batch_decode(generation_outputs))
28
 
29
  actions = processor.decode_actions(generation_outputs, unnorm_key="bridge_orig/1.0.0")
30
+ print(actions)