import torch from transformers import AutoProcessor, LlavaForConditionalGeneration from PIL import Image import base64 import io class EndpointHandler(): def __init__(self, model_path=""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = AutoProcessor.from_pretrained(model_path) self.model = LlavaForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map="auto" if torch.cuda.is_available() else None ) self.model.eval() def __call__(self, data): inputs = data.get("inputs", {}) prompt = inputs.get("prompt", "Generate a caption for this image.") images_b64 = inputs.get("images") # Handle both single image and list of images if isinstance(images_b64, str): images_b64 = [images_b64] if not images_b64: return {"error": "No images provided in the payload."} try: images = [ Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") for img_b64 in images_b64 ] except Exception as e: return {"error": f"Failed to decode image: {str(e)}"} # Build the conversation template for captioning conversation = [ {"role": "system", "content": "You are a helpful image captioner."}, {"role": "user", "content": prompt} ] convo_string = self.processor.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) if not isinstance(convo_string, str): return {"error": "Failed to create conversation string."} # Prepare the inputs for the model - process all images at once model_inputs = self.processor( text=[convo_string], images=images, return_tensors="pt" ) model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()} if "pixel_values" in model_inputs: model_inputs["pixel_values"] = model_inputs["pixel_values"].to(torch.bfloat16) # Generate caption tokens for all images at once generate_ids = self.model.generate( **model_inputs, max_new_tokens=300, do_sample=True, temperature=0.6, top_p=0.9 ) # Trim off the prompt tokens and decode all captions generate_ids = generate_ids[:, model_inputs["input_ids"].shape[1]:] captions = [ self.processor.tokenizer.decode( ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ).strip() for ids in generate_ids ] return {"captions": captions if len(captions) > 1 else captions[0]}