jina-bert-flash-implementation / jina_clip_handler.py
vikho's picture
Upload 13 files
d4b7928 verified
from ts.torch_handler.base_handler import BaseHandler
from transformers import AutoModel, AutoProcessor, AutoTokenizer
import torch
from PIL import Image
import requests
from io import BytesIO
import logging
import os
import transformers
from jina_clip_implementation import modeling_clip, configuration_clip
import numpy as np
from time import time
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)
class JinaClipHandler(BaseHandler):
"""
A custom model handler implementation.
"""
def __init__(self):
super(JinaClipHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
""" Loads the model.pt file and initializes the model object.
Instantiates Tokenizer for preprocessor to use
Loads labels to name mapping file for post-processing inference response
"""
self.manifest = ctx.manifest
logger.info("ctx manifest: " + str(self.manifest))
properties = ctx.system_properties
logger.info("ctx properties: " + str(properties))
model_dir = properties.get("model_dir")
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")
# Read model serialize/pt file
serialized_file = self.manifest["model"]["serializedFile"]
model_pt_path = os.path.join(model_dir, serialized_file)
if not os.path.isfile(model_pt_path):
raise RuntimeError("Missing the model.pt or pytorch_model.bin file")
# Load model from config.json path
# self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
# self.model = AutoModel.from_pretrained(model_dir, local_files_only=True)
self.model_config = configuration_clip.JinaCLIPConfig()
self.model = modeling_clip.JinaCLIPModel(self.model_config)
self.model = torch.load(model_pt_path)
self.model.to(self.device)
self.model.eval()
logger.debug('Transformer model from path {0} loaded successfully'.format(model_pt_path))
self.initialized = True
def preprocess(self, data):
data = data[0]
texts = data.get("texts", [])
texts = [texts] if isinstance(texts, str) else texts
image_urls = data.get("image_urls", [])
image_base64 = data.get("image_base64", [])
image_urls = [image_urls] if isinstance(image_urls, str) else image_urls
if not texts and not image_urls:
raise ValueError("Missing 'texts' and/or 'image_urls' in the request.")
images = []
if image_urls:
for url in image_urls:
try:
response = requests.get(url, stream=True)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
images.append(image)
except Exception as e:
raise ValueError(f"Error loading image from URL: {url}. Error: {e}")
return texts, image_urls
if image_base64:
return texts, image_base64
def inference(self, model_input):
res = {"text_embeddings": [], "image_embeddings": []}
texts, images = model_input
with torch.no_grad():
if texts:
res['text_embeddings'] = self.model.encode_text(texts)
if images:
res['image_embeddings'] = self.model.encode_image(images)
return res
def postprocess(self, inference_output):
for k, v in inference_output.items():
if len(v) > 0:
inference_output[k] = [i.tolist() for i in v]
return [inference_output]
def handle(self, data, context):
"""
Invoke by TorchServe for prediction request.
Do pre-processing of data, prediction using model and postprocessing of prediciton output
:param data: Input data for prediction
:param context: Initial context contains model server system properties.
:return: prediction output
"""
model_input = self.preprocess(data)
model_output = self.inference(model_input)
return self.postprocess(model_output)