|
import torch |
|
from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
from PIL import Image |
|
import base64 |
|
import io |
|
|
|
class EndpointHandler(): |
|
def __init__(self, model_path=""): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
self.model = LlavaForConditionalGeneration.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" if torch.cuda.is_available() else None |
|
) |
|
self.model.eval() |
|
|
|
def __call__(self, data): |
|
|
|
prompt = data.get("prompt", "Generate a caption for this image.") |
|
image_b64 = data.get("image") |
|
if image_b64 is None: |
|
return {"error": "No image provided in the payload."} |
|
try: |
|
image_bytes = base64.b64decode(image_b64) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
except Exception as e: |
|
return {"error": f"Failed to decode image: {str(e)}"} |
|
|
|
|
|
conversation = [ |
|
{"role": "system", "content": "You are a helpful image captioner."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
convo_string = self.processor.apply_chat_template( |
|
conversation, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
if not isinstance(convo_string, str): |
|
return {"error": "Failed to create conversation string."} |
|
|
|
|
|
inputs = self.processor( |
|
text=[convo_string], |
|
images=[image], |
|
return_tensors="pt" |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
if "pixel_values" in inputs: |
|
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) |
|
|
|
|
|
generate_ids = self.model.generate( |
|
**inputs, |
|
max_new_tokens=300, |
|
do_sample=True, |
|
temperature=0.6, |
|
top_p=0.9 |
|
)[0] |
|
|
|
|
|
generate_ids = generate_ids[inputs["input_ids"].shape[1]:] |
|
|
|
caption = self.processor.tokenizer.decode( |
|
generate_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False |
|
).strip() |
|
|
|
return {"caption": caption} |
|
|