shilinxu commited on
Commit
1b4543a
·
verified ·
1 Parent(s): f3d2bc2

Update modeling_moonvit.py

Browse files
Files changed (1) hide show
  1. modeling_moonvit.py +3 -1
modeling_moonvit.py CHANGED
@@ -180,7 +180,7 @@ class Learnable2DInterpPosEmb(nn.Module):
180
 
181
  def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
182
  pos_embs = []
183
- for shape in grid_hws[:, 1:].tolist():
184
  if shape == self.weight.shape[:-1]:
185
  pos_embs.append(self.weight.flatten(end_dim=1))
186
  else:
@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
596
  Returns:
597
  torch.Tensor: The output tokens.
598
  """
 
 
599
  hidden_states = self.patch_embed(pixel_values, image_grid_hws)
600
  hidden_states = self.encoder(hidden_states, image_grid_hws)
601
  hidden_states = patch_merger(
 
180
 
181
  def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
182
  pos_embs = []
183
+ for shape in grid_hws.tolist():
184
  if shape == self.weight.shape[:-1]:
185
  pos_embs.append(self.weight.flatten(end_dim=1))
186
  else:
 
596
  Returns:
597
  torch.Tensor: The output tokens.
598
  """
599
+ if image_grid_hws.shape[-1] == 3:
600
+ image_grid_hws = image_grid_hws[:, 1:]
601
  hidden_states = self.patch_embed(pixel_values, image_grid_hws)
602
  hidden_states = self.encoder(hidden_states, image_grid_hws)
603
  hidden_states = patch_merger(