qwen2.5-VL-api / main.py
bla's picture
Update main.py
20074fe verified
raw
history blame
7.42 kB
#!/usr/bin/env python3
# -- coding: utf-8 --
import base64
import json
import logging
import os
import time
import uuid
from io import BytesIO
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.staticfiles import StaticFiles
from PIL import Image
from pydantic import BaseModel
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
# Create the temporary folder if it doesn't exist.
TEMP_DIR = "/temp"
os.makedirs(TEMP_DIR, exist_ok=True)
app = FastAPI()
# Mount the temporary folder so annotated images can be served at /temp/<filename>
app.mount("/temp", StaticFiles(directory=TEMP_DIR), name="temp")
# Define the request model
class PredictRequest(BaseModel):
image_base64: list[str]
prompt: str
# Use the desired checkpoint: Qwen/Qwen2.5-VL-3B-Instruct-AWQ
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
# Load the processor with the image resolution settings
processor = AutoProcessor.from_pretrained(
checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
)
# Load the Qwen2.5-VL model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype="auto",
device_map="auto",
# attn_implementation="flash_attention_2",
)
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
"""
Converts an image from file data to a Base64-encoded string with optimized size.
"""
try:
with Image.open(image_data) as img:
img = img.convert("RGB")
img.thumbnail(max_size, Image.LANCZOS)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=quality)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error encoding image: {e}")
@app.post("/encode-image/")
async def upload_and_encode_image(file: UploadFile = File(...)):
"""
Endpoint to upload an image file and return its Base64-encoded representation.
"""
try:
image_data = BytesIO(await file.read())
encoded_string = encode_image(image_data)
return {"filename": file.filename, "encoded_image": encoded_string}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
@app.post("/predict")
def predict(data: PredictRequest, annotate: bool = False):
"""
Generates a description (e.g. bounding boxes with labels) for image(s) using Qwen2.5-VL-3B-Instruct-AWQ.
If 'annotate' is True (as a query parameter), the first image is annotated with the predicted bounding boxes,
stored in a temporary folder, and its URL is returned.
Request:
- image_base64: List of base64-encoded images.
- prompt: A prompt string.
Response (JSON):
{
"response": <text generated by Qwen2.5-VL>,
"annotated_image_url": "/temp/<filename>" # only if annotate=True
}
"""
logging.warning("Calling /predict endpoint...")
# Ensure image_base64 is a list.
image_list = data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64]
# Create input messages: include all images and then the prompt.
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{image}"}
for image in image_list
] + [{"type": "text", "text": data.prompt}],
}
]
logging.info("Processing inputs... Number of images: %d", len(image_list))
# Prepare inputs for the model using the processor's chat interface.
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
logging.warning("Starting generation...")
start_time = time.time()
# Generate output using the model.
generated_ids = model.generate(**inputs, max_new_tokens=2056)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
generation_time = time.time() - start_time
logging.warning("Generation completed in %.2fs.", generation_time)
# The generated output text is expected to be JSON (e.g., list of detections).
result_text = output_text[0] if output_text else "No description generated."
response_data = {"response": result_text}
if annotate:
# Decode the first image for annotation.
try:
img_str = image_list[0]
# If the image string contains a data URI prefix, remove it.
if img_str.startswith("data:image"):
img_str = img_str.split(",")[1]
img_data = base64.b64decode(img_str)
image = Image.open(BytesIO(img_data))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error decoding image for annotation: {e}")
# Determine image dimensions (width, height)
input_wh = image.size
resolution_wh = input_wh # Assuming no resolution change
# Parse the detection result from the model output.
try:
detection_result = json.loads(result_text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error parsing detection result: {e}")
# Use the supervision library to create detections and annotate the image.
try:
import supervision as sv
detections = sv.Detections.from_vlm(
vlm=sv.VLM.QWEN_2_5_VL,
result=detection_result,
input_wh=input_wh,
resolution_wh=resolution_wh
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating detections: {e}")
try:
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_image = image.copy()
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error annotating image: {e}")
# Save the annotated image in the temporary folder.
try:
filename = f"{uuid.uuid4()}.jpg"
filepath = os.path.join(TEMP_DIR, filename)
annotated_image.save(filepath, format="JPEG")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving annotated image: {e}")
# Add the annotated image URL to the response.
response_data["annotated_image_url"] = f"/temp/{filename}"
return response_data