damerajee commited on
Commit
409b074
·
verified ·
1 Parent(s): 30504b0

Update vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +7 -8
vision_encoder.py CHANGED
@@ -1,11 +1,10 @@
1
- import torch.nn as nn
2
- from transformers import ViTModel
3
  from torchvision import transforms
4
- import torch
5
 
6
  import transformers
7
 
8
-
9
  transformers.logging.set_verbosity_error()
10
 
11
  class VisionEncoder(nn.Module):
@@ -18,9 +17,9 @@ class VisionEncoder(nn.Module):
18
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
19
  ])
20
 
21
- def forward(self, images,device):
22
- processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
23
  with torch.no_grad():
24
- pixel_values = self.vision_model(processed_images)
25
  image_features = pixel_values.last_hidden_state
26
- return image_features
 
1
+ import torch.nn as nn
2
+ from transformers import ViTModel
3
  from torchvision import transforms
4
+ import torch
5
 
6
  import transformers
7
 
 
8
  transformers.logging.set_verbosity_error()
9
 
10
  class VisionEncoder(nn.Module):
 
17
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ def forward(self, image, device):
21
+ processed_image = self.image_transform(image).unsqueeze(0).to(device)
22
  with torch.no_grad():
23
+ pixel_values = self.vision_model(processed_image)
24
  image_features = pixel_values.last_hidden_state
25
+ return image_features