blip2-opt-2.7b-coco / handler.py
adasdimchom's picture
Upload handler.py
61b80cc
raw
history blame
2.52 kB
from transformers import Blip2Processor, Blip2Model, Blip2ForConditionalGeneration
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
class EndpointHandler():
def __init__(self, path=""):
"""
path:
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = Blip2Processor.from_pretrained(path)
self.generate_model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.generate_model.to(self.device)
self.feature_model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
self.feature_model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
result = {}
inputs = data.pop("inputs", data)
image_url = inputs['image_url']
if "prompt" in inputs:
prompt = inputs["prompt"]
else:
prompt = None
if "extract_feature" in inputs:
extract_feature = inputs["extract_feature"]
else:
extract_feature = False
image = Image.open(requests.get(image_url, stream=True).raw)
processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.generate_model.generate(**processed_image)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
result["image_caption"] = generated_text
if extract_feature:
caption_feature = self.feature_model(**processed_image)
result["caption_feature"] = caption_feature
if prompt:
prompt_image_processed = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.generate_model.generate(**prompt_image_processed)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
result["image_prompt"] = generated_text
pass
if extract_feature:
prompt_feature = self.feature_model(**prompt_image_processed)
result["prompt_feature"] = prompt_feature
return result