Image-Text-to-Text
Safetensors
openvla
custom_code
emrys-hong commited on
Commit
74cbaa3
·
verified ·
1 Parent(s): 3debc13

Update processing_prismatic.py

Browse files
Files changed (1) hide show
  1. processing_prismatic.py +74 -17
processing_prismatic.py CHANGED
@@ -15,12 +15,21 @@ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTen
15
  from transformers import PreTrainedTokenizerBase
16
  from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
  from transformers.processing_utils import ProcessorMixin
18
- from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
 
 
 
 
 
19
  from transformers.utils import TensorType
20
 
 
 
21
 
22
  # === Image Processing ===
23
- def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
 
 
24
  """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
  (w, h), max_wh = image.size, max(image.size)
26
  horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
@@ -62,10 +71,19 @@ class PrismaticImageProcessor(ImageProcessingMixin):
62
  stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
 
64
  # TIMM `data_cfg` Parameters
65
- self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
 
 
 
 
 
66
 
67
  # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
- self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
 
 
 
 
69
  self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
 
71
  for idx in range(len(input_sizes)):
@@ -90,11 +108,17 @@ class PrismaticImageProcessor(ImageProcessingMixin):
90
  and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
  and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
  ):
93
- raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
 
 
94
 
95
  # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
  # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
- resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
 
 
 
 
98
  self.tvf_resize_params.append(
99
  {
100
  "size": resize_t.size,
@@ -117,11 +141,15 @@ class PrismaticImageProcessor(ImageProcessingMixin):
117
  if self.image_resize_strategy == "resize-naive":
118
  self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
  elif self.image_resize_strategy == "letterbox":
120
- self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
 
 
121
  elif self.image_resize_strategy == "resize-crop":
122
  pass
123
  else:
124
- raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
 
 
125
 
126
  # Dispatch **kwargs to super()
127
  super().__init__(**kwargs)
@@ -164,12 +192,19 @@ class PrismaticImageProcessor(ImageProcessingMixin):
164
  images = [images]
165
 
166
  # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
- pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
 
 
168
 
169
  # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
- return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
 
 
 
171
 
172
- def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
 
 
173
  return self.preprocess(images, **kwargs)
174
 
175
 
@@ -189,7 +224,9 @@ class PrismaticProcessor(ProcessorMixin):
189
 
190
  def __call__(
191
  self,
192
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
 
 
193
  images: Union[Image.Image, List[Image.Image]],
194
  padding: Union[bool, str, PaddingStrategy] = False,
195
  truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
@@ -209,21 +246,31 @@ class PrismaticProcessor(ProcessorMixin):
209
 
210
  @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
  """
212
- pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
 
 
213
  text_inputs = self.tokenizer(
214
- text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
 
 
 
 
215
  )
216
 
217
  # [Validate] Need same number of images and text inputs!
218
  if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
- raise ValueError("Batch is malformed; expected same number of images and text inputs!")
 
 
220
 
221
  return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
 
