thisnick's picture
Upload full model folder with custom handler
c6aa952 verified
raw
history blame
2.27 kB
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import base64
import io
class EndpointHandler():
def __init__(self, model_path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = LlavaForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto" if torch.cuda.is_available() else None
)
self.model.eval()
def __call__(self, data):
# Expecting data with a "prompt" (text) and an "image" (base64 string)
prompt = data.get("prompt", "Generate a caption for this image.")
image_b64 = data.get("image")
if image_b64 is None:
return {"error": "No image provided in the payload."}
try:
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to decode image: {str(e)}"}
# Build the conversation template for captioning
conversation = [
{"role": "system", "content": "You are a helpful image captioner."},
{"role": "user", "content": prompt}
]
convo_string = self.processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
if not isinstance(convo_string, str):
return {"error": "Failed to create conversation string."}
# Prepare the inputs for the model
inputs = self.processor(
text=[convo_string],
images=[image],
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
# Generate caption tokens
generate_ids = self.model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.6,
top_p=0.9
)[0]
# Optionally, trim off the prompt tokens
generate_ids = generate_ids[inputs["input_ids"].shape[1]:]
caption = self.processor.tokenizer.decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
).strip()
return {"caption": caption}