|
|
|
from ts.torch_handler.base_handler import BaseHandler |
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
|
|
import logging |
|
import os |
|
|
|
import transformers |
|
from jina_clip_implementation import modeling_clip, configuration_clip |
|
|
|
import numpy as np |
|
from time import time |
|
|
|
from ts.torch_handler.base_handler import BaseHandler |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Transformers version %s", transformers.__version__) |
|
|
|
class JinaClipHandler(BaseHandler): |
|
""" |
|
A custom model handler implementation. |
|
""" |
|
|
|
def __init__(self): |
|
super(JinaClipHandler, self).__init__() |
|
self.initialized = False |
|
|
|
def initialize(self, ctx): |
|
""" Loads the model.pt file and initializes the model object. |
|
Instantiates Tokenizer for preprocessor to use |
|
Loads labels to name mapping file for post-processing inference response |
|
""" |
|
self.manifest = ctx.manifest |
|
logger.info("ctx manifest: " + str(self.manifest)) |
|
|
|
properties = ctx.system_properties |
|
logger.info("ctx properties: " + str(properties)) |
|
model_dir = properties.get("model_dir") |
|
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
serialized_file = self.manifest["model"]["serializedFile"] |
|
model_pt_path = os.path.join(model_dir, serialized_file) |
|
if not os.path.isfile(model_pt_path): |
|
raise RuntimeError("Missing the model.pt or pytorch_model.bin file") |
|
|
|
|
|
|
|
|
|
self.model_config = configuration_clip.JinaCLIPConfig() |
|
self.model = modeling_clip.JinaCLIPModel(self.model_config) |
|
self.model = torch.load(model_pt_path) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
logger.debug('Transformer model from path {0} loaded successfully'.format(model_pt_path)) |
|
|
|
self.initialized = True |
|
|
|
def preprocess(self, data): |
|
data = data[0] |
|
texts = data.get("texts", []) |
|
texts = [texts] if isinstance(texts, str) else texts |
|
image_urls = data.get("image_urls", []) |
|
image_base64 = data.get("image_base64", []) |
|
image_urls = [image_urls] if isinstance(image_urls, str) else image_urls |
|
|
|
if not texts and not image_urls: |
|
raise ValueError("Missing 'texts' and/or 'image_urls' in the request.") |
|
|
|
images = [] |
|
if image_urls: |
|
for url in image_urls: |
|
try: |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
images.append(image) |
|
except Exception as e: |
|
raise ValueError(f"Error loading image from URL: {url}. Error: {e}") |
|
|
|
return texts, image_urls |
|
if image_base64: |
|
return texts, image_base64 |
|
|
|
def inference(self, model_input): |
|
res = {"text_embeddings": [], "image_embeddings": []} |
|
|
|
texts, images = model_input |
|
with torch.no_grad(): |
|
if texts: |
|
res['text_embeddings'] = self.model.encode_text(texts) |
|
if images: |
|
res['image_embeddings'] = self.model.encode_image(images) |
|
return res |
|
|
|
def postprocess(self, inference_output): |
|
for k, v in inference_output.items(): |
|
if len(v) > 0: |
|
inference_output[k] = [i.tolist() for i in v] |
|
return [inference_output] |
|
|
|
def handle(self, data, context): |
|
""" |
|
Invoke by TorchServe for prediction request. |
|
Do pre-processing of data, prediction using model and postprocessing of prediciton output |
|
:param data: Input data for prediction |
|
:param context: Initial context contains model server system properties. |
|
:return: prediction output |
|
""" |
|
|
|
model_input = self.preprocess(data) |
|
model_output = self.inference(model_input) |
|
return self.postprocess(model_output) |
|
|