File size: 6,536 Bytes
d26a895
 
 
426785e
d26a895
 
 
 
426785e
 
 
 
d26a895
 
 
 
 
 
 
 
 
 
7a7f986
 
d26a895
 
 
 
 
 
 
 
 
 
 
 
 
d318463
 
 
426785e
d26a895
 
 
 
 
 
 
 
426785e
 
 
d318463
426785e
 
 
 
d318463
426785e
d26a895
426785e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d26a895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426785e
d26a895
 
 
 
 
 
 
 
 
 
 
 
 
 
426785e
d26a895
 
 
 
 
 
 
 
 
426785e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import base64
import torch
from typing import Dict, List, Any
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from torch.nn.functional import cosine_similarity
from typing import Union

max_text_list_length = 30
max_image_list_length = 20

class EndpointHandler():
    def __init__(self, path: str="", image_size: int=224) -> None:
        """
        Initialize the EndpointHandler with a given model path and image size.

        Args:
            path (str, optional): Path to the pretrained model. Defaults to an empty string.
            image_size (int, optional): The size of the images to be processed. Defaults to 224.
        """
        self.model = CLIPModel.from_pretrained("Superlore/clip-vit-large-patch14")
        self.processor = CLIPProcessor.from_pretrained("Superlore/clip-vit-large-patch14")
        self.image_transform = Compose([
            Resize(image_size, interpolation=3),
            CenterCrop(image_size),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def __call__(self, data: Dict[str, Any]) -> Dict[str, list]:
        """
        Process input data containing image and text lists, computing image and text embeddings,
        and, if both image and text lists are provided, calculate similarity scores between them.
        
        Args:
            data (Dict[str, Any]): A dictionary containing the following key:
                - "inputs" (Dict[str, list]): A dictionary containing the following keys:
                    - "image_list" (List[str]): A list of base64-encoded images.
                    - "text_list" (Union[List[str], str]): A list of text strings.

        Returns:
            Dict[str, list]: A dictionary containing the following keys:
                - "image_features" (List[List[float]]): A list of image embeddings.
                - "text_features" (List[List[float]]): A list of text embeddings.
                - "similarity_scores" (List[List[float]]): A list of similarity scores between image and text embeddings.
                                                        Empty if either "image_list" or "text_list" is empty.
        """
        if not isinstance(data, dict):
            raise ValueError("Expected input data to be a dict.")
        
        inputs = data.get("inputs", {})

        if not isinstance(inputs, dict):
            raise ValueError("Expected 'inputs' to be a dict.")
        
        image_list = inputs.get("image_list", []) # list of b64 images
        text_list = inputs.get("text_list", []) # list of texts (or just plain string)

        if not isinstance(image_list, list):
            raise ValueError("Expected 'image_list' to be a list.")
        if not isinstance(text_list, list) and not isinstance(text_list, str):
            raise ValueError("Expected 'text_list' to be a list or string.")
        if not all(isinstance(image, str) for image in image_list):
            raise ValueError("Expected 'image_list' to contain only strings.")
        if isinstance(text_list, list) and not all(isinstance(text, str) for text in text_list):
            raise ValueError("Expected 'text_list' to contain only strings.")
        
        # if text_list is a string, convert to list
        if isinstance(text_list, str):
            text_list = [text_list]

        if len(image_list) > max_image_list_length:
            raise ValueError(f"Expected 'image_list' to have a maximum length of {max_image_list_length}.")
        if len(text_list) > max_text_list_length:
            raise ValueError(f"Expected 'text_list' to have a maximum length of {max_text_list_length}.")
        if not all(is_valid_base64_image(image) for image in image_list):
            raise ValueError("Expected 'image_list' to contain only valid base64-encoded images.")
    
        image_features = self.get_image_embeddings(image_list) if len(image_list) > 0 else None
        text_features = self.get_text_embeddings(text_list) if len(text_list) > 0 else None

        result = {
            "image_features": image_features.tolist() if image_features is not None else [],
            "text_features": text_features.tolist() if text_features is not None else [],
            "similarity_scores": []
        }
        # if image_features & text_features, compute similarity
        if image_features is not None and text_features is not None:
            similarity_scores = [cosine_similarity(img_feat, text_features) for img_feat in image_features]
            result["similarity_scores"] = [t.tolist() for t in similarity_scores]
            
        return result

    def preprocess_images(self, base64_images: List[str]) -> torch.Tensor:
        """Loads a list of images and applies preprocessing steps."""
        preprocessed_images = []
        for base64_image in base64_images:
            # Decode the base64-encoded image and convert it to an RGB image
            image_data = base64.b64decode(base64_image)
            image = Image.open(BytesIO(image_data)).convert("RGB")
            preprocessed_image = self.image_transform(image).unsqueeze(0)
            preprocessed_images.append(preprocessed_image)

        return torch.cat(preprocessed_images, dim=0)

    def get_image_embeddings(self, base64_images: List[str]) -> torch.Tensor:
        image_tensors = self.preprocess_images(base64_images)

        with torch.no_grad():
            self.model.eval()
            image_features = self.model.get_image_features(pixel_values=image_tensors)

        return image_features

    def get_text_embeddings(self, text_list: Union[List[str], str]) -> torch.Tensor:
        with torch.no_grad():
            # Tokenize the input text list
            input_tokens = self.processor(text_list, return_tensors="pt", padding=True, truncation=True)

            # Generate the embeddings for the text list
            self.model.eval()
            text_features = self.model.get_text_features(**input_tokens)
        return text_features
        
    
def is_valid_base64_image(data: str) -> bool:
    try:
        # Decode the base64 string
        img_data = base64.b64decode(data)
        
        # Open the image using PIL
        img = Image.open(BytesIO(img_data))
        
        # Check that the image format is supported
        img.verify()
        
        return True
    except:
        return False