import base64 import logging import time from io import BytesIO import torch from fastapi import Body, FastAPI, File, HTTPException, Query, UploadFile from PIL import Image from pydantic import BaseModel from qwen_vl_utils import process_vision_info from transformers import ( AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, ) app = FastAPI() # Define request model class PredictRequest(BaseModel): image_base64: list[str] prompt: str # checkpoint = "Qwen/Qwen2-VL-2B-Instruct" # min_pixels = 256 * 28 * 28 # max_pixels = 1280 * 28 * 28 # processor = AutoProcessor.from_pretrained( # checkpoint, min_pixels=min_pixels, max_pixels=max_pixels # ) # model = Qwen2VLForConditionalGeneration.from_pretrained( # checkpoint, # torch_dtype=torch.bfloat16, # device_map="auto", # # attn_implementation="flash_attention_2", # ) checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct" min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 processor = AutoProcessor.from_pretrained( checkpoint, min_pixels=min_pixels, max_pixels=max_pixels ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", # attn_implementation="flash_attention_2", ) @app.get("/") def read_root(): return {"message": "API is live. Use the /predict endpoint."} def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): """ Converts an image from file data to a Base64-encoded string with optimized size. """ try: with Image.open(image_data) as img: img = img.convert("RGB") img.thumbnail(max_size, Image.LANCZOS) buffer = BytesIO() img.save(buffer, format="JPEG", quality=quality) return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") @app.post("/encode-image/") async def upload_and_encode_image(file: UploadFile = File(...)): """ Endpoint to upload an image file and return its Base64-encoded representation. """ try: image_data = BytesIO(await file.read()) encoded_string = encode_image(image_data) return {"filename": file.filename, "encoded_image": encoded_string} except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid file: {e}") @app.post("/predict") def predict(data: PredictRequest): """ Generates a description for an image using the Qwen-2-VL model. Args: data (PredictRequest): The request containing encoded images and a prompt. Returns: dict: The generated description of the image(s). """ logging.warning("Calling /predict endpoint...") # Ensure image_base64 is a list (even if a single image is provided) image_list = ( data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64] ) # Create the input message structure with multiple images messages = [ { "role": "user", "content": [ {"type": "image", "image": f"data:image;base64,{image}"} for image in image_list ] + [{"type": "text", "text": data.prompt}], } ] logging.info("Processing inputs...", len(image_list)) # Prepare inputs for the model text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(model.device) logging.warning("Starting generation...") start_time = time.time() # Generate the output generated_ids = model.generate(**inputs, max_new_tokens=2056) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) logging.warning(f"Generation completed in {time.time() - start_time:.2f}s.") return {"response": output_text[0] if output_text else "No description generated."}