from ts.torch_handler.base_handler import BaseHandler from transformers import AutoModel, AutoProcessor, AutoTokenizer import torch from PIL import Image import requests from io import BytesIO import logging import os import transformers from jina_clip_implementation import modeling_clip, configuration_clip import numpy as np from time import time from ts.torch_handler.base_handler import BaseHandler logger = logging.getLogger(__name__) logger.info("Transformers version %s", transformers.__version__) class JinaClipHandler(BaseHandler): """ A custom model handler implementation. """ def __init__(self): super(JinaClipHandler, self).__init__() self.initialized = False def initialize(self, ctx): """ Loads the model.pt file and initializes the model object. Instantiates Tokenizer for preprocessor to use Loads labels to name mapping file for post-processing inference response """ self.manifest = ctx.manifest logger.info("ctx manifest: " + str(self.manifest)) properties = ctx.system_properties logger.info("ctx properties: " + str(properties)) model_dir = properties.get("model_dir") self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") # Read model serialize/pt file serialized_file = self.manifest["model"]["serializedFile"] model_pt_path = os.path.join(model_dir, serialized_file) if not os.path.isfile(model_pt_path): raise RuntimeError("Missing the model.pt or pytorch_model.bin file") # Load model from config.json path # self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) # self.model = AutoModel.from_pretrained(model_dir, local_files_only=True) self.model_config = configuration_clip.JinaCLIPConfig() self.model = modeling_clip.JinaCLIPModel(self.model_config) self.model = torch.load(model_pt_path) self.model.to(self.device) self.model.eval() logger.debug('Transformer model from path {0} loaded successfully'.format(model_pt_path)) self.initialized = True def preprocess(self, data): data = data[0] texts = data.get("texts", []) texts = [texts] if isinstance(texts, str) else texts image_urls = data.get("image_urls", []) image_base64 = data.get("image_base64", []) image_urls = [image_urls] if isinstance(image_urls, str) else image_urls if not texts and not image_urls: raise ValueError("Missing 'texts' and/or 'image_urls' in the request.") images = [] if image_urls: for url in image_urls: try: response = requests.get(url, stream=True) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") images.append(image) except Exception as e: raise ValueError(f"Error loading image from URL: {url}. Error: {e}") return texts, image_urls if image_base64: return texts, image_base64 def inference(self, model_input): res = {"text_embeddings": [], "image_embeddings": []} texts, images = model_input with torch.no_grad(): if texts: res['text_embeddings'] = self.model.encode_text(texts) if images: res['image_embeddings'] = self.model.encode_image(images) return res def postprocess(self, inference_output): for k, v in inference_output.items(): if len(v) > 0: inference_output[k] = [i.tolist() for i in v] return [inference_output] def handle(self, data, context): """ Invoke by TorchServe for prediction request. Do pre-processing of data, prediction using model and postprocessing of prediciton output :param data: Input data for prediction :param context: Initial context contains model server system properties. :return: prediction output """ model_input = self.preprocess(data) model_output = self.inference(model_input) return self.postprocess(model_output)