File size: 979 Bytes
319e2a1
0f98d7f
8f93744
319e2a1
0f98d7f
 
 
 
319e2a1
 
0f98d7f
 
 
 
099bc26
 
 
 
319e2a1
099bc26
319e2a1
 
 
 
0f98d7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import ViTFeatureExtractor, ViTModel
from PIL import Image
import numpy as np 
import torch

class VitBase():

    def __init__(self):
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
        self.model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
    
    def extract_feature(self, imgs):
        features = []
        for img in imgs:
            # Convert the image to RGB if it has 4 channels
            if img.mode == 'RGBA':
                img = img.convert('RGB')
                
            inputs = self.feature_extractor(images=img, return_tensors="pt")
            # print("input shape: ", inputs.shape)
            with torch.no_grad():
                outputs = self.model(**inputs)
            last_hidden_states =  outputs.last_hidden_state            
            features.append(np.squeeze(last_hidden_states.numpy()).flatten())            
        return features