theArijitDas commited on
Commit
a4e45ab
·
verified ·
1 Parent(s): 755413c

Upload image_validator.py

Browse files
Files changed (1) hide show
  1. image_validator.py +64 -0
image_validator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel, ViTImageProcessor, ViTModel
2
+ from PIL import Image
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ from warnings import filterwarnings
6
+ filterwarnings("ignore")
7
+
8
+ models = ["CLIP-ViT Base", "ViT Base", "DINO ViT-S16"]
9
+ models_info = {
10
+ "CLIP-ViT Base": {
11
+ "model_size": "386MB",
12
+ "model_url": "openai/clip-vit-base-patch32",
13
+ "efficiency": "High",
14
+ },
15
+ "ViT Base": {
16
+ "model_size": "304MB",
17
+ "model_url": "google/vit-base-patch16-224",
18
+ "efficiency": "High",
19
+ },
20
+ "DINO ViT-S16": {
21
+ "model_size": "1.34GB",
22
+ "model_url": "facebook/dino-vits16",
23
+ "efficiency": "Moderate",
24
+ },
25
+ }
26
+
27
+ class Image_Validator:
28
+ def __init__(self, model_name=None):
29
+ if model_name is None: model_name="ViT Base"
30
+
31
+ self.model_info = models_info[model_name]
32
+ model_url = self.model_info["model_url"]
33
+
34
+ if model_name == "CLIP-ViT Base":
35
+ self.model = CLIPModel.from_pretrained(model_url)
36
+ self.processor = CLIPProcessor.from_pretrained(model_url)
37
+
38
+ elif model_name == "ViT Base":
39
+ self.model = ViTModel.from_pretrained(model_url)
40
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
41
+
42
+ elif model_name == "DINO ViT-S16":
43
+ self.model = ViTModel.from_pretrained(model_url)
44
+ self.feature_extractor = ViTImageProcessor.from_pretrained(model_url)
45
+
46
+ def get_image_embedding(self, image_path):
47
+ image = Image.open(image_path)
48
+
49
+ # Process image according to the model
50
+ if hasattr(self, 'processor'): # CLIP models
51
+ inputs = self.processor(images=image, return_tensors="pt")
52
+ outputs = self.model.get_image_features(**inputs)
53
+
54
+ elif hasattr(self, 'feature_extractor'): # ViT models
55
+ inputs = self.feature_extractor(images=image, return_tensors="pt")
56
+ outputs = self.model(**inputs).last_hidden_state
57
+
58
+ return outputs
59
+
60
+ def similarity_score(self, image_path_1, image_path_2):
61
+ embedding1 = self.get_image_embedding(image_path_1).reshape(1, -1)
62
+ embedding2 = self.get_image_embedding(image_path_2).reshape(1, -1)
63
+ similarity = cosine_similarity(embedding1.detach().numpy(), embedding2.detach().numpy())
64
+ return similarity[0][0]