Upload folder using huggingface_hub
Browse files- action_tokenizer.py +5 -20
- config.json +1 -2
- configuration_spatialvla.py +13 -64
- modeling_gemma2.py +2 -4
- modeling_spatialvla.py +121 -366
- processing_spatialvla.py +33 -218
- test_huggingface.py +2 -7
action_tokenizer.py
CHANGED
@@ -1,27 +1,16 @@
|
|
1 |
-
# MIT License
|
2 |
-
# Copyright (c) 2025 IPEC at Shanghai AI Laboratory
|
3 |
-
# Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
|
4 |
-
# distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
|
5 |
-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
|
7 |
-
# coding=utf-8
|
8 |
-
|
9 |
"""
|
10 |
action_tokenizer.py
|
11 |
|
12 |
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
|
13 |
"""
|
14 |
-
from typing import List, Union, Dict,
|
15 |
import numpy as np
|
16 |
from transformers import PreTrainedTokenizerBase
|
17 |
-
from pathlib import Path
|
18 |
-
import json
|
19 |
from scipy.stats import norm
|
20 |
import torch
|
21 |
|
22 |
ACTION_TOKEN = '<ACTION{:05d}>'
|
23 |
|
24 |
-
"""Spatial Tokenizer"""
|
25 |
class ActionTokenizer:
|
26 |
def __init__(
|
27 |
self,
|
@@ -67,7 +56,6 @@ class ActionTokenizer:
|
|
67 |
def vocab_size(self) -> int:
|
68 |
return self._vocab_size
|
69 |
|
70 |
-
"""Spatial Tokenizer"""
|
71 |
class TranslationTokenizer:
|
72 |
def __init__(
|
73 |
self,
|
@@ -258,7 +246,7 @@ class GripperTokenzier:
|
|
258 |
def vocab_size(self) -> int:
|
259 |
return self.num_bins
|
260 |
|
261 |
-
class
|
262 |
range_bins = {
|
263 |
"translation": {
|
264 |
"theta_bins": (0.0, np.pi),
|
@@ -282,7 +270,7 @@ class SphericalCoordinateActionTokenizer:
|
|
282 |
min_action: float = -1.0,
|
283 |
max_action: float = 1.0,
|
284 |
):
|
285 |
-
"""set bin_policy if exist, otherwise, caculate bin_policy from gs_params
|
286 |
gs_params: Optional[Dict],
|
287 |
bin_policy: Optional[Dict],
|
288 |
"""
|
@@ -293,7 +281,6 @@ class SphericalCoordinateActionTokenizer:
|
|
293 |
|
294 |
# set bin policy
|
295 |
self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
|
296 |
-
|
297 |
self.translation_tokenizer = TranslationTokenizer(
|
298 |
self.tokenizer,
|
299 |
self.num_bins["translation"],
|
@@ -406,13 +393,11 @@ class SphericalCoordinateActionTokenizer:
|
|
406 |
embeddings: tensor (S,E)
|
407 |
"""
|
408 |
from scipy.interpolate import griddata
|
409 |
-
# __import__("ipdb").set_trace()
|
410 |
-
|
411 |
new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
|
412 |
trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
|
413 |
trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
|
414 |
|
415 |
-
print("
|
416 |
self.bin_policy = new_policy
|
417 |
self.min_sigma = min_sigma
|
418 |
self.translation_tokenizer.set_bins(new_policy["translation"])
|
@@ -442,5 +427,5 @@ class SphericalCoordinateActionTokenizer:
|
|
442 |
device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
|
443 |
embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
|
444 |
embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
|
445 |
-
print("
|
446 |
print(embeddings.weight.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
action_tokenizer.py
|
3 |
|
4 |
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
|
5 |
"""
|
6 |
+
from typing import List, Union, Dict, Optional
|
7 |
import numpy as np
|
8 |
from transformers import PreTrainedTokenizerBase
|
|
|
|
|
9 |
from scipy.stats import norm
|
10 |
import torch
|
11 |
|
12 |
ACTION_TOKEN = '<ACTION{:05d}>'
|
13 |
|
|
|
14 |
class ActionTokenizer:
|
15 |
def __init__(
|
16 |
self,
|
|
|
56 |
def vocab_size(self) -> int:
|
57 |
return self._vocab_size
|
58 |
|
|
|
59 |
class TranslationTokenizer:
|
60 |
def __init__(
|
61 |
self,
|
|
|
246 |
def vocab_size(self) -> int:
|
247 |
return self.num_bins
|
248 |
|
249 |
+
class SpatialActionTokenizer:
|
250 |
range_bins = {
|
251 |
"translation": {
|
252 |
"theta_bins": (0.0, np.pi),
|
|
|
270 |
min_action: float = -1.0,
|
271 |
max_action: float = 1.0,
|
272 |
):
|
273 |
+
"""set bin_policy if exist, otherwise, caculate bin_policy from gs_params or use uniform bin grids.
|
274 |
gs_params: Optional[Dict],
|
275 |
bin_policy: Optional[Dict],
|
276 |
"""
|
|
|
281 |
|
282 |
# set bin policy
|
283 |
self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
|
|
|
284 |
self.translation_tokenizer = TranslationTokenizer(
|
285 |
self.tokenizer,
|
286 |
self.num_bins["translation"],
|
|
|
393 |
embeddings: tensor (S,E)
|
394 |
"""
|
395 |
from scipy.interpolate import griddata
|
|
|
|
|
396 |
new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
|
397 |
trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
|
398 |
trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
|
399 |
|
400 |
+
print("overwrite bin policy and tokenizer bins ...")
|
401 |
self.bin_policy = new_policy
|
402 |
self.min_sigma = min_sigma
|
403 |
self.translation_tokenizer.set_bins(new_policy["translation"])
|
|
|
427 |
device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
|
428 |
embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
|
429 |
embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
|
430 |
+
print("DONE! adapt spatial embedding to new gaussian distributation finished.")
|
431 |
print(embeddings.weight.data)
|
config.json
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "../pretrained/2025-01-05_09-12-37_oxe_spatial_vla_paligemma3b_zoe_gsN8194_gpu64-204k",
|
3 |
"_vocab_size": 265347,
|
4 |
"action_token_begin_idx": 257153,
|
5 |
"architectures": [
|
@@ -317,4 +316,4 @@
|
|
317 |
"use_bias_in_fusion_residual": null,
|
318 |
"use_pretrained_backbone": false
|
319 |
}
|
320 |
-
}
|
|
|
1 |
{
|
|
|
2 |
"_vocab_size": 265347,
|
3 |
"action_token_begin_idx": 257153,
|
4 |
"architectures": [
|
|
|
316 |
"use_bias_in_fusion_residual": null,
|
317 |
"use_pretrained_backbone": false
|
318 |
}
|
319 |
+
}
|
configuration_spatialvla.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
-
# MIT License
|
2 |
-
# Copyright (c) 2025 IPEC at Shanghai AI Laboratory
|
3 |
-
# Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
|
4 |
-
# distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
|
5 |
-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
|
7 |
-
# Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
|
8 |
# coding=utf-8
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""PaliGemmamodel configuration"""
|
11 |
|
12 |
import warnings
|
@@ -15,59 +19,9 @@ from transformers.configuration_utils import PretrainedConfig
|
|
15 |
from transformers.utils import logging
|
16 |
from transformers import CONFIG_MAPPING, AutoConfig
|
17 |
|
18 |
-
|
19 |
logger = logging.get_logger(__name__)
|
20 |
|
21 |
-
|
22 |
class SpatialVLAConfig(PretrainedConfig):
|
23 |
-
r"""
|
24 |
-
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
|
25 |
-
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
|
26 |
-
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
|
27 |
-
|
28 |
-
e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
|
29 |
-
|
30 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
31 |
-
documentation from [`PretrainedConfig`] for more information.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
vision_config (`PaliGemmaVisionConfig`, *optional*):
|
35 |
-
Custom vision config or dict
|
36 |
-
text_config (`Union[AutoConfig, dict]`, *optional*):
|
37 |
-
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
|
38 |
-
ignore_index (`int`, *optional*, defaults to -100):
|
39 |
-
The ignore index for the loss function.
|
40 |
-
image_token_index (`int`, *optional*, defaults to 256000):
|
41 |
-
The image token index to encode the image prompt.
|
42 |
-
vocab_size (`int`, *optional*, defaults to 257152):
|
43 |
-
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
|
44 |
-
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
|
45 |
-
projection_dim (`int`, *optional*, defaults to 2048):
|
46 |
-
Dimension of the multimodal projection space.
|
47 |
-
hidden_size (`int`, *optional*, defaults to 2048):
|
48 |
-
Dimension of the hidden layer of the Language model.
|
49 |
-
|
50 |
-
Example:
|
51 |
-
|
52 |
-
```python
|
53 |
-
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
|
54 |
-
|
55 |
-
>>> # Initializing a Siglip-like vision config
|
56 |
-
>>> vision_config = SiglipVisionConfig()
|
57 |
-
|
58 |
-
>>> # Initializing a PaliGemma config
|
59 |
-
>>> text_config = GemmaConfig()
|
60 |
-
|
61 |
-
>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
|
62 |
-
>>> configuration = PaliGemmaConfig(vision_config, text_config)
|
63 |
-
|
64 |
-
>>> # Initializing a model from the paligemma-3b-224 style configuration
|
65 |
-
>>> model = PaliGemmaForConditionalGeneration(configuration)
|
66 |
-
|
67 |
-
>>> # Accessing the model configuration
|
68 |
-
>>> configuration = model.config
|
69 |
-
```"""
|
70 |
-
|
71 |
model_type = "spatialvla"
|
72 |
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
|
73 |
|
@@ -87,7 +41,6 @@ class SpatialVLAConfig(PretrainedConfig):
|
|
87 |
ego3d_patch_reso=4,
|
88 |
n_freqs=8,
|
89 |
use_vision_zoe=True,
|
90 |
-
# wrap_lora=False,
|
91 |
**kwargs,
|
92 |
):
|
93 |
self._ignore_index = ignore_index
|
@@ -138,19 +91,15 @@ class SpatialVLAConfig(PretrainedConfig):
|
|
138 |
vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
|
139 |
self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
|
140 |
else:
|
141 |
-
print(f"🔥 init from default configurations ... {self.vision_zoe_config}")
|
142 |
-
# BUG: initializing zoe in default cause key error
|
143 |
-
# self.vision_zoe_config = CONFIG_MAPPING["zoedepth"]()
|
144 |
pass
|
145 |
|
146 |
-
#
|
147 |
self.action_token_begin_idx = action_token_begin_idx
|
148 |
self.spatial_token_num = spatial_token_num
|
149 |
self.use_spatial_token = use_spatial_token
|
150 |
self.ego3d_patch_reso = ego3d_patch_reso
|
151 |
self.n_freqs = n_freqs
|
152 |
self.use_vision_zoe = use_vision_zoe
|
153 |
-
# self.wrap_lora = wrap_lora
|
154 |
|
155 |
super().__init__(**kwargs)
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
"""PaliGemmamodel configuration"""
|
15 |
|
16 |
import warnings
|
|
|
19 |
from transformers.utils import logging
|
20 |
from transformers import CONFIG_MAPPING, AutoConfig
|
21 |
|
|
|
22 |
logger = logging.get_logger(__name__)
|
23 |
|
|
|
24 |
class SpatialVLAConfig(PretrainedConfig):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
model_type = "spatialvla"
|
26 |
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
|
27 |
|
|
|
41 |
ego3d_patch_reso=4,
|
42 |
n_freqs=8,
|
43 |
use_vision_zoe=True,
|
|
|
44 |
**kwargs,
|
45 |
):
|
46 |
self._ignore_index = ignore_index
|
|
|
91 |
vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
|
92 |
self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
|
93 |
else:
|
|
|
|
|
|
|
94 |
pass
|
95 |
|
96 |
+
# additional attributes
|
97 |
self.action_token_begin_idx = action_token_begin_idx
|
98 |
self.spatial_token_num = spatial_token_num
|
99 |
self.use_spatial_token = use_spatial_token
|
100 |
self.ego3d_patch_reso = ego3d_patch_reso
|
101 |
self.n_freqs = n_freqs
|
102 |
self.use_vision_zoe = use_vision_zoe
|
|
|
103 |
|
104 |
super().__init__(**kwargs)
|
105 |
|
modeling_gemma2.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
# custom gemma2 to support flash_attention_2
|
|
|
2 |
# coding=utf-8
|
3 |
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
4 |
#
|
@@ -205,10 +206,7 @@ def flash_attention_forward(
|
|
205 |
) -> Tuple[torch.Tensor, None]:
|
206 |
# NOTE: None mask cause un defined https://github.com/huggingface/transformers/blob/c8c8dffbe45ebef0a8dba4a51024e5e5e498596b/src/transformers/models/gemma2/modeling_gemma2.py#L211
|
207 |
seq_len = query.shape[2]
|
208 |
-
# print(f"🔥 query {query.shape}, key {key.shape}, value: {value.shape}")
|
209 |
if mask is not None:
|
210 |
-
# print(f"🔥 mask {mask.shape}")
|
211 |
-
# seq_len = mask.shape[1]
|
212 |
query = query[:, :, :seq_len]
|
213 |
value = value[:, :, :seq_len]
|
214 |
|
|
|
1 |
+
# custom gemma2 to support flash_attention_2,
|
2 |
+
# source from https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/gemma2/modeling_gemma2.py
|
3 |
# coding=utf-8
|
4 |
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
5 |
#
|
|
|
206 |
) -> Tuple[torch.Tensor, None]:
|
207 |
# NOTE: None mask cause un defined https://github.com/huggingface/transformers/blob/c8c8dffbe45ebef0a8dba4a51024e5e5e498596b/src/transformers/models/gemma2/modeling_gemma2.py#L211
|
208 |
seq_len = query.shape[2]
|
|
|
209 |
if mask is not None:
|
|
|
|
|
210 |
query = query[:, :, :seq_len]
|
211 |
value = value[:, :, :seq_len]
|
212 |
|
modeling_spatialvla.py
CHANGED
@@ -1,153 +1,118 @@
|
|
1 |
-
# MIT License
|
2 |
-
# Copyright (c) 2025 IPEC at Shanghai AI Laboratory
|
3 |
-
# Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
|
4 |
-
# distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
|
5 |
-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
|
7 |
-
# Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
|
8 |
# coding=utf-8
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""PyTorch PaliGemmamodel."""
|
11 |
|
12 |
from dataclasses import dataclass
|
13 |
from typing import List, Optional, Tuple, Union
|
14 |
|
|
|
15 |
import torch
|
16 |
import torch.utils.checkpoint
|
17 |
from torch import nn
|
18 |
from torch.linalg import inv
|
19 |
-
import torchvision.transforms.functional as
|
20 |
-
|
21 |
-
import os
|
22 |
from transformers.cache_utils import Cache, HybridCache, StaticCache
|
23 |
from transformers.generation import GenerationMixin
|
24 |
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
|
25 |
from transformers.utils import (
|
26 |
ModelOutput,
|
27 |
-
add_start_docstrings,
|
28 |
-
add_start_docstrings_to_model_forward,
|
29 |
-
is_flash_attn_2_available,
|
30 |
logging,
|
31 |
-
replace_return_docstrings,
|
32 |
)
|
33 |
from .configuration_spatialvla import SpatialVLAConfig
|
34 |
-
from .modeling_ego3d import Ego3DPositionEmbeddingMLP, process_zoe
|
35 |
from .modeling_gemma2 import Gemma2ForCausalLM
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
from transformers import AutoModel, AutoModelForCausalLM, ZoeDepthForDepthEstimation
|
41 |
-
|
42 |
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
#
|
48 |
-
SIGLIP_MEAN, SIGLIP_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
49 |
-
|
50 |
-
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
51 |
-
# But Paligemma has no causal mask on prefix
|
52 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
53 |
-
attention_mask: torch.Tensor,
|
54 |
-
sequence_length: int,
|
55 |
-
target_length: int,
|
56 |
-
dtype: torch.dtype,
|
57 |
-
device: torch.device,
|
58 |
-
min_dtype: float,
|
59 |
-
cache_position: torch.Tensor,
|
60 |
-
batch_size: int,
|
61 |
-
is_training: bool = False,
|
62 |
-
token_type_ids: torch.Tensor = None,
|
63 |
-
**kwargs,
|
64 |
-
):
|
65 |
"""
|
66 |
-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
67 |
-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
68 |
-
|
69 |
-
Args:
|
70 |
-
attention_mask (`torch.Tensor`):
|
71 |
-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
72 |
-
sequence_length (`int`):
|
73 |
-
The sequence length being processed.
|
74 |
-
target_length (`int`):
|
75 |
-
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
76 |
-
dtype (`torch.dtype`):
|
77 |
-
The dtype to use for the 4D attention mask.
|
78 |
-
device (`torch.device`):
|
79 |
-
The device to plcae the 4D attention mask on.
|
80 |
-
min_dtype (`float`):
|
81 |
-
The minimum value representable with the dtype `dtype`.
|
82 |
-
cache_position (`torch.Tensor`):
|
83 |
-
Indices depicting the position of the input sequence tokens in the sequence.
|
84 |
-
batch_size (`torch.Tensor`):
|
85 |
-
Batch size.
|
86 |
-
is_training (`bool`):
|
87 |
-
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
88 |
-
"""
|
89 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
90 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
91 |
-
causal_mask = attention_mask
|
92 |
-
else:
|
93 |
-
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
94 |
-
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
95 |
-
if sequence_length != 1:
|
96 |
-
if is_training:
|
97 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
98 |
-
else:
|
99 |
-
causal_mask[:, :sequence_length] = 0.0
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
@dataclass
|
120 |
class SpatialVLACausalLMOutputWithPast(ModelOutput):
|
121 |
-
"""
|
122 |
-
Base class for PaliGemmacausal language model (or autoregressive) outputs.
|
123 |
-
|
124 |
-
Args:
|
125 |
-
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
126 |
-
Language modeling loss (for next-token prediction).
|
127 |
-
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
128 |
-
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
129 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
130 |
-
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
131 |
-
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
132 |
-
|
133 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
134 |
-
`past_key_values` input) to speed up sequential decoding.
|
135 |
-
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
136 |
-
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
137 |
-
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
138 |
-
|
139 |
-
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
140 |
-
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
141 |
-
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
142 |
-
sequence_length)`.
|
143 |
-
|
144 |
-
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
145 |
-
heads.
|
146 |
-
image_hidden_states (`torch.FloatTensor`, *optional*):
|
147 |
-
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
148 |
-
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
149 |
-
"""
|
150 |
-
|
151 |
loss: Optional[torch.FloatTensor] = None
|
152 |
logits: torch.FloatTensor = None
|
153 |
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
@@ -155,7 +120,6 @@ class SpatialVLACausalLMOutputWithPast(ModelOutput):
|
|
155 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
156 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
157 |
|
158 |
-
|
159 |
class SpatialVLAMultiModalProjector(nn.Module):
|
160 |
def __init__(self, config: SpatialVLAConfig):
|
161 |
super().__init__()
|
@@ -163,31 +127,8 @@ class SpatialVLAMultiModalProjector(nn.Module):
|
|
163 |
|
164 |
def forward(self, image_features):
|
165 |
hidden_states = self.linear(image_features)
|
166 |
-
|
167 |
return hidden_states
|
168 |
|
169 |
-
|
170 |
-
PALIGEMMA_START_DOCSTRING = r"""
|
171 |
-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
172 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
173 |
-
etc.)
|
174 |
-
|
175 |
-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
176 |
-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
177 |
-
and behavior.
|
178 |
-
|
179 |
-
Parameters:
|
180 |
-
config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]):
|
181 |
-
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
182 |
-
load the weights associated with the model, only the configuration. Check out the
|
183 |
-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
184 |
-
"""
|
185 |
-
|
186 |
-
|
187 |
-
@add_start_docstrings(
|
188 |
-
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
189 |
-
PALIGEMMA_START_DOCSTRING,
|
190 |
-
)
|
191 |
class SpatialVLAPreTrainedModel(PreTrainedModel):
|
192 |
config_class = SpatialVLAConfig
|
193 |
base_model_prefix = "model"
|
@@ -202,8 +143,6 @@ class SpatialVLAPreTrainedModel(PreTrainedModel):
|
|
202 |
_supports_sdpa = True
|
203 |
|
204 |
def _init_weights(self, module):
|
205 |
-
# important: this ported version of PaliGemmaisn't meant for training from scratch - only
|
206 |
-
# inference and fine-tuning
|
207 |
std = (
|
208 |
self.config.initializer_range
|
209 |
if hasattr(self.config, "initializer_range")
|
@@ -222,99 +161,20 @@ class SpatialVLAPreTrainedModel(PreTrainedModel):
|
|
222 |
if module.padding_idx is not None:
|
223 |
module.weight.data[module.padding_idx].zero_()
|
224 |
|
225 |
-
|
226 |
-
PALIGEMMA_INPUTS_DOCSTRING = r"""
|
227 |
-
Args:
|
228 |
-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
229 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
230 |
-
it.
|
231 |
-
|
232 |
-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
233 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
234 |
-
|
235 |
-
[What are input IDs?](../glossary#input-ids)
|
236 |
-
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
|
237 |
-
The tensors corresponding to the input images. Pixel values can be obtained using
|
238 |
-
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
|
239 |
-
[`SiglipImageProcessor`] for processing images).
|
240 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
241 |
-
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
242 |
-
|
243 |
-
- 1 for tokens that are **not masked**,
|
244 |
-
- 0 for tokens that are **masked**.
|
245 |
-
|
246 |
-
[What are attention masks?](../glossary#attention-mask)
|
247 |
-
|
248 |
-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
249 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
250 |
-
|
251 |
-
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
252 |
-
`past_key_values`).
|
253 |
-
|
254 |
-
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
255 |
-
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
256 |
-
information on the default strategy.
|
257 |
-
|
258 |
-
- 1 indicates the head is **not masked**,
|
259 |
-
- 0 indicates the head is **masked**.
|
260 |
-
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
261 |
-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
262 |
-
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
263 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
264 |
-
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
265 |
-
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
266 |
-
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
267 |
-
|
268 |
-
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
269 |
-
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
270 |
-
|
271 |
-
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
272 |
-
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
273 |
-
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
274 |
-
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
275 |
-
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
276 |
-
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
277 |
-
model's internal embedding lookup matrix.
|
278 |
-
use_cache (`bool`, *optional*):
|
279 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
280 |
-
`past_key_values`).
|
281 |
-
output_attentions (`bool`, *optional*):
|
282 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
283 |
-
tensors for more detail.
|
284 |
-
output_hidden_states (`bool`, *optional*):
|
285 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
286 |
-
more detail.
|
287 |
-
return_dict (`bool`, *optional*):
|
288 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
289 |
-
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
290 |
-
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
291 |
-
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
292 |
-
the complete sequence length.
|
293 |
-
"""
|
294 |
-
|
295 |
-
|
296 |
-
@add_start_docstrings(
|
297 |
-
"""The PALIGEMMA model which consists of a vision backbone and a language model.""",
|
298 |
-
PALIGEMMA_START_DOCSTRING,
|
299 |
-
)
|
300 |
class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMixin):
|
301 |
def __init__(self, config: SpatialVLAConfig, vision_model=None, vision_zoe_model=None, projector_model=None, language_model=None):
|
302 |
super().__init__(config)
|
303 |
-
|
304 |
self.vision_tower = vision_model or AutoModel.from_config(config=config.vision_config)
|
305 |
-
# projector
|
306 |
self.multi_modal_projector = projector_model or SpatialVLAMultiModalProjector(config)
|
307 |
-
# language model
|
308 |
self.vocab_size = config.text_config.vocab_size
|
309 |
if language_model is None:
|
310 |
-
language_model = Gemma2ForCausalLM(config=config.text_config)
|
311 |
-
# set tile key
|
312 |
if language_model._tied_weights_keys is not None:
|
313 |
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
314 |
self.language_model = language_model
|
315 |
|
316 |
if config.use_vision_zoe:
|
317 |
-
# zoe model
|
318 |
self.vision_zoe_model = vision_zoe_model or ZoeDepthForDepthEstimation(config.vision_zoe_config)
|
319 |
self.position_embedding_3d = Ego3DPositionEmbeddingMLP(
|
320 |
config.ego3d_patch_reso**2 * 3, num_pos_feats=config.vision_config.hidden_size, n_freqs=config.n_freqs
|
@@ -326,15 +186,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
326 |
uv_h = torch.stack([x, y, torch.ones_like(x)], dim=0).reshape(3, -1) # (3 hw)
|
327 |
self.register_buffer("uv_h", uv_h, persistent=False)
|
328 |
|
329 |
-
#
|
330 |
if config.use_spatial_token:
|
331 |
self.spatial_embed_tokens = nn.Embedding(self.config.spatial_token_num, config.text_config.hidden_size)
|
332 |
else:
|
333 |
self.spatial_embed_tokens = None
|
334 |
-
|
335 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
336 |
-
# self.post_init() # BUG: cause from_pretrained failed!
|
337 |
-
# self.position_embedding_3d._reset_parameters()
|
338 |
|
339 |
|
340 |
def backproject_patch(self, K: torch.Tensor, depth: torch.Tensor, patch_size=14, reso=2) -> torch.Tensor:
|
@@ -343,44 +200,48 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
343 |
Args:
|
344 |
K: camera intrinsic matrix (b 3 3)
|
345 |
depth: depth map (b 1 h w)
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
"""
|
348 |
-
# __import__("ipdb").set_trace()
|
349 |
b, c, h, w = depth.shape
|
350 |
hp, wp = h // patch_size, w // patch_size
|
351 |
sub_hp = sub_wp = reso
|
352 |
-
patch_depth =
|
353 |
-
|
354 |
-
# import torchvision; torchvision.utils.save_image(zoe_pixel_values[0], "zoe_image.png")
|
355 |
p_cam = (inv(K.float()) @ self.uv_h.float()) * patch_depth # (b 3 3) @ (3 hw) -> (b 3 hw) * (b 1 hw) -> (b 3 hw)
|
356 |
patch_p_cam = p_cam.reshape(b, 3, hp, sub_hp, wp, sub_wp).permute(0, 2, 4, 3, 5, 1).reshape(b, hp * wp, -1)
|
357 |
return patch_p_cam
|
358 |
|
359 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
|
360 |
def get_input_embeddings(self):
|
361 |
return self.language_model.get_input_embeddings()
|
362 |
|
363 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
|
364 |
def set_input_embeddings(self, value):
|
365 |
self.language_model.set_input_embeddings(value)
|
366 |
|
367 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
|
368 |
def get_output_embeddings(self):
|
369 |
return self.language_model.get_output_embeddings()
|
370 |
|
371 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
|
372 |
def set_output_embeddings(self, new_embeddings):
|
373 |
self.language_model.set_output_embeddings(new_embeddings)
|
374 |
|
375 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
|
376 |
def set_decoder(self, decoder):
|
377 |
self.language_model.set_decoder(decoder)
|
378 |
|
379 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
|
380 |
def get_decoder(self):
|
381 |
return self.language_model.get_decoder()
|
382 |
|
383 |
-
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights with Llava->PaliGemma
|
384 |
def tie_weights(self):
|
385 |
return self.language_model.tie_weights()
|
386 |
|
@@ -390,11 +251,7 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
390 |
pad_to_multiple_of: Optional[int] = None,
|
391 |
mean_resizing: bool = True,
|
392 |
) -> nn.Embedding:
|
393 |
-
# TODO: is_deepspeed_zero3_enabled gather
|
394 |
-
print(f"resize token embeddings from {self.language_model.get_output_embeddings().weight.shape} to (*,{new_num_tokens})")
|
395 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
396 |
-
|
397 |
-
# update base model and current model config
|
398 |
vocab_size = model_embeds.weight.shape[0]
|
399 |
self.config.text_config.vocab_size = self.vocab_size = self.config._vocab_size = vocab_size
|
400 |
self.tie_weights()
|
@@ -431,18 +288,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
431 |
)
|
432 |
|
433 |
if attention_mask is not None and attention_mask.dim() == 4:
|
434 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
435 |
return attention_mask
|
436 |
|
437 |
-
causal_mask = torch.full(
|
438 |
-
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
|
439 |
-
)
|
440 |
-
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
441 |
if sequence_length != 1:
|
442 |
-
if is_training:
|
443 |
-
|
444 |
-
else:
|
445 |
-
causal_mask[:, :sequence_length] = 0.0
|
446 |
|
447 |
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
448 |
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
@@ -451,29 +302,13 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
451 |
mask_length = attention_mask.shape[-1]
|
452 |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
453 |
padding_mask = padding_mask == 0
|
454 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
455 |
-
padding_mask, min_dtype
|
456 |
-
)
|
457 |
-
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
458 |
if is_training:
|
459 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
460 |
-
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
461 |
-
)
|
462 |
return causal_mask
|
463 |
|
464 |
def get_image_features(self, pixel_values: torch.FloatTensor, intrinsic: torch.FloatTensor):
|
465 |
-
|
466 |
-
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
467 |
-
|
468 |
-
Args:
|
469 |
-
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
470 |
-
The tensors corresponding to the input images.
|
471 |
-
Returns:
|
472 |
-
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
473 |
-
"""
|
474 |
-
# mintrinsic = intrinsic.reshape(-1, 3, 3)
|
475 |
-
# siglip vision tower
|
476 |
-
siglip_pixel_values = F.normalize(pixel_values, mean=SIGLIP_MEAN, std=SIGLIP_STD)
|
477 |
image_outputs = self.vision_tower(siglip_pixel_values)
|
478 |
|
479 |
# ego3d position encoding
|
@@ -482,13 +317,12 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
482 |
with torch.no_grad():
|
483 |
pvh, pvw = pixel_values.shape[-2:]
|
484 |
depth = self.vision_zoe_model(pixel_values=zoe_pixel_values).predicted_depth
|
485 |
-
depth =
|
486 |
depth.unsqueeze(1),
|
487 |
size=(pvh+2*ph, pvw+2*pw),
|
488 |
mode="bicubic",
|
489 |
align_corners=True,
|
490 |
)[..., ph:-ph, pw:-pw]
|
491 |
-
# depth = torch.clamp(depth, 0., 4.0) # NOTE: we find that depth w/o clamp performs better
|
492 |
xyz = self.backproject_patch(
|
493 |
intrinsic, depth, patch_size=self.config.vision_config.patch_size, reso=self.config.ego3d_patch_reso
|
494 |
) # (b, n, 3*4)
|
@@ -500,8 +334,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
500 |
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
501 |
return image_features
|
502 |
|
503 |
-
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
504 |
-
@replace_return_docstrings(output_type=SpatialVLACausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
505 |
def forward(
|
506 |
self,
|
507 |
input_ids: torch.LongTensor = None,
|
@@ -521,93 +353,29 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
521 |
return_dict: Optional[bool] = None,
|
522 |
num_logits_to_keep: int = 0,
|
523 |
) -> Union[Tuple, SpatialVLACausalLMOutputWithPast]:
|
524 |
-
r"""
|
525 |
-
Args:
|
526 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
527 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
528 |
-
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
529 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
530 |
-
|
531 |
-
num_logits_to_keep (`int`, *optional*):
|
532 |
-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
533 |
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
534 |
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
535 |
-
|
536 |
-
Returns:
|
537 |
-
|
538 |
-
Example:
|
539 |
-
|
540 |
-
```python
|
541 |
-
>>> from PIL import Image
|
542 |
-
>>> import requests
|
543 |
-
>>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
544 |
-
|
545 |
-
>>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
|
546 |
-
>>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
|
547 |
-
|
548 |
-
>>> prompt = "answer en Where is the cow standing?"
|
549 |
-
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
|
550 |
-
>>> image = Image.open(requests.get(url, stream=True).raw)
|
551 |
-
|
552 |
-
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
553 |
-
|
554 |
-
>>> # Generate
|
555 |
-
>>> generate_ids = model.generate(**inputs, max_length=30)
|
556 |
-
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
557 |
-
"answer en Where is the cow standing?\nbeach"
|
558 |
-
```"""
|
559 |
-
# print(f"**************************************\n \
|
560 |
-
# input_ids {input_ids} \n \
|
561 |
-
# labels {labels} \n \
|
562 |
-
# token_type_ids {token_type_ids} \n \
|
563 |
-
# attention_mask {attention_mask} \n \
|
564 |
-
# actions {actions} \n \
|
565 |
-
# **************************************"
|
566 |
-
# )
|
567 |
-
# print(f"model.language_model.config._attn_implementation {self.language_model.config._attn_implementation} model.config.vision_config._attn_implementation_internal {self.config.vision_config._attn_implementation_internal} \n \
|
568 |
-
# model.vision_tower.config._attn_implementation {self.vision_tower.config._attn_implementation} model.config.vision_config._attn_implementation_internal {self.config.vision_config._attn_implementation_internal}")
|
569 |
-
# __import__("ipdb").set_trace()
|
570 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
571 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
572 |
-
|
573 |
-
if pixel_values is not None and inputs_embeds is not None:
|
574 |
-
raise ValueError(
|
575 |
-
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
576 |
-
)
|
577 |
|
578 |
-
output_attentions = output_attentions
|
579 |
-
output_hidden_states =
|
580 |
-
|
581 |
-
)
|
582 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
583 |
|
584 |
is_training = token_type_ids is not None and labels is not None
|
585 |
|
586 |
-
if inputs_embeds is None:
|
587 |
-
|
588 |
-
|
589 |
-
# NOTE: replace the fixed embeddings with trainable spatial embeddings
|
590 |
-
# BUG: LoRA causes inputs_embeds requires_grad = True
|
591 |
-
# peft: https://github.com/huggingface/peft/blob/ec92cdcc41fe1b141bfe1e0da69b38a7e601cc80/src/peft/peft_model.py#L687
|
592 |
-
# hf: https://github.com/huggingface/transformers/blob/05260a1fc1c8571a2b421ce72b680d5f1bc3e5a4/src/transformers/modeling_utils.py#L2545
|
593 |
-
# lora w/ prompt: https://discuss.huggingface.co/t/combine-between-lora-and-prompt-tunning/65151
|
594 |
if self.config.use_spatial_token:
|
595 |
spatial_selected = (input_ids >= self.config.action_token_begin_idx) & (input_ids < self.config.action_token_begin_idx + self.config.spatial_token_num)
|
596 |
inputs_embeds[spatial_selected] = inputs_embeds[spatial_selected] * 0.0 + self.spatial_embed_tokens(input_ids[spatial_selected] - self.config.action_token_begin_idx)
|
597 |
|
598 |
if cache_position is None:
|
599 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
600 |
-
cache_position = torch.arange(
|
601 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
602 |
-
)
|
603 |
|
604 |
if position_ids is None:
|
605 |
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
606 |
|
607 |
-
#
|
608 |
if pixel_values is not None:
|
609 |
image_features = self.get_image_features(pixel_values, intrinsic)
|
610 |
-
|
611 |
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
612 |
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
613 |
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
@@ -647,20 +415,16 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
647 |
logits = outputs.logits
|
648 |
loss = None
|
649 |
if labels is not None:
|
650 |
-
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
651 |
logits = logits.float()
|
652 |
shift_logits = logits[..., :-1, :]
|
653 |
shift_labels = labels[..., 1:]
|
654 |
if attention_mask is not None:
|
655 |
-
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
656 |
-
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
657 |
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
658 |
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
659 |
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
660 |
else:
|
661 |
shift_logits = shift_logits.contiguous()
|
662 |
shift_labels = shift_labels.contiguous()
|
663 |
-
# Flatten the tokens
|
664 |
loss_fct = nn.CrossEntropyLoss()
|
665 |
|
666 |
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
@@ -679,6 +443,7 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
679 |
image_hidden_states=image_features if pixel_values is not None else None,
|
680 |
)
|
681 |
|
|
|
682 |
def prepare_inputs_for_generation(
|
683 |
self,
|
684 |
input_ids,
|
@@ -695,7 +460,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
695 |
labels=None,
|
696 |
**kwargs,
|
697 |
):
|
698 |
-
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
699 |
model_inputs = self.language_model.prepare_inputs_for_generation(
|
700 |
input_ids,
|
701 |
past_key_values=past_key_values,
|
@@ -708,19 +472,13 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
708 |
token_type_ids=token_type_ids,
|
709 |
**kwargs,
|
710 |
)
|
711 |
-
|
712 |
-
# position_ids in Paligemma are 1-indexed
|
713 |
if model_inputs.get("position_ids") is not None:
|
714 |
model_inputs["position_ids"] += 1
|
715 |
-
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
716 |
-
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
|
717 |
if cache_position[0] == 0:
|
718 |
model_inputs["pixel_values"] = pixel_values
|
719 |
is_training = token_type_ids is not None and labels is not None
|
720 |
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
721 |
-
causal_mask = self._update_causal_mask(
|
722 |
-
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
|
723 |
-
)
|
724 |
model_inputs["attention_mask"] = causal_mask
|
725 |
model_inputs["intrinsic"] = intrinsic
|
726 |
return model_inputs
|
@@ -765,9 +523,6 @@ class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMi
|
|
765 |
weights_only=weights_only,
|
766 |
**kwargs,
|
767 |
)
|
768 |
-
# NOTE: tie the weights of the embed_tokens with lm head (donot work if un_tie_weight)
|
769 |
-
# model.language_model.tie_weights()
|
770 |
-
# NOTE: tie the data of spatial_embed_tokens with embed_tokens (BUG: forweight sync issue in training)
|
771 |
if model.config.use_spatial_token:
|
772 |
model.language_model.model.embed_tokens.weight.data[-model.config.spatial_token_num:] = model.spatial_embed_tokens.weight.data
|
773 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# coding=utf-8
|
2 |
+
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
"""PyTorch PaliGemmamodel."""
|
16 |
|
17 |
from dataclasses import dataclass
|
18 |
from typing import List, Optional, Tuple, Union
|
19 |
|
20 |
+
import os
|
21 |
import torch
|
22 |
import torch.utils.checkpoint
|
23 |
from torch import nn
|
24 |
from torch.linalg import inv
|
25 |
+
import torchvision.transforms.functional as TF
|
26 |
+
import torch.nn.functional as F
|
|
|
27 |
from transformers.cache_utils import Cache, HybridCache, StaticCache
|
28 |
from transformers.generation import GenerationMixin
|
29 |
from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
|
30 |
from transformers.utils import (
|
31 |
ModelOutput,
|
|
|
|
|
|
|
32 |
logging,
|
|
|
33 |
)
|
34 |
from .configuration_spatialvla import SpatialVLAConfig
|
|
|
35 |
from .modeling_gemma2 import Gemma2ForCausalLM
|
36 |
+
from transformers import AutoModel, ZoeDepthForDepthEstimation
|
37 |
|
38 |
+
SIGLIP_MEAN, SIGLIP_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
39 |
+
ZOE_MEAN, ZOE_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
|
|
|
|
|
|
40 |
|
41 |
logger = logging.get_logger(__name__)
|
42 |
|
43 |
+
class Ego3DPositionEmbeddingMLP(nn.Module):
|
44 |
+
"""Absolute pos embedding, learned.
|
45 |
+
https://github.com/kwea123/nerf_pl/blob/52aeb387da64a9ad9a0f914ea9b049ffc598b20c/models/nerf.py#L4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
def __init__(self, in_channels=3, num_pos_feats=768, n_freqs=8, logscale=True):
|
49 |
+
super(Ego3DPositionEmbeddingMLP, self).__init__()
|
50 |
+
self.n_freqs = n_freqs
|
51 |
+
self.freq_out_channels = in_channels * (2 * n_freqs + 1)
|
52 |
+
if logscale:
|
53 |
+
freq_bands = 2 ** torch.linspace(0, n_freqs - 1, n_freqs)
|
54 |
+
else:
|
55 |
+
freq_bands = torch.linspace(1, 2 ** (n_freqs - 1), n_freqs)
|
56 |
+
|
57 |
+
center = torch.tensor([0., 0., 2.]).repeat(in_channels // 3)
|
58 |
+
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
59 |
+
self.register_buffer("center", center, persistent=False)
|
60 |
+
|
61 |
+
self.position_embedding_head = nn.Sequential(
|
62 |
+
nn.Linear(self.freq_out_channels, num_pos_feats),
|
63 |
+
nn.LayerNorm(num_pos_feats),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Linear(num_pos_feats, num_pos_feats),
|
66 |
+
)
|
67 |
+
self._reset_parameters()
|
68 |
|
69 |
+
def _reset_parameters(self):
|
70 |
+
"""init with small weights to maintain stable training."""
|
71 |
+
for p in self.parameters():
|
72 |
+
if p.dim() > 1:
|
73 |
+
nn.init.xavier_uniform_(p, gain=0.01)
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def frequency_encoding(self, xyz):
|
77 |
+
"""
|
78 |
+
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
|
79 |
+
Different from the paper, "x" is also in the output
|
80 |
+
See https://github.com/bmild/nerf/issues/12
|
81 |
+
x \in [-2, 2]
|
82 |
+
y \in [-2, 2]
|
83 |
+
z \in [0., 4]
|
84 |
+
Inputs:
|
85 |
+
x: (b n m)
|
86 |
+
Outputs:
|
87 |
+
out: (b n o)
|
88 |
+
"""
|
89 |
+
xyz_n = ((xyz - self.center) / 2.0).to(self.freq_bands.dtype)
|
90 |
+
xyz_feq = xyz_n.unsqueeze(-1) * self.freq_bands # (b n m 1)
|
91 |
+
sin_xyz, cos_xyz = torch.sin(xyz_feq), torch.cos(xyz_feq) # (b n m nf)
|
92 |
+
encoding = torch.cat([xyz_n.unsqueeze(-1), sin_xyz, cos_xyz], -1).reshape(*xyz.shape[:2], -1)
|
93 |
+
return encoding
|
94 |
+
|
95 |
+
def forward(self, xyz):
|
96 |
+
"""Forward pass, xyz is (B, N, 3or6), output (B, N, F)."""
|
97 |
+
freq_encoding = self.frequency_encoding(xyz)
|
98 |
+
position_embedding = self.position_embedding_head(freq_encoding)
|
99 |
+
return position_embedding
|
100 |
+
|
101 |
+
def process_zoe(pixel_values, pad_mode="reflect", output_size=(384, 512)):
|
102 |
+
"""https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/zoedepth/image_processing_zoedepth.py"""
|
103 |
+
# h, w = images.shape[-2:]
|
104 |
+
# pad
|
105 |
+
ph, pw = 31, 31 # int((h / 2)**0.5 * 3), int((w / 2)**0.5 * 3) # 32, 31
|
106 |
+
images = F.pad(pixel_values, (pw, pw, ph, ph), mode=pad_mode)
|
107 |
+
# resize
|
108 |
+
size = (384, 384) # get_resize_output_image_size
|
109 |
+
images = F.interpolate(images, size=size, mode="bicubic", align_corners=True)
|
110 |
+
# zoe: padding -> resize -> nomalize. we follow `nomalize -> padding -> resize` from siglip
|
111 |
+
images = TF.normalize(images, mean=ZOE_MEAN, std=ZOE_STD)
|
112 |
+
return images, ph, pw
|
113 |
|
114 |
@dataclass
|
115 |
class SpatialVLACausalLMOutputWithPast(ModelOutput):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
loss: Optional[torch.FloatTensor] = None
|
117 |
logits: torch.FloatTensor = None
|
118 |
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
|
|
|
120 |
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
121 |
image_hidden_states: Optional[torch.FloatTensor] = None
|
122 |
|
|
|
123 |
class SpatialVLAMultiModalProjector(nn.Module):
|
124 |
def __init__(self, config: SpatialVLAConfig):
|
125 |
super().__init__()
|
|
|
127 |
|
128 |
def forward(self, image_features):
|
129 |
hidden_states = self.linear(image_features)
|
|
|
130 |
return hidden_states
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
class SpatialVLAPreTrainedModel(PreTrainedModel):
|
133 |
config_class = SpatialVLAConfig
|
134 |
base_model_prefix = "model"
|
|
|
143 |
_supports_sdpa = True
|
144 |
|
145 |
def _init_weights(self, module):
|
|
|
|
|
146 |
std = (
|
147 |
self.config.initializer_range
|
148 |
if hasattr(self.config, "initializer_range")
|
|
|
161 |
if module.padding_idx is not None:
|
162 |
module.weight.data[module.padding_idx].zero_()
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
class SpatialVLAForConditionalGeneration(SpatialVLAPreTrainedModel, GenerationMixin):
|
165 |
def __init__(self, config: SpatialVLAConfig, vision_model=None, vision_zoe_model=None, projector_model=None, language_model=None):
|
166 |
super().__init__(config)
|
167 |
+
|
168 |
self.vision_tower = vision_model or AutoModel.from_config(config=config.vision_config)
|
|
|
169 |
self.multi_modal_projector = projector_model or SpatialVLAMultiModalProjector(config)
|
|
|
170 |
self.vocab_size = config.text_config.vocab_size
|
171 |
if language_model is None:
|
172 |
+
language_model = Gemma2ForCausalLM(config=config.text_config)
|
|
|
173 |
if language_model._tied_weights_keys is not None:
|
174 |
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
175 |
self.language_model = language_model
|
176 |
|
177 |
if config.use_vision_zoe:
|
|
|
178 |
self.vision_zoe_model = vision_zoe_model or ZoeDepthForDepthEstimation(config.vision_zoe_config)
|
179 |
self.position_embedding_3d = Ego3DPositionEmbeddingMLP(
|
180 |
config.ego3d_patch_reso**2 * 3, num_pos_feats=config.vision_config.hidden_size, n_freqs=config.n_freqs
|
|
|
186 |
uv_h = torch.stack([x, y, torch.ones_like(x)], dim=0).reshape(3, -1) # (3 hw)
|
187 |
self.register_buffer("uv_h", uv_h, persistent=False)
|
188 |
|
189 |
+
# shared spatial embeddings for <ACTION> <IMG>
|
190 |
if config.use_spatial_token:
|
191 |
self.spatial_embed_tokens = nn.Embedding(self.config.spatial_token_num, config.text_config.hidden_size)
|
192 |
else:
|
193 |
self.spatial_embed_tokens = None
|
|
|
194 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
|
|
|
195 |
|
196 |
|
197 |
def backproject_patch(self, K: torch.Tensor, depth: torch.Tensor, patch_size=14, reso=2) -> torch.Tensor:
|
|
|
200 |
Args:
|
201 |
K: camera intrinsic matrix (b 3 3)
|
202 |
depth: depth map (b 1 h w)
|
203 |
+
patch_size: patch size for siglip
|
204 |
+
reso: reso^2 -> sample points in each patch
|
205 |
+
patch sz = 14 ......
|
206 |
+
┌────────┬────────┐
|
207 |
+
│ ─ ─ │ ─ ─ │
|
208 |
+
│ points │ ├─ ─ ─
|
209 |
+
│ ─ ─ │ ─ ─ │
|
210 |
+
├────────┼────────┤
|
211 |
+
│ ─ ─ │ ─ ─ │
|
212 |
+
│ │ │
|
213 |
+
│ ─ ─ │ ─ ─ │
|
214 |
+
└────────┴────────┘
|
215 |
+
reso=2───►points=4
|
216 |
+
│
|
217 |
+
│
|
218 |
"""
|
|
|
219 |
b, c, h, w = depth.shape
|
220 |
hp, wp = h // patch_size, w // patch_size
|
221 |
sub_hp = sub_wp = reso
|
222 |
+
patch_depth = F.interpolate(depth, size=(hp * reso, wp * reso), mode="area").reshape(b, c, -1)
|
|
|
|
|
223 |
p_cam = (inv(K.float()) @ self.uv_h.float()) * patch_depth # (b 3 3) @ (3 hw) -> (b 3 hw) * (b 1 hw) -> (b 3 hw)
|
224 |
patch_p_cam = p_cam.reshape(b, 3, hp, sub_hp, wp, sub_wp).permute(0, 2, 4, 3, 5, 1).reshape(b, hp * wp, -1)
|
225 |
return patch_p_cam
|
226 |
|
|
|
227 |
def get_input_embeddings(self):
|
228 |
return self.language_model.get_input_embeddings()
|
229 |
|
|
|
230 |
def set_input_embeddings(self, value):
|
231 |
self.language_model.set_input_embeddings(value)
|
232 |
|
|
|
233 |
def get_output_embeddings(self):
|
234 |
return self.language_model.get_output_embeddings()
|
235 |
|
|
|
236 |
def set_output_embeddings(self, new_embeddings):
|
237 |
self.language_model.set_output_embeddings(new_embeddings)
|
238 |
|
|
|
239 |
def set_decoder(self, decoder):
|
240 |
self.language_model.set_decoder(decoder)
|
241 |
|
|
|
242 |
def get_decoder(self):
|
243 |
return self.language_model.get_decoder()
|
244 |
|
|
|
245 |
def tie_weights(self):
|
246 |
return self.language_model.tie_weights()
|
247 |
|
|
|
251 |
pad_to_multiple_of: Optional[int] = None,
|
252 |
mean_resizing: bool = True,
|
253 |
) -> nn.Embedding:
|
|
|
|
|
254 |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
|
|
|
|
255 |
vocab_size = model_embeds.weight.shape[0]
|
256 |
self.config.text_config.vocab_size = self.vocab_size = self.config._vocab_size = vocab_size
|
257 |
self.tie_weights()
|
|
|
288 |
)
|
289 |
|
290 |
if attention_mask is not None and attention_mask.dim() == 4:
|
|
|
291 |
return attention_mask
|
292 |
|
293 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device)
|
|
|
|
|
|
|
294 |
if sequence_length != 1:
|
295 |
+
if is_training: causal_mask = torch.triu(causal_mask, diagonal=1)
|
296 |
+
else: causal_mask[:, :sequence_length] = 0.0
|
|
|
|
|
297 |
|
298 |
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
299 |
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
|
|
|
302 |
mask_length = attention_mask.shape[-1]
|
303 |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
304 |
padding_mask = padding_mask == 0
|
305 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
|
|
|
|
|
|
|
306 |
if is_training:
|
307 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0)
|
|
|
|
|
308 |
return causal_mask
|
309 |
|
310 |
def get_image_features(self, pixel_values: torch.FloatTensor, intrinsic: torch.FloatTensor):
|
311 |
+
siglip_pixel_values = TF.normalize(pixel_values, mean=SIGLIP_MEAN, std=SIGLIP_STD)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
image_outputs = self.vision_tower(siglip_pixel_values)
|
313 |
|
314 |
# ego3d position encoding
|
|
|
317 |
with torch.no_grad():
|
318 |
pvh, pvw = pixel_values.shape[-2:]
|
319 |
depth = self.vision_zoe_model(pixel_values=zoe_pixel_values).predicted_depth
|
320 |
+
depth = F.interpolate(
|
321 |
depth.unsqueeze(1),
|
322 |
size=(pvh+2*ph, pvw+2*pw),
|
323 |
mode="bicubic",
|
324 |
align_corners=True,
|
325 |
)[..., ph:-ph, pw:-pw]
|
|
|
326 |
xyz = self.backproject_patch(
|
327 |
intrinsic, depth, patch_size=self.config.vision_config.patch_size, reso=self.config.ego3d_patch_reso
|
328 |
) # (b, n, 3*4)
|
|
|
334 |
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
335 |
return image_features
|
336 |
|
|
|
|
|
337 |
def forward(
|
338 |
self,
|
339 |
input_ids: torch.LongTensor = None,
|
|
|
353 |
return_dict: Optional[bool] = None,
|
354 |
num_logits_to_keep: int = 0,
|
355 |
) -> Union[Tuple, SpatialVLACausalLMOutputWithPast]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
+
output_attentions = output_attentions or self.config.output_attentions
|
358 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
359 |
+
return_dict = return_dict or self.config.use_return_dict
|
|
|
|
|
360 |
|
361 |
is_training = token_type_ids is not None and labels is not None
|
362 |
|
363 |
+
if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids).clone() # avoid checkpint grad True
|
364 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
if self.config.use_spatial_token:
|
366 |
spatial_selected = (input_ids >= self.config.action_token_begin_idx) & (input_ids < self.config.action_token_begin_idx + self.config.spatial_token_num)
|
367 |
inputs_embeds[spatial_selected] = inputs_embeds[spatial_selected] * 0.0 + self.spatial_embed_tokens(input_ids[spatial_selected] - self.config.action_token_begin_idx)
|
368 |
|
369 |
if cache_position is None:
|
370 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
371 |
+
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
|
|
|
372 |
|
373 |
if position_ids is None:
|
374 |
position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
|
375 |
|
376 |
+
# merge
|
377 |
if pixel_values is not None:
|
378 |
image_features = self.get_image_features(pixel_values, intrinsic)
|
|
|
379 |
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
380 |
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
381 |
if inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
|
415 |
logits = outputs.logits
|
416 |
loss = None
|
417 |
if labels is not None:
|
|
|
418 |
logits = logits.float()
|
419 |
shift_logits = logits[..., :-1, :]
|
420 |
shift_labels = labels[..., 1:]
|
421 |
if attention_mask is not None:
|
|
|
|
|
422 |
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
|
423 |
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
424 |
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
425 |
else:
|
426 |
shift_logits = shift_logits.contiguous()
|
427 |
shift_labels = shift_labels.contiguous()
|
|
|
428 |
loss_fct = nn.CrossEntropyLoss()
|
429 |
|
430 |
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
|
443 |
image_hidden_states=image_features if pixel_values is not None else None,
|
444 |
)
|
445 |
|
446 |
+
# AR inference
|
447 |
def prepare_inputs_for_generation(
|
448 |
self,
|
449 |
input_ids,
|
|
|
460 |
labels=None,
|
461 |
**kwargs,
|
462 |
):
|
|
|
463 |
model_inputs = self.language_model.prepare_inputs_for_generation(
|
464 |
input_ids,
|
465 |
past_key_values=past_key_values,
|
|
|
472 |
token_type_ids=token_type_ids,
|
473 |
**kwargs,
|
474 |
)
|
|
|
|
|
475 |
if model_inputs.get("position_ids") is not None:
|
476 |
model_inputs["position_ids"] += 1
|
|
|
|
|
477 |
if cache_position[0] == 0:
|
478 |
model_inputs["pixel_values"] = pixel_values
|
479 |
is_training = token_type_ids is not None and labels is not None
|
480 |
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
|
481 |
+
causal_mask = self._update_causal_mask(attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training)
|
|
|
|
|
482 |
model_inputs["attention_mask"] = causal_mask
|
483 |
model_inputs["intrinsic"] = intrinsic
|
484 |
return model_inputs
|
|
|
523 |
weights_only=weights_only,
|
524 |
**kwargs,
|
525 |
)
|
|
|
|
|
|
|
526 |
if model.config.use_spatial_token:
|
527 |
model.language_model.model.embed_tokens.weight.data[-model.config.spatial_token_num:] = model.spatial_embed_tokens.weight.data
|
528 |
return model
|
processing_spatialvla.py
CHANGED
@@ -1,142 +1,38 @@
|
|
1 |
-
# MIT License
|
2 |
-
# Copyright (c) 2025 IPEC at Shanghai AI Laboratory
|
3 |
-
# Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
|
4 |
-
# distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
|
5 |
-
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
|
7 |
-
# Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
|
8 |
# coding=utf-8
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import logging
|
15 |
from typing import List, Optional, Union, Dict
|
16 |
-
import torch
|
17 |
import numpy as np
|
18 |
-
|
19 |
from transformers.feature_extraction_utils import BatchFeature
|
20 |
from transformers.image_utils import ImageInput, is_valid_image
|
21 |
-
from transformers.processing_utils import
|
22 |
-
|
23 |
-
ProcessingKwargs,
|
24 |
-
ProcessorMixin,
|
25 |
-
TextKwargs,
|
26 |
-
Unpack,
|
27 |
-
_validate_images_text_input_order,
|
28 |
-
)
|
29 |
-
from transformers.tokenization_utils_base import (
|
30 |
-
AddedToken,
|
31 |
-
PreTokenizedInput,
|
32 |
-
TextInput,
|
33 |
-
)
|
34 |
from transformers.utils import logging
|
35 |
-
from .
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
logger = logging.get_logger(__name__)
|
38 |
|
39 |
-
IMAGE_TOKEN = "<image>"
|
40 |
-
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
|
41 |
-
|
42 |
-
|
43 |
-
class PaliGemmaTextKwargs(TextKwargs):
|
44 |
-
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
|
45 |
-
|
46 |
-
|
47 |
-
class PaliGemmaImagesKwargs(ImagesKwargs):
|
48 |
-
do_convert_rgb: Optional[bool]
|
49 |
-
|
50 |
-
|
51 |
-
class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
|
52 |
-
text_kwargs: PaliGemmaTextKwargs
|
53 |
-
images_kwargs: PaliGemmaImagesKwargs
|
54 |
-
_defaults = {
|
55 |
-
"text_kwargs": {
|
56 |
-
"padding": False,
|
57 |
-
},
|
58 |
-
"images_kwargs": {
|
59 |
-
"data_format": "channels_first",
|
60 |
-
},
|
61 |
-
}
|
62 |
-
|
63 |
-
|
64 |
-
# Copied from transformers.models.idefics2.processing_idefics2.is_url
|
65 |
-
def is_url(val) -> bool:
|
66 |
-
return isinstance(val, str) and val.startswith("http")
|
67 |
-
|
68 |
-
|
69 |
-
# Copied from transformers.models.idefics2.processing_idefics2.is_image_or_image_url
|
70 |
-
def is_image_or_image_url(elem):
|
71 |
-
return is_url(elem) or is_valid_image(elem)
|
72 |
-
|
73 |
-
|
74 |
-
def _is_str_or_image(elem):
|
75 |
-
return isinstance(elem, (str)) or is_image_or_image_url(elem)
|
76 |
-
|
77 |
-
|
78 |
-
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
|
79 |
-
"""
|
80 |
-
Builds a string from the input prompt and image tokens.
|
81 |
-
For example, for the call:
|
82 |
-
build_string_from_input(
|
83 |
-
prompt="Prefix str"
|
84 |
-
bos_token="<s>",
|
85 |
-
image_seq_len=3,
|
86 |
-
image_token="<im>",
|
87 |
-
)
|
88 |
-
The output will be:
|
89 |
-
"<im><im><im><s>Initial str"
|
90 |
-
Args:
|
91 |
-
prompt (`List[Union[str, ImageInput]]`): The input prompt.
|
92 |
-
bos_token (`str`): The beginning of sentence token.
|
93 |
-
image_seq_len (`int`): The length of the image sequence.
|
94 |
-
image_token (`str`): The image token.
|
95 |
-
num_images (`int`): Number of images in the prompt.
|
96 |
-
"""
|
97 |
-
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
|
98 |
-
|
99 |
-
|
100 |
-
# Copied from transformers.models.llava_next.image_processing_llava_next.make_batched_images
|
101 |
-
def make_batched_images(images) -> List[List[ImageInput]]:
|
102 |
-
"""
|
103 |
-
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
104 |
-
|
105 |
-
Args:
|
106 |
-
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
107 |
-
The input image.
|
108 |
-
|
109 |
-
Returns:
|
110 |
-
list: A list of images.
|
111 |
-
"""
|
112 |
-
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
113 |
-
return [img for img_list in images for img in img_list]
|
114 |
-
|
115 |
-
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
116 |
-
return images
|
117 |
-
|
118 |
-
elif is_valid_image(images):
|
119 |
-
return [images]
|
120 |
-
|
121 |
-
raise ValueError(f"Could not make batched video from {images}")
|
122 |
-
|
123 |
-
|
124 |
class SpatialVLAProcessor(ProcessorMixin):
|
125 |
-
r"""
|
126 |
-
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
|
127 |
-
|
128 |
-
[`PaliGemmaProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`LlamaTokenizerFast`]. See the
|
129 |
-
[`~PaliGemmaProcessor.__call__`] and [`~PaliGemmaProcessor.decode`] for more information.
|
130 |
-
|
131 |
-
Args:
|
132 |
-
image_processor ([`SiglipImageProcessor`], *optional*):
|
133 |
-
The image processor is a required input.
|
134 |
-
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
135 |
-
The tokenizer is a required input.
|
136 |
-
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
137 |
-
in a chat into a tokenizable string.
|
138 |
-
"""
|
139 |
-
|
140 |
attributes = ["image_processor", "tokenizer"]
|
141 |
valid_kwargs = ["chat_template"]
|
142 |
image_processor_class = "SiglipImageProcessor"
|
@@ -192,17 +88,13 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
192 |
self.dataset_intrinsics = {}
|
193 |
height, width = image_processor.size["height"], image_processor.size["width"]
|
194 |
|
|
|
195 |
for k, v in intrinsic_config.items():
|
196 |
K = torch.tensor(v["intrinsic"]).float()
|
197 |
-
|
198 |
-
K[0, 0] *= width / w
|
199 |
-
K[1, 1] *= height / h
|
200 |
-
K[0, 2] *= width / w
|
201 |
-
K[1, 2] *= height / h
|
202 |
self.dataset_intrinsics[k] = K
|
203 |
-
print(f"scale intrinsic of {k} from {v['intrinsic']} to {K} ...")
|
204 |
|
205 |
-
self.action_tokenizer =
|
206 |
tokenizer=tokenizer, num_bins=action_config["num_bins"],
|
207 |
bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
|
208 |
min_sigma=min_sigma,
|
@@ -212,70 +104,10 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
212 |
self,
|
213 |
images: ImageInput = None,
|
214 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
215 |
-
audio=None,
|
216 |
-
videos=None,
|
217 |
unnorm_key: Optional[str] = None,
|
218 |
suffix_actions: Optional[np.array] = None, # (t e)
|
219 |
**kwargs: Unpack[PaliGemmaProcessorKwargs],
|
220 |
) -> BatchFeature:
|
221 |
-
"""
|
222 |
-
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
223 |
-
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
|
224 |
-
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
225 |
-
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
|
226 |
-
of the above two methods for more information.
|
227 |
-
|
228 |
-
The usage for PaliGemma fine-tuning preparation is slightly different than usual. suffix passed are suffixes to
|
229 |
-
the prompt in `text`, and will be placed after the prompt. This is because attention is handled differently for
|
230 |
-
the prefix and the suffix. For instance,
|
231 |
-
```python
|
232 |
-
image = PIL_cow_image
|
233 |
-
prompt = "answer en Where is the cow standing?"
|
234 |
-
suffix = "on the beach"
|
235 |
-
inputs = processor(text=prompt, images=image, suffix=suffix)
|
236 |
-
```
|
237 |
-
Here `inputs` will contain the `input_ids` and `token_type_ids` that follow
|
238 |
-
```python
|
239 |
-
inputs["input_ids"][:, 256:]
|
240 |
-
# tensor([[ 2, 6006, 603, 573, 13910, 9980, 235336, 108, 477, 573, 8318]])
|
241 |
-
inputs["token_type_ids"][:, 256:]
|
242 |
-
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
|
243 |
-
```
|
244 |
-
Meaning the last three tokens are of "label" ("suffix") type while the other ones are of "prefix" type.
|
245 |
-
|
246 |
-
|
247 |
-
Args:
|
248 |
-
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
249 |
-
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
250 |
-
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
251 |
-
number of channels, H and W are image height and width.
|
252 |
-
text (`str`, `List[str]`, `List[List[str]]`):
|
253 |
-
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
254 |
-
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
255 |
-
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
256 |
-
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
257 |
-
If set, will return tensors of a particular framework. Acceptable values are:
|
258 |
-
|
259 |
-
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
260 |
-
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
261 |
-
- `'np'`: Return NumPy `np.ndarray` objects.
|
262 |
-
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
263 |
-
suffix (`str`, `List[str]`, `List[List[str]]`):
|
264 |
-
The suffixes or batch of suffixes to be encoded. Only necessary for finetuning. See https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
|
265 |
-
for more information. If your prompt is "<image> What is on the image", the suffix corresponds to the expected prediction "a cow sitting on a bench".
|
266 |
-
|
267 |
-
Returns:
|
268 |
-
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
269 |
-
|
270 |
-
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. If `suffix`
|
271 |
-
is provided, the `input_ids` will also contain the suffix input ids.
|
272 |
-
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
273 |
-
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
274 |
-
`None`).
|
275 |
-
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
276 |
-
- **labels** -- Labels compatible with training if `suffix` is not None
|
277 |
-
"""
|
278 |
-
# check if images and text inputs are reversed for BC
|
279 |
images, text = _validate_images_text_input_order(images, text)
|
280 |
|
281 |
output_kwargs = self._merge_kwargs(
|
@@ -294,9 +126,7 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
294 |
if images is None:
|
295 |
raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
|
296 |
if text is None:
|
297 |
-
logger.warning_once(
|
298 |
-
"You are using PaliGemma without a text prefix. It will perform as a picture-captioning model."
|
299 |
-
)
|
300 |
text = ""
|
301 |
|
302 |
if _is_str_or_image(text):
|
@@ -306,31 +136,19 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
306 |
|
307 |
if text is not None and images is not None:
|
308 |
if not any(IMAGE_TOKEN in sample for sample in text):
|
309 |
-
# logger.warning(
|
310 |
-
# "You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special "
|
311 |
-
# "image tokens in the text, as many tokens as there are images per each text. It is recommended to "
|
312 |
-
# "add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images "
|
313 |
-
# "each text has and add special tokens."
|
314 |
-
# )
|
315 |
if isinstance(text, List) and isinstance(images, List):
|
316 |
if len(images) != len(text):
|
317 |
raise ValueError(
|
318 |
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
|
319 |
)
|
320 |
-
|
321 |
-
# make a nested list of lists to be able to iterate over the images and text below
|
322 |
if is_valid_image(images):
|
323 |
images = [[images]]
|
324 |
elif isinstance(images, list) and is_valid_image(images[0]):
|
325 |
images = [[image] for image in images]
|
326 |
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
327 |
raise ValueError("images must be an image, list of images or list of list of images")
|
328 |
-
|
329 |
-
if suffix is not None
|
330 |
-
suffix = [suffix]
|
331 |
-
if suffix is not None:
|
332 |
-
suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
|
333 |
-
|
334 |
input_strings = [
|
335 |
build_string_from_input(
|
336 |
prompt=prompt,
|
@@ -355,7 +173,6 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
355 |
input_strings = [f"{sample}\n" for sample in expanded_samples]
|
356 |
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
357 |
|
358 |
-
# max_length has to account for the image tokens
|
359 |
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
|
360 |
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
|
361 |
|
@@ -391,7 +208,6 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
391 |
return self.tokenizer.decode(*args, **kwargs)
|
392 |
|
393 |
@property
|
394 |
-
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma
|
395 |
def model_input_names(self):
|
396 |
tokenizer_input_names = self.tokenizer.model_input_names
|
397 |
image_processor_input_names = self.image_processor.model_input_names
|
@@ -407,7 +223,7 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
407 |
assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
|
408 |
|
409 |
if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
|
410 |
-
|
411 |
predicted_action_token_ids = np.concatenate(
|
412 |
[
|
413 |
predicted_action_token_ids,
|
@@ -417,9 +233,8 @@ class SpatialVLAProcessor(ProcessorMixin):
|
|
417 |
predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
|
418 |
normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
|
419 |
|
420 |
-
# Unnormalize actions
|
421 |
if unnorm_key is None:
|
422 |
-
|
423 |
unnorm_key = next(self.statistics.keys())
|
424 |
action_norm_stats = self.statistics[unnorm_key]["action"]
|
425 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
import logging
|
16 |
from typing import List, Optional, Union, Dict
|
|
|
17 |
import numpy as np
|
18 |
+
import torch
|
19 |
from transformers.feature_extraction_utils import BatchFeature
|
20 |
from transformers.image_utils import ImageInput, is_valid_image
|
21 |
+
from transformers.processing_utils import Unpack, _validate_images_text_input_order, ProcessorMixin
|
22 |
+
from transformers.tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from transformers.utils import logging
|
24 |
+
from transformers.models.paligemma.processing_paligemma import (
|
25 |
+
make_batched_images,
|
26 |
+
build_string_from_input,
|
27 |
+
_is_str_or_image,
|
28 |
+
PaliGemmaProcessorKwargs,
|
29 |
+
IMAGE_TOKEN,
|
30 |
+
EXTRA_TOKENS
|
31 |
+
)
|
32 |
+
from .action_tokenizer import SpatialActionTokenizer
|
33 |
logger = logging.get_logger(__name__)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
class SpatialVLAProcessor(ProcessorMixin):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
attributes = ["image_processor", "tokenizer"]
|
37 |
valid_kwargs = ["chat_template"]
|
38 |
image_processor_class = "SiglipImageProcessor"
|
|
|
88 |
self.dataset_intrinsics = {}
|
89 |
height, width = image_processor.size["height"], image_processor.size["width"]
|
90 |
|
91 |
+
# scale intrinsic matrix
|
92 |
for k, v in intrinsic_config.items():
|
93 |
K = torch.tensor(v["intrinsic"]).float()
|
94 |
+
K[:2] *= torch.tensor([width / v["width"], height / v["height"]])[:, None]
|
|
|
|
|
|
|
|
|
95 |
self.dataset_intrinsics[k] = K
|
|
|
96 |
|
97 |
+
self.action_tokenizer = SpatialActionTokenizer(
|
98 |
tokenizer=tokenizer, num_bins=action_config["num_bins"],
|
99 |
bin_policy=bin_policy, use_spherical=action_config["use_spherical"],
|
100 |
min_sigma=min_sigma,
|
|
|
104 |
self,
|
105 |
images: ImageInput = None,
|
106 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
|
|
|
107 |
unnorm_key: Optional[str] = None,
|
108 |
suffix_actions: Optional[np.array] = None, # (t e)
|
109 |
**kwargs: Unpack[PaliGemmaProcessorKwargs],
|
110 |
) -> BatchFeature:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
images, text = _validate_images_text_input_order(images, text)
|
112 |
|
113 |
output_kwargs = self._merge_kwargs(
|
|
|
126 |
if images is None:
|
127 |
raise ValueError("`images` are expected as arguments to a `PaliGemmaProcessor` instance.")
|
128 |
if text is None:
|
129 |
+
logger.warning_once( "You are using PaliGemma without a text prefix. It will perform as a picture-captioning model.")
|
|
|
|
|
130 |
text = ""
|
131 |
|
132 |
if _is_str_or_image(text):
|
|
|
136 |
|
137 |
if text is not None and images is not None:
|
138 |
if not any(IMAGE_TOKEN in sample for sample in text):
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
if isinstance(text, List) and isinstance(images, List):
|
140 |
if len(images) != len(text):
|
141 |
raise ValueError(
|
142 |
f"Received {len(images)} images for {len(text)} prompts. Each prompt should be associated with an image or list of images."
|
143 |
)
|
|
|
|
|
144 |
if is_valid_image(images):
|
145 |
images = [[images]]
|
146 |
elif isinstance(images, list) and is_valid_image(images[0]):
|
147 |
images = [[image] for image in images]
|
148 |
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
|
149 |
raise ValueError("images must be an image, list of images or list of list of images")
|
150 |
+
if suffix is not None and _is_str_or_image(suffix): suffix = [suffix]
|
151 |
+
if suffix is not None: suffix = [sfx + self.tokenizer.eos_token for sfx in suffix]
|
|
|
|
|
|
|
|
|
152 |
input_strings = [
|
153 |
build_string_from_input(
|
154 |
prompt=prompt,
|
|
|
173 |
input_strings = [f"{sample}\n" for sample in expanded_samples]
|
174 |
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
175 |
|
|
|
176 |
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
|
177 |
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
|
178 |
|
|
|
208 |
return self.tokenizer.decode(*args, **kwargs)
|
209 |
|
210 |
@property
|
|
|
211 |
def model_input_names(self):
|
212 |
tokenizer_input_names = self.tokenizer.model_input_names
|
213 |
image_processor_input_names = self.image_processor.model_input_names
|
|
|
223 |
assert self.tokenizer.eos_token != predicted_action_token_ids[-1], "[error] actions contain EOS token, please check you truncation settings!"
|
224 |
|
225 |
if predicted_action_token_ids.shape[0] < action_token_num * self.action_chunk_size: # pad with zeros
|
226 |
+
logger.warning(f"Padding zero action!")
|
227 |
predicted_action_token_ids = np.concatenate(
|
228 |
[
|
229 |
predicted_action_token_ids,
|
|
|
233 |
predicted_action_token_ids = predicted_action_token_ids.reshape(-1, action_token_num)
|
234 |
normalized_action_chunks = self.action_tokenizer.decode_token_ids_to_actions(predicted_action_token_ids)
|
235 |
|
|
|
236 |
if unnorm_key is None:
|
237 |
+
logger.warning(f"unnorm_key {unnorm_key} is not in statistics, use next one")
|
238 |
unnorm_key = next(self.statistics.keys())
|
239 |
action_norm_stats = self.statistics[unnorm_key]["action"]
|
240 |
|
test_huggingface.py
CHANGED
@@ -1,17 +1,12 @@
|
|
1 |
import os
|
2 |
import argparse
|
3 |
from pathlib import Path
|
4 |
-
import shutil
|
5 |
-
import os
|
6 |
-
import argparse
|
7 |
-
from pathlib import Path
|
8 |
-
import shutil
|
9 |
import torch
|
10 |
from PIL import Image
|
11 |
from transformers import AutoModel, AutoProcessor
|
12 |
|
13 |
parser = argparse.ArgumentParser("Huggingface AutoModel Tesing")
|
14 |
-
parser.add_argument("--model_name_or_path", default="", help="pretrained model name or path.")
|
15 |
parser.add_argument("--num_images", type=int, default=1, help="num_images for testing.")
|
16 |
|
17 |
args = parser.parse_args()
|
@@ -32,4 +27,4 @@ if __name__ == "__main__":
|
|
32 |
print(generation_outputs, processor.batch_decode(generation_outputs))
|
33 |
|
34 |
actions = processor.decode_actions(generation_outputs, unnorm_key="bridge_orig/1.0.0")
|
35 |
-
print(actions)
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
from PIL import Image
|
6 |
from transformers import AutoModel, AutoProcessor
|
7 |
|
8 |
parser = argparse.ArgumentParser("Huggingface AutoModel Tesing")
|
9 |
+
parser.add_argument("--model_name_or_path", default=".", help="pretrained model name or path.")
|
10 |
parser.add_argument("--num_images", type=int, default=1, help="num_images for testing.")
|
11 |
|
12 |
args = parser.parse_args()
|
|
|
27 |
print(generation_outputs, processor.batch_decode(generation_outputs))
|
28 |
|
29 |
actions = processor.decode_actions(generation_outputs, unnorm_key="bridge_orig/1.0.0")
|
30 |
+
print(actions)
|