from typing import Dict, List, Any from transformers import BlipProcessor, BlipForConditionalGeneration from PIL import Image import requests import torch class EndpointHandler(): def __init__(self, path="./"): # Load the processor and model, and move to CUDA if available self.processor = BlipProcessor.from_pretrained(path) self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: image_url (:obj: `str`): URL of the image to caption prompt (:obj: `str`, optional): Text prompt for conditional captioning Return: A :obj:`list` with caption as `dict` """ # Get inputs from the data image_url = data.get("image_url") prompt = data.get("prompt", "") # Optional prompt for conditional captioning # Load image from URL and ensure RGB format image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") # Conditional or Unconditional Captioning if prompt: # Conditional captioning inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device) else: # Unconditional captioning inputs = self.processor(image, return_tensors="pt").to(self.model.device) # Generate caption out = self.model.generate(**inputs) caption = self.processor.decode(out[0], skip_special_tokens=True) # Return the generated caption return [{"caption": caption}]