qwen2.5-VL-api / main.py
danilohssantana's picture
fixing image loading
a2b6d64
raw
history blame
3.14 kB
import base64
from io import BytesIO
import torch
from fastapi import FastAPI, Query
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
app = FastAPI()
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
# attn_implementation="flash_attention_2",
)
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
def encode_image(image_path, max_size=(800, 800), quality=85):
"""
Converts an image from a local file path to a Base64-encoded string with optimized size.
Args:
image_path (str): The path to the image file.
max_size (tuple): The maximum width and height of the resized image.
quality (int): The compression quality (1-100, higher means better quality but bigger size).
Returns:
str: Base64-encoded representation of the optimized image.
"""
try:
with Image.open(image_path) as img:
# Convert to RGB (avoid issues with PNG transparency)
img = img.convert("RGB")
# Resize while maintaining aspect ratio
img.thumbnail(max_size, Image.LANCZOS)
# Save to buffer with compression
buffer = BytesIO()
img.save(
buffer, format="JPEG", quality=quality
) # Save as JPEG to reduce size
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except Exception as e:
print(f"❌ Error encoding image {image_path}: {e}")
return None
@app.get("/predict")
def predict(image_url: str = Query(...), prompt: str = Query(...)):
image = encode_image(image_url)
messages = [
{
"role": "system",
"content": "You are a helpful assistant with vision abilities.",
},
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{image}"},
{"type": "text", "text": prompt},
],
},
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_texts = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return {"response": output_texts[0]}