0xnewton-superlore commited on
Commit
d26a895
·
1 Parent(s): 93cf9b4

adds handler.py for custom inference

Browse files
Files changed (1) hide show
  1. handler.py +94 -0
handler.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import torch
4
+ from typing import Dict, List, Any
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
+ from PIL import Image
8
+ from torch.nn.functional import cosine_similarity
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path: str="", image_size: int=224) -> None:
12
+ """
13
+ Initialize the EndpointHandler with a given model path and image size.
14
+
15
+ Args:
16
+ path (str, optional): Path to the pretrained model. Defaults to an empty string.
17
+ image_size (int, optional): The size of the images to be processed. Defaults to 224.
18
+ """
19
+ self.model = CLIPModel.from_pretrained("SuperloreAI/clip-vit-large-patch14")
20
+ self.processor = CLIPProcessor.from_pretrained("SuperloreAI/clip-vit-large-patch14")
21
+ self.image_transform = Compose([
22
+ Resize(image_size, interpolation=3),
23
+ CenterCrop(image_size),
24
+ ToTensor(),
25
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
26
+ ])
27
+
28
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, list]:
29
+ """
30
+ Process input data containing image and text lists, computing image and text embeddings,
31
+ and, if both image and text lists are provided, calculate similarity scores between them.
32
+
33
+ Args:
34
+ data (Dict[str, Any]): A dictionary containing the following keys:
35
+ - "image_list" (List[str]): A list of base64-encoded images.
36
+ - "text_list" (List[str]): A list of text strings.
37
+
38
+ Returns:
39
+ Dict[str, list]: A dictionary containing the following keys:
40
+ - "image_features" (List[List[float]]): A list of image embeddings.
41
+ - "text_features" (List[List[float]]): A list of text embeddings.
42
+ - "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings.
43
+ Empty if either "image_list" or "text_list" is empty.
44
+ """
45
+ image_list = data.get("image_list", []) # list of b64 images
46
+ text_list = data.get("text_list", []) # list of texts
47
+
48
+ image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None
49
+ text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None
50
+
51
+ result = {
52
+ "image_features": image_features.tolist() if image_features is not None else [],
53
+ "text_features": text_features.tolist() if text_features is not None else [],
54
+ "similarity_scores": []
55
+ }
56
+ # if image_features & text_features, compute similarity
57
+ if image_features is not None and text_features is not None:
58
+ similarity_scores = [cosine_similarity(img_feat, text_features) for img_feat in image_features]
59
+ result["similarity_scores"] = [t.tolist() for t in similarity_scores]
60
+
61
+ return result
62
+
63
+ def preprocess_images(self, base64_images: List[str]) -> torch.Tensor:
64
+ """Loads a list of images and applies preprocessing steps."""
65
+ preprocessed_images = []
66
+ for base64_image in base64_images:
67
+ # Decode the base64-encoded image and convert it to an RGB image
68
+ image_data = base64.b64decode(base64_image)
69
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
70
+ preprocessed_image = self.image_transform(image).unsqueeze(0)
71
+ preprocessed_images.append(preprocessed_image)
72
+
73
+ return torch.cat(preprocessed_images, dim=0)
74
+
75
+ def get_image_embeddings(self, base64_images: List[str]) -> torch.Tensor:
76
+ image_tensors = self.preprocess_images(base64_images)
77
+
78
+ with torch.no_grad():
79
+ self.model.eval()
80
+ image_features = self.model.get_image_features(pixel_values=image_tensors)
81
+
82
+ return image_features
83
+
84
+ def get_text_embeddings(self, text_list: List[str]) -> torch.Tensor:
85
+ with torch.no_grad():
86
+ # Tokenize the input text list
87
+ input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True)
88
+
89
+ # Generate the embeddings for the text list
90
+ self.model.eval()
91
+ text_features = self.model.get_text_features(**input_tokens)
92
+ return text_features
93
+
94
+