Emrys-Hong
commited on
Commit
·
1b2ebf2
1
Parent(s):
ae16192
Update
Browse files- modeling_prismatic.py +5 -8
modeling_prismatic.py
CHANGED
@@ -541,7 +541,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
541 |
return actions, generated_ids
|
542 |
|
543 |
@torch.inference_mode()
|
544 |
-
def generate_actions(self,
|
545 |
# For now, only support generation with a batch size of 1 for simplicity
|
546 |
# image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
547 |
|
@@ -557,18 +557,15 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
557 |
# raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
558 |
|
559 |
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
560 |
-
autocast_dtype = self.llm_backbone.half_precision_dtype
|
561 |
# with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
562 |
-
with torch.autocast("cuda", dtype=torch.
|
563 |
# fmt: off
|
564 |
generated_ids = self.generate(
|
565 |
-
|
566 |
-
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
|
567 |
-
**kwargs
|
568 |
)
|
569 |
# fmt: on
|
570 |
|
571 |
-
generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
|
572 |
|
573 |
s = solver
|
574 |
actions, reasoning = s.extract_action_policies(generated_text)
|
@@ -586,7 +583,7 @@ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
|
|
586 |
)
|
587 |
_actions.append(action_norm)
|
588 |
|
589 |
-
return _actions, generated_text
|
590 |
|
591 |
|
592 |
@staticmethod
|
|
|
541 |
return actions, generated_ids
|
542 |
|
543 |
@torch.inference_mode()
|
544 |
+
def generate_actions(self, inputs, tokenizer, **kwargs: str) -> str:
|
545 |
# For now, only support generation with a batch size of 1 for simplicity
|
546 |
# image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
|
547 |
|
|
|
557 |
# raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
|
558 |
|
559 |
# Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
|
|
|
560 |
# with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
|
561 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
562 |
# fmt: off
|
563 |
generated_ids = self.generate(
|
564 |
+
**inputs, **kwargs
|
|
|
|
|
565 |
)
|
566 |
# fmt: on
|
567 |
|
568 |
+
generated_text = tokenizer.decode(generated_ids[0, inputs['input_ids'].shape[1] :], skip_special_tokens=True).strip()
|
569 |
|
570 |
s = solver
|
571 |
actions, reasoning = s.extract_action_policies(generated_text)
|
|
|
583 |
)
|
584 |
_actions.append(action_norm)
|
585 |
|
586 |
+
return _actions[0], generated_text
|
587 |
|
588 |
|
589 |
@staticmethod
|