import base64 from io import BytesIO import torch from fastapi import Body, FastAPI, File, HTTPException, Query, UploadFile from PIL import Image from pydantic import BaseModel from qwen_vl_utils import process_vision_info from transformers import ( AutoProcessor, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, ) app = FastAPI() # Define request model class PredictRequest(BaseModel): image_base64: str prompt: str # checkpoint = "Qwen/Qwen2-VL-2B-Instruct" # min_pixels = 256 * 28 * 28 # max_pixels = 1280 * 28 * 28 # processor = AutoProcessor.from_pretrained( # checkpoint, min_pixels=min_pixels, max_pixels=max_pixels # ) # model = Qwen2VLForConditionalGeneration.from_pretrained( # checkpoint, # torch_dtype=torch.bfloat16, # device_map="auto", # # attn_implementation="flash_attention_2", # ) checkpoint = "Qwen/Qwen2.5-VL-7B-Instruct" min_pixels = 256 * 28 * 28 max_pixels = 1280 * 28 * 28 processor = AutoProcessor.from_pretrained( checkpoint, min_pixels=min_pixels, max_pixels=max_pixels ) model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", # attn_implementation="flash_attention_2", ) @app.get("/") def read_root(): return {"message": "API is live. Use the /predict endpoint."} # def encode_image(image_path, max_size=(800, 800), quality=85): # """ # Converts an image from a local file path to a Base64-encoded string with optimized size. # Args: # image_path (str): The path to the image file. # max_size (tuple): The maximum width and height of the resized image. # quality (int): The compression quality (1-100, higher means better quality but bigger size). # Returns: # str: Base64-encoded representation of the optimized image. # """ # try: # with Image.open(image_path) as img: # # Convert to RGB (avoid issues with PNG transparency) # img = img.convert("RGB") # # Resize while maintaining aspect ratio # img.thumbnail(max_size, Image.LANCZOS) # # Save to buffer with compression # buffer = BytesIO() # img.save( # buffer, format="JPEG", quality=quality # ) # Save as JPEG to reduce size # return base64.b64encode(buffer.getvalue()).decode("utf-8") # except Exception as e: # print(f"❌ Error encoding image {image_path}: {e}") # return None def encode_image(image_data: BytesIO, max_size=(800, 800), quality=85): """ Converts an image from file data to a Base64-encoded string with optimized size. """ try: with Image.open(image_data) as img: img = img.convert("RGB") img.thumbnail(max_size, Image.LANCZOS) buffer = BytesIO() img.save(buffer, format="JPEG", quality=quality) return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: raise HTTPException(status_code=500, detail=f"Error encoding image: {e}") @app.post("/encode-image/") async def upload_and_encode_image(file: UploadFile = File(...)): """ Endpoint to upload an image file and return its Base64-encoded representation. """ try: image_data = BytesIO(await file.read()) encoded_string = encode_image(image_data) return {"filename": file.filename, "encoded_image": encoded_string} except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid file: {e}") @app.post("/predict") def predict(data: PredictRequest): """ Generates a description for an image using the Qwen-2-VL model. Args: data (PredictRequest): The request containing encoded images and a prompt. Returns: dict: The generated description of the image(s). """ # Ensure image_base64 is a list (even if a single image is provided) image_list = ( data.image_base64 if isinstance(data.image_base64, list) else [data.image_base64] ) # Create the input message structure with multiple images messages = [ { "role": "user", "content": [ {"type": "image", "image": f"data:image;base64,{image}"} for image in image_list ] + [{"type": "text", "text": data.prompt}], } ] # Prepare inputs for the model text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(model.device) # Generate the output generated_ids = model.generate(**inputs, max_new_tokens=2056) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return {"response": output_text[0] if output_text else "No description generated."} # @app.get("/predict") # def predict(image_url: str = Query(...), prompt: str = Query(...)): # image = encode_image(image_url) # messages = [ # { # "role": "system", # "content": "You are a helpful assistant with vision abilities.", # }, # { # "role": "user", # "content": [ # {"type": "image", "image": f"data:image;base64,{image}"}, # {"type": "text", "text": prompt}, # ], # }, # ] # text = processor.apply_chat_template( # messages, tokenize=False, add_generation_prompt=True # ) # image_inputs, video_inputs = process_vision_info(messages) # inputs = processor( # text=[text], # images=image_inputs, # videos=video_inputs, # padding=True, # return_tensors="pt", # ).to(model.device) # with torch.no_grad(): # generated_ids = model.generate(**inputs, max_new_tokens=128) # generated_ids_trimmed = [ # out_ids[len(in_ids) :] # for in_ids, out_ids in zip(inputs.input_ids, generated_ids) # ] # output_texts = processor.batch_decode( # generated_ids_trimmed, # skip_special_tokens=True, # clean_up_tokenization_spaces=False, # ) # return {"response": output_texts[0]}