File size: 4,357 Bytes
d4b7928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

from ts.torch_handler.base_handler import BaseHandler
from transformers import AutoModel, AutoProcessor, AutoTokenizer
import torch
from PIL import Image
import requests
from io import BytesIO

import logging
import os

import transformers
from jina_clip_implementation import modeling_clip, configuration_clip

import numpy as np
from time import time

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)

class JinaClipHandler(BaseHandler):
    """
    A custom model handler implementation.
    """

    def __init__(self):
        super(JinaClipHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        """ Loads the model.pt file and initializes the model object.
        Instantiates Tokenizer for preprocessor to use
        Loads labels to name mapping file for post-processing inference response
        """
        self.manifest = ctx.manifest
        logger.info("ctx manifest: " + str(self.manifest))

        properties = ctx.system_properties
        logger.info("ctx properties: " + str(properties))
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")


        # Read model serialize/pt file
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt or pytorch_model.bin file")
        
        # Load model from config.json path
        # self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
        # self.model = AutoModel.from_pretrained(model_dir, local_files_only=True)
        self.model_config = configuration_clip.JinaCLIPConfig()
        self.model = modeling_clip.JinaCLIPModel(self.model_config)
        self.model = torch.load(model_pt_path)
        self.model.to(self.device)
        self.model.eval()
        logger.debug('Transformer model from path {0} loaded successfully'.format(model_pt_path))

        self.initialized = True

    def preprocess(self, data):
        data = data[0]
        texts = data.get("texts", [])
        texts = [texts] if isinstance(texts, str) else texts
        image_urls = data.get("image_urls", [])
        image_base64 = data.get("image_base64", [])
        image_urls = [image_urls] if isinstance(image_urls, str) else image_urls

        if not texts and not image_urls:
            raise ValueError("Missing 'texts' and/or 'image_urls' in the request.")

        images = []
        if image_urls:
            for url in image_urls:
                try:
                    response = requests.get(url, stream=True)
                    response.raise_for_status()
                    image = Image.open(BytesIO(response.content)).convert("RGB")
                    images.append(image)
                except Exception as e:
                    raise ValueError(f"Error loading image from URL: {url}. Error: {e}")

            return texts, image_urls
        if image_base64:
            return texts, image_base64

    def inference(self, model_input):
        res = {"text_embeddings": [], "image_embeddings": []}

        texts, images = model_input
        with torch.no_grad():
            if texts:
                res['text_embeddings'] = self.model.encode_text(texts)
            if images:
                res['image_embeddings'] = self.model.encode_image(images)
        return res

    def postprocess(self, inference_output):
        for k, v in inference_output.items():
            if len(v) > 0:
                inference_output[k] = [i.tolist() for i in v]
        return [inference_output]

    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """

        model_input = self.preprocess(data)
        model_output = self.inference(model_input)
        return self.postprocess(model_output)