damerajee commited on
Commit
530ec53
·
verified ·
1 Parent(s): 137bd3e

Create vision_encoder.py

Browse files
Files changed (1) hide show
  1. vision_encoder.py +22 -0
vision_encoder.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import ViTModel
3
+ from torchvision import transforms
4
+ import torch
5
+
6
+ class VisionEncoder(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.vision_model = ViTModel.from_pretrained("google/vit-base-patch16-224")
10
+ self.image_transform = transforms.Compose([
11
+ transforms.Resize((224, 224)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
14
+ ])
15
+
16
+ def forward(self, images,device):
17
+ processed_images = torch.stack([self.image_transform(image) for image in images]).to(device)
18
+ with torch.no_grad():
19
+ pixel_values = self.vision_model(processed_images)
20
+ image_features = pixel_values.last_hidden_state
21
+ image_features = image_features
22
+ return image_features