Qwen2-VL-2B-Instruct / handler.py
Gabriel's picture
Update handler.py
92feb47 verified
raw
history blame
2.05 kB
from typing import Dict, Any
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import io
import base64
import requests
import torch
class EndpointHandler():
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(path)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path, device_map="auto"
)
def __call__(self, data: Any) -> Dict[str, Any]:
image_input = data.get('image')
text_input = data.get('text', "Describe this image.")
if image_input is None:
return {"error": "No image provided."}
try:
if image_input.startswith('http'):
image = Image.open(requests.get(image_input, stream=True).raw).convert('RGB')
else:
image_data = base64.b64decode(image_input)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
except Exception as e:
return {"error": f"Failed to process the image. Details: {str(e)}"}
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text_input},
],
}
]
text_prompt = self.processor.apply_chat_template(
conversation, add_generation_prompt=True
)
inputs = self.processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
output_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids = [
output_id[len(input_id):] for input_id, output_id in zip(inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
return {"generated_text": output_text}