verityw commited on
Commit
be477d4
·
1 Parent(s): 31c5444

update action inference code

Browse files
Files changed (1) hide show
  1. modeling_prismatic.py +6 -5
modeling_prismatic.py CHANGED
@@ -510,12 +510,13 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
510
 
511
  # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
512
  # in order for the predictions to match the training configuration and be accurate.
513
- input_ids = torch.cat(
514
- (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
515
- )
 
516
 
517
  # Run VLA inference
518
- generated_ids = self.generate(input_ids, max_new_tokens=self.get_action_dim(unnorm_key), **kwargs)
519
 
520
  # Extract predicted action tokens and translate into (normalized) continuous actions
521
  predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
@@ -533,7 +534,7 @@ class OpenVLAForActionPrediction(PrismaticForConditionalGeneration):
533
  normalized_actions,
534
  )
535
 
536
- return actions
537
 
538
  @staticmethod
539
  def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
 
510
 
511
  # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
512
  # in order for the predictions to match the training configuration and be accurate.
513
+ # NOTE: This is NOT needed for ECoT
514
+ # input_ids = torch.cat(
515
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
516
+ # )
517
 
518
  # Run VLA inference
519
+ generated_ids = self.generate(input_ids, **kwargs)
520
 
521
  # Extract predicted action tokens and translate into (normalized) continuous actions
522
  predicted_action_token_ids = generated_ids[0, -self.get_action_dim(unnorm_key) :].cpu().numpy()
 
534
  normalized_actions,
535
  )
536
 
537
+ return actions, generated_ids
538
 
539
  @staticmethod
540
  def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: