File size: 5,104 Bytes
3f75720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from typing import Dict, Any, List
from PIL import Image
import base64
from io import BytesIO
import logging
from transformers import AutoImageProcessor, AutoModel
import os
from dataclasses import dataclass


# Define a dataclass for the results
@dataclass
class ImageEncodingResult:
    image_encoded: List[List[float]]  # Full encoded embeddings
    image_encoded_average: List[float]  # Average of the embeddings


class EndpointHandler:
    """
    A handler class for processing images and generating embeddings using a pre-trained model.
    Attributes:
        processor: The pre-trained image processor.
        model: The pre-trained model for generating embeddings.
        device: The device (CPU or CUDA) used to run model inference.
    """

    def __init__(self):
        """
        Initializes the EndpointHandler with the model and processor from the current directory.
        """
        # Initialize logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        # Determine the device (CPU or CUDA)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.logger.info(f"Using device: {self.device}")

        # Load the model and processor from the current directory
        self.logger.info("Loading model and processor from the current directory.")
        try:
            self.processor = AutoImageProcessor.from_pretrained(os.getcwd())
            self.model = AutoModel.from_pretrained(
                os.getcwd(), trust_remote_code=True
            ).to(self.device)
            self.logger.info("Model and processor loaded successfully.")
        except Exception as e:
            self.logger.error(f"Failed to load model or processor: {e}")
            raise

    def _resize_image_if_large(
        self, image: Image.Image, max_size: int = 1080
    ) -> Image.Image:
        """
        Resizes an image if its dimensions exceed the specified maximum size.
        Args:
            image (Image.Image): Input image.
            max_size (int): Maximum size for the image dimensions.
        Returns:
            Image.Image: Resized image.
        """
        width, height = image.size
        if width > max_size or height > max_size:
            scale = max_size / max(width, height)
            new_width = int(width * scale)
            new_height = int(height * scale)
            image = image.resize((new_width, new_height), resample=Image.BILINEAR)
        return image

    def _encode_image(self, image: Image.Image) -> ImageEncodingResult:
        """
        Encodes an image into embeddings using the model.
        Args:
            image (Image.Image): Input image.
        Returns:
            ImageEncodingResult: Dataclass containing the encoded embeddings and their average.
        """
        try:
            # Resize the image if necessary
            image = self._resize_image_if_large(image)

            # Process the image and generate embeddings
            inputs = self.processor(image, return_tensors="pt").to(self.device)
            with torch.inference_mode():
                outputs = self.model(**inputs)
                last_hidden_state = outputs.last_hidden_state
                image_encoded = last_hidden_state.squeeze().tolist()
                image_encoded_average = last_hidden_state.mean(dim=1).squeeze().tolist()

            return ImageEncodingResult(
                image_encoded=image_encoded,
                image_encoded_average=image_encoded_average,
            )
        except Exception as e:
            self.logger.error(f"Error encoding image: {e}")
            raise

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Processes input data containing base64-encoded images and generates embeddings.
        Args:
            data (Dict[str, Any]): Dictionary containing input images.
        Returns:
            Dict[str, Any]: Dictionary containing encoded embeddings or error messages.
        """
        images_data = data.get("images", [])

        if not images_data:
            return {"error": "No image data provided."}

        results = []
        for img_data in images_data:
            if isinstance(img_data, str):
                try:
                    # Decode the base64-encoded image
                    image_bytes = base64.b64decode(img_data)
                    image = Image.open(BytesIO(image_bytes)).convert("RGB")

                    # Encode the image
                    encoded_image = self._encode_image(image)
                    results.append(encoded_image)
                except Exception as e:
                    self.logger.error(f"Invalid image data: {e}")
                    return {"error": f"Invalid image data: {e}"}
            else:
                self.logger.error("Images should be base64-encoded strings.")
                return {"error": "Images should be base64-encoded strings."}

        # Convert the results to a dictionary for JSON serialization
        return {"results": [result.__dict__ for result in results]}