Spaces:
Running
Running
File size: 7,417 Bytes
40415b8 a2b6d64 40415b8 d0e2871 40415b8 d0e2871 40415b8 a2b6d64 40415b8 a2b6d64 bcf356c 2044733 40415b8 20074fe 40415b8 bcf356c 3a8bfcd 40415b8 2044733 40415b8 bcf356c 6340216 bcf356c 40415b8 cf4cb00 2044733 40415b8 3a8bfcd 2044733 3a8bfcd 40415b8 62d1e32 3a8bfcd 2a7381c 3a8bfcd 47f5422 a2b6d64 47f5422 a2b6d64 47f5422 a2b6d64 47f5422 a2b6d64 47f5422 a2b6d64 7c13927 40415b8 d7e7825 40415b8 d7e7825 d0e2871 263b331 40415b8 98d8559 40415b8 3a8bfcd 2bebf02 98d8559 40415b8 2bebf02 52c90eb 2bebf02 40415b8 297d0ae 40415b8 3a8bfcd 1af9e28 d7e7825 d0e2871 263b331 40415b8 d7e7825 a2b6d64 40415b8 a2b6d64 d7e7825 40415b8 3a8bfcd 40415b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#!/usr/bin/env python3
# -- coding: utf-8 --
import base64
import json
import logging
import os
import time
import uuid
from io import BytesIO
import torch
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.staticfiles import StaticFiles
from PIL import Image
from pydantic import BaseModel
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
# Create the temporary folder if it doesn't exist.
TEMP_DIR = "/temp"
os.makedirs(TEMP_DIR, exist_ok=True)
app = FastAPI()
# Mount the temporary folder so annotated images can be served at /temp/<filename>
app.mount("/temp", StaticFiles(directory=TEMP_DIR), name="temp")
# Define the request model
class PredictRequest(BaseModel):
image_base64: list[str]
prompt: str
# Use the desired checkpoint: Qwen/Qwen2.5-VL-3B-Instruct-AWQ
checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct-AWQ"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
# Load the processor with the image resolution settings
processor = AutoProcessor.from_pretrained(
checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
)
# Load the Qwen2.5-VL model.
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype="auto",
device_map="auto",
# attn_implementation="flash_attention_2",
)
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85):
"""
Converts an image from file data to a Base64-encoded string with optimized size.
"""
try:
with Image.open(image_data) as img:
img = img.convert("RGB")
img.thumbnail(max_size, Image.LANCZOS)
buffer = BytesIO()
img.save(buffer, format="JPEG", quality=quality)
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error encoding image: {e}")
@app.post("/encode-image/")
async def upload_and_encode_image(file: UploadFile = File(...)):
"""
Endpoint to upload an image file and return its Base64-encoded representation.
"""
try:
image_data = BytesIO(await file.read())
encoded_string = encode_image(image_data)
return {"filename": file.filename, "encoded_image": encoded_string}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid file: {e}")
@app.post("/predict")
def predict(data: PredictRequest, annotate: bool = False):
"""
Generates a description (e.g. bounding boxes with labels) for image(s) using Qwen2.5-VL-3B-Instruct-AWQ.
If 'annotate' is True (as a query parameter), the first image is annotated with the predicted bounding boxes,
stored in a temporary folder, and its URL is returned.
Request:
- image_base64: List of base64-encoded images.
- prompt: A prompt string.
Response (JSON):
{
"response": <text generated by Qwen2.5-VL>,
"annotated_image_url": "/temp/<filename>" # only if annotate=True
}
"""
logging.warning("Calling /predict endpoint...")
# Ensure image_base64 is a list.
image_list = data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64]
# Create input messages: include all images and then the prompt.
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{image}"}
for image in image_list
] + [{"type": "text", "text": data.prompt}],
}
]
logging.info("Processing inputs... Number of images: %d", len(image_list))
# Prepare inputs for the model using the processor's chat interface.
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
logging.warning("Starting generation...")
start_time = time.time()
# Generate output using the model.
generated_ids = model.generate(**inputs, max_new_tokens=2056)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
generation_time = time.time() - start_time
logging.warning("Generation completed in %.2fs.", generation_time)
# The generated output text is expected to be JSON (e.g., list of detections).
result_text = output_text[0] if output_text else "No description generated."
response_data = {"response": result_text}
if annotate:
# Decode the first image for annotation.
try:
img_str = image_list[0]
# If the image string contains a data URI prefix, remove it.
if img_str.startswith("data:image"):
img_str = img_str.split(",")[1]
img_data = base64.b64decode(img_str)
image = Image.open(BytesIO(img_data))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error decoding image for annotation: {e}")
# Determine image dimensions (width, height)
input_wh = image.size
resolution_wh = input_wh # Assuming no resolution change
# Parse the detection result from the model output.
try:
detection_result = json.loads(result_text)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error parsing detection result: {e}")
# Use the supervision library to create detections and annotate the image.
try:
import supervision as sv
detections = sv.Detections.from_vlm(
vlm=sv.VLM.QWEN_2_5_VL,
result=detection_result,
input_wh=input_wh,
resolution_wh=resolution_wh
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating detections: {e}")
try:
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
annotated_image = image.copy()
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error annotating image: {e}")
# Save the annotated image in the temporary folder.
try:
filename = f"{uuid.uuid4()}.jpg"
filepath = os.path.join(TEMP_DIR, filename)
annotated_image.save(filepath, format="JPEG")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving annotated image: {e}")
# Add the annotated image URL to the response.
response_data["annotated_image_url"] = f"/temp/{filename}"
return response_data
|