223
  # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
  def batch_decode(
225
  self,
226
- sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
 
 
227
  skip_special_tokens: bool = False,
228
  clean_up_tokenization_spaces: Optional[bool] = None,
229
  **kwargs: str,
@@ -237,7 +284,9 @@ class PrismaticProcessor(ProcessorMixin):
237
 
238
  def decode(
239
  self,
240
- token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
 
 
241
  skip_special_tokens: bool = False,
242
  clean_up_tokenization_spaces: Optional[bool] = None,
243
  **kwargs: str,
@@ -255,3 +304,11 @@ class PrismaticProcessor(ProcessorMixin):
255
  image_processor_input_names = self.image_processor.model_input_names
256
 
257
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
 
 
 
 
 
 
 
 
15
  from transformers import PreTrainedTokenizerBase
16
  from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
  from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import (
19
+ PaddingStrategy,
20
+ PreTokenizedInput,
21
+ TextInput,
22
+ TruncationStrategy,
23
+ )
24
  from transformers.utils import TensorType
25
 
26
+ from .gripper_position import get_gripper_pos_raw
27
+
28
 
29
  # === Image Processing ===
30
+ def letterbox_pad_transform(
31
+ image: Image.Image, padding_fill_value: Tuple[int, int, int]
32
+ ) -> Image.Image:
33
  """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
34
  (w, h), max_wh = image.size, max(image.size)
35
  horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
 
71
  stds = [(0.5, 0.5, 0.5)] if stds is None else stds
72
 
73
  # TIMM `data_cfg` Parameters
74
+ self.input_sizes, self.interpolations, self.means, self.stds = (
75
+ input_sizes,
76
+ interpolations,
77
+ means,
78
+ stds,
79
+ )
80
 
81
  # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
82
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = (
83
+ [],
84
+ [],
85
+ [],
86
+ )
87
  self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
88
 
89
  for idx in range(len(input_sizes)):
 
108
  and (transform.transforms[0].size == self.input_sizes[idx][-1])
109
  and (transform.transforms[1].size == self.input_sizes[idx][-2:])
110
  ):
111
+ raise ValueError(
112
+ f"Unexpected TIMM image transformation structure/sizes: `{transform}`"
113
+ )
114
 
115
  # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
116
  # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
117
+ resize_t, crop_t, norm_t = (
118
+ transform.transforms[0],
119
+ transform.transforms[1],
120
+ transform.transforms[3],
121
+ )
122
  self.tvf_resize_params.append(
123
  {
124
  "size": resize_t.size,
 
141
  if self.image_resize_strategy == "resize-naive":
142
  self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
143
  elif self.image_resize_strategy == "letterbox":
144
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple(
145
+ [int(x * 255) for x in self.means[idx]]
146
+ )
147
  elif self.image_resize_strategy == "resize-crop":
148
  pass
149
  else:
150
+ raise ValueError(
151
+ f"Image resize strategy `{self.image_resize_strategy}` is not supported!"
152
+ )
153
 
154
  # Dispatch **kwargs to super()
155
  super().__init__(**kwargs)
 
192
  images = [images]
193
 
194
  # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
195
+ pixel_values = torch.stack(
196
+ [self.apply_transform(img.convert("RGB")) for img in images]
197
+ )
198
 
199
  # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
200
+ return BatchFeature(
201
+ data={"pixel_values": pixel_values.float().numpy()},
202
+ tensor_type=return_tensors,
203
+ )
204
 
205
+ def __call__(
206
+ self, images: Union[Image.Image, List[Image.Image]], **kwargs
207
+ ) -> BatchFeature:
208
  return self.preprocess(images, **kwargs)
209
 
210
 
 
224
 
225
  def __call__(
226
  self,
227
+ text: Union[
228
+ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
229
+ ],
230
  images: Union[Image.Image, List[Image.Image]],
231
  padding: Union[bool, str, PaddingStrategy] = False,
232
  truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
 
246
 
247
  @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
248
  """
249
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)[
250
+ "pixel_values"
251
+ ]
252
  text_inputs = self.tokenizer(
253
+ text,
254
+ return_tensors=return_tensors,
255
+ padding=padding,
256
+ truncation=truncation,
257
+ max_length=max_length,
258
  )
259
 
260
  # [Validate] Need same number of images and text inputs!
261
  if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
262
+ raise ValueError(
263
+ "Batch is malformed; expected same number of images and text inputs!"
264
+ )
265
 
266
  return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
267
 
268
  # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
269
  def batch_decode(
270
  self,
271
+ sequences: Union[
272
+ List[int], List[List[int]], torch.Tensor, Any
273
+ ], # `Any` = np.ndarray | tf.Tensor
274
  skip_special_tokens: bool = False,
275
  clean_up_tokenization_spaces: Optional[bool] = None,
276
  **kwargs: str,
 
284
 
285
  def decode(
286
  self,
287
+ token_ids: Union[
288
+ int, List[int], torch.Tensor, Any
289
+ ], # `Any` = np.ndarray | tf.Tensor
290
  skip_special_tokens: bool = False,
291
  clean_up_tokenization_spaces: Optional[bool] = None,
292
  **kwargs: str,
 
304
  image_processor_input_names = self.image_processor.model_input_names
305
 
306
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
307
+
308
+ def get_prompt(self, task_label, image: Image.Image):
309
+ image = image.convert("RGB")
310
+ image = image.resize((224, 224))
311
+ gripper_pos, mask, prediction = get_gripper_pos_raw(image)
312
+
313
+ prompt = f"In: What action should the robot take to achieve the instruction\nINSTRUCTION: \n{task_label}\nCURRENT GRIPPER: {list(gripper_pos)}\n\nOut:"
314
+ return prompt, image