qwen2.5-VL-api / main.py
danilohssantana's picture
fixing name model
077ebc4
raw
history blame
4.81 kB
import base64
from io import BytesIO
import torch
from fastapi import FastAPI, Query
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
app = FastAPI()
checkpoint = "Qwen/Qwen2-VL-7B-Instruct"
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor = AutoProcessor.from_pretrained(
checkpoint, min_pixels=min_pixels, max_pixels=max_pixels
)
model = Qwen2VLForConditionalGeneration.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
# attn_implementation="flash_attention_2",
)
@app.get("/")
def read_root():
return {"message": "API is live. Use the /predict endpoint."}
def encode_image(image_path, max_size=(800, 800), quality=85):
"""
Converts an image from a local file path to a Base64-encoded string with optimized size.
Args:
image_path (str): The path to the image file.
max_size (tuple): The maximum width and height of the resized image.
quality (int): The compression quality (1-100, higher means better quality but bigger size).
Returns:
str: Base64-encoded representation of the optimized image.
"""
try:
with Image.open(image_path) as img:
# Convert to RGB (avoid issues with PNG transparency)
img = img.convert("RGB")
# Resize while maintaining aspect ratio
img.thumbnail(max_size, Image.LANCZOS)
# Save to buffer with compression
buffer = BytesIO()
img.save(
buffer, format="JPEG", quality=quality
) # Save as JPEG to reduce size
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except Exception as e:
print(f"❌ Error encoding image {image_path}: {e}")
return None
@app.get("/predict")
def describe_image_with_qwen2_vl(image_url: str = Query(...), prompt: str = Query(...)):
"""
Generates a description for an image using the Qwen-2-VL model.
Args:
image_url (str): The URL of the image to describe.
prompt (str): The text prompt to guide the model's response.
Returns:
str: The generated description of the image.
"""
image = encode_image(image_url)
# Create the input message structure
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{image}"},
{"type": "text", "text": prompt},
],
}
]
# Prepare inputs for the model
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda:0")
# Generate the output
generated_ids = model.generate(**inputs, max_new_tokens=2056)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return {"response": output_text[0] if output_text else "No description generated."}
# @app.get("/predict")
# def predict(image_url: str = Query(...), prompt: str = Query(...)):
# image = encode_image(image_url)
# messages = [
# {
# "role": "system",
# "content": "You are a helpful assistant with vision abilities.",
# },
# {
# "role": "user",
# "content": [
# {"type": "image", "image": f"data:image;base64,{image}"},
# {"type": "text", "text": prompt},
# ],
# },
# ]
# text = processor.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
# image_inputs, video_inputs = process_vision_info(messages)
# inputs = processor(
# text=[text],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt",
# ).to(model.device)
# with torch.no_grad():
# generated_ids = model.generate(**inputs, max_new_tokens=128)
# generated_ids_trimmed = [
# out_ids[len(in_ids) :]
# for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]
# output_texts = processor.batch_decode(
# generated_ids_trimmed,
# skip_special_tokens=True,
# clean_up_tokenization_spaces=False,
# )
# return {"response": output_texts[0]}