update action inference code
Browse files- modeling_prismatic.py +6 -5
modeling_prismatic.py
CHANGED
@@ -510,12 +510,13 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
510 |
|
511 |
# We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
|
512 |
# in order for the predictions to match the training configuration and be accurate.
|
513 |
-
|
514 |
-
|
515 |
-
)
|
|
|
516 |
|
517 |
# Run VLA inference
|
518 |
-
generated_ids = self.generate(input_ids,
|
519 |
|
520 |
# Extract predicted action tokens and translate into (normalized) continuous actions
|
521 |
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
|
@@ -533,7 +534,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
|
|
533 |
normalized_actions,
|
534 |
)
|
535 |
|
536 |
-
return actions
|
537 |
|
538 |
@staticmethod
|
539 |
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|
|
|
510 |
|
511 |
# We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
|
512 |
# in order for the predictions to match the training configuration and be accurate.
|
513 |
+
# NOTE: This is NOT needed for ECoT
|
514 |
+
# input_ids = torch.cat(
|
515 |
+
# (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
|
516 |
+
# )
|
517 |
|
518 |
# Run VLA inference
|
519 |
+
generated_ids = self.generate(input_ids, **kwargs)
|
520 |
|
521 |
# Extract predicted action tokens and translate into (normalized) continuous actions
|
522 |
predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
|
|
|
534 |
normalized_actions,
|
535 |
)
|
536 |
|
537 |
+
return actions, generated_ids
|
538 |
|
539 |
@staticmethod
|
540 |
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
|