#!/usr/bin/env python3 # -- coding: utf-8 -- import base64 import json import logging import os import time import uuid from io import BytesIO import torch from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.staticfiles import StaticFiles from PIL import Image from pydantic import BaseModel from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration # Create the temporary folder if it doesn't exist. TEMP_DIR = "/temp" os.makedirs(TEMP_DIR, exist_ok=True) app = FastAPI() # Mount the temporary folder so annotated images can be served at /temp/ app.mount("/temp", StaticFiles(directory=TEMP_DIR), name="temp") # Define the request model class PredictRequest(BaseModel): image_base64: list[str] prompt: str # Use the desired checkpoint: Qwen/Qwen2.5-VL-3B-Instruct-AWQ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ" min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 # Load the processor with the image resolution settings processor = AutoProcessor.from_pretrained( checkpoint, min_pixels=min_pixels, max_pixels=max_pixels ) # Load the Qwen2.5-VL model. model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint, torch_dtype="auto", device_map="auto", # attn_implementation="flash_attention_2", ) @app.get("/") def read_root(): return {"message": "API is live. Use the /predict endpoint."} def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): """ Converts an image from file data to a Base64-encoded string with optimized size. """ try: with Image.open(image_data) as img: img = img.convert("RGB") img.thumbnail(max_size, Image.LANCZOS) buffer = BytesIO() img.save(buffer, format="JPEG", quality=quality) return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") @app.post("/encode-image/") async def upload_and_encode_image(file: UploadFile = File(...)): """ Endpoint to upload an image file and return its Base64-encoded representation. """ try: image_data = BytesIO(await file.read()) encoded_string = encode_image(image_data) return {"filename": file.filename, "encoded_image": encoded_string} except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid file: {e}") @app.post("/predict") def predict(data: PredictRequest, annotate: bool = False): """ Generates a description (e.g. bounding boxes with labels) for image(s) using Qwen2.5-VL-3B-Instruct-AWQ. If 'annotate' is True (as a query parameter), the first image is annotated with the predicted bounding boxes, stored in a temporary folder, and its URL is returned. Request: - image_base64: List of base64-encoded images. - prompt: A prompt string. Response (JSON): { "response": , "annotated_image_url": "/temp/" # only if annotate=True } """ logging.warning("Calling /predict endpoint...") # Ensure image_base64 is a list. image_list = data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64] # Create input messages: include all images and then the prompt. messages = [ { "role": "user", "content": [ {"type": "image", "image": f"data:image;base64,{image}"} for image in image_list ] + [{"type": "text", "text": data.prompt}], } ] logging.info("Processing inputs... Number of images: %d", len(image_list)) # Prepare inputs for the model using the processor's chat interface. text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(model.device) logging.warning("Starting generation...") start_time = time.time() # Generate output using the model. generated_ids = model.generate(**inputs, max_new_tokens=2056) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) generation_time = time.time() - start_time logging.warning("Generation completed in %.2fs.", generation_time) # The generated output text is expected to be JSON (e.g., list of detections). result_text = output_text[0] if output_text else "No description generated." response_data = {"response": result_text} if annotate: # Decode the first image for annotation. try: img_str = image_list[0] # If the image string contains a data URI prefix, remove it. if img_str.startswith("data:image"): img_str = img_str.split(",")[1] img_data = base64.b64decode(img_str) image = Image.open(BytesIO(img_data)) except Exception as e: raise HTTPException(status_code=500, detail=f"Error decoding image for annotation: {e}") # Determine image dimensions (width, height) input_wh = image.size resolution_wh = input_wh # Assuming no resolution change # Parse the detection result from the model output. try: detection_result = json.loads(result_text) except Exception as e: raise HTTPException(status_code=500, detail=f"Error parsing detection result: {e}") # Use the supervision library to create detections and annotate the image. try: import supervision as sv detections = sv.Detections.from_vlm( vlm=sv.VLM.QWEN_2_5_VL, result=detection_result, input_wh=input_wh, resolution_wh=resolution_wh ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error creating detections: {e}") try: box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX) label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) annotated_image = image.copy() annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections) annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) except Exception as e: raise HTTPException(status_code=500, detail=f"Error annotating image: {e}") # Save the annotated image in the temporary folder. try: filename = f"{uuid.uuid4()}.jpg" filepath = os.path.join(TEMP_DIR, filename) annotated_image.save(filepath, format="JPEG") except Exception as e: raise HTTPException(status_code=500, detail=f"Error saving annotated image: {e}") # Add the annotated image URL to the response. response_data["annotated_image_url"] = f"/temp/{filename}" return response_data