Update modeling_moonvit.py
Browse files- modeling_moonvit.py +3 -1
modeling_moonvit.py
CHANGED
|
@@ -180,7 +180,7 @@ class Learnable2DInterpPosEmb(nn.Module):
|
|
| 180 |
|
| 181 |
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 182 |
pos_embs = []
|
| 183 |
-
for shape in grid_hws
|
| 184 |
if shape == self.weight.shape[:-1]:
|
| 185 |
pos_embs.append(self.weight.flatten(end_dim=1))
|
| 186 |
else:
|
|
@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
| 596 |
Returns:
|
| 597 |
torch.Tensor: The output tokens.
|
| 598 |
"""
|
|
|
|
|
|
|
| 599 |
hidden_states = self.patch_embed(pixel_values, image_grid_hws)
|
| 600 |
hidden_states = self.encoder(hidden_states, image_grid_hws)
|
| 601 |
hidden_states = patch_merger(
|
|
|
|
| 180 |
|
| 181 |
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 182 |
pos_embs = []
|
| 183 |
+
for shape in grid_hws.tolist():
|
| 184 |
if shape == self.weight.shape[:-1]:
|
| 185 |
pos_embs.append(self.weight.flatten(end_dim=1))
|
| 186 |
else:
|
|
|
|
| 596 |
Returns:
|
| 597 |
torch.Tensor: The output tokens.
|
| 598 |
"""
|
| 599 |
+
if image_grid_hws.shape[-1] == 3:
|
| 600 |
+
image_grid_hws = image_grid_hws[:, 1:]
|
| 601 |
hidden_states = self.patch_embed(pixel_values, image_grid_hws)
|
| 602 |
hidden_states = self.encoder(hidden_states, image_grid_hws)
|
| 603 |
hidden_states = patch_merger(
|