Image-Text-to-Text
Safetensors
openvla
custom_code
Emrys-Hong commited on
Commit
1b2ebf2
·
1 Parent(s): ae16192
Files changed (1) hide show
  1. modeling_prismatic.py +5 -8
modeling_prismatic.py CHANGED
@@ -541,7 +541,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
541
  return actions, generated_ids
542
 
543
  @torch.inference_mode()
544
- def generate_actions(self, image: Image, prompt_text: str, type: str, **kwargs: str) -> str:
545
  # For now, only support generation with a batch size of 1 for simplicity
546
  # image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
547
 
@@ -557,18 +557,15 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
557
  # raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
558
 
559
  # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
560
- autocast_dtype = self.llm_backbone.half_precision_dtype
561
  # with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
562
- with torch.autocast("cuda", dtype=torch.float16):
563
  # fmt: off
564
  generated_ids = self.generate(
565
- input_ids=input_ids, # Shape: [1, seq]
566
- pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
567
- **kwargs
568
  )
569
  # fmt: on
570
 
571
- generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
572
 
573
  s = solver
574
  actions, reasoning = s.extract_action_policies(generated_text)
@@ -586,7 +583,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
586
  )
587
  _actions.append(action_norm)
588
 
589
- return _actions, generated_text
590
 
591
 
592
  @staticmethod
 
541
  return actions, generated_ids
542
 
543
  @torch.inference_mode()
544
+ def generate_actions(self, inputs, tokenizer, **kwargs: str) -> str:
545
  # For now, only support generation with a batch size of 1 for simplicity
546
  # image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
547
 
 
557
  # raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
558
 
559
  # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
 
560
  # with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
561
+ with torch.autocast("cuda", dtype=torch.bfloat16):
562
  # fmt: off
563
  generated_ids = self.generate(
564
+ **inputs, **kwargs
 
 
565
  )
566
  # fmt: on
567
 
568
+ generated_text = tokenizer.decode(generated_ids[0, inputs['input_ids'].shape[1] :], skip_special_tokens=True).strip()
569
 
570
  s = solver
571
  actions, reasoning = s.extract_action_policies(generated_text)
 
583
  )
584
  _actions.append(action_norm)
585
 
586
+ return _actions[0], generated_text
587
 
588
 
589
  @staticmethod