from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel import base64 import io from PIL import Image import torch from ultralytics import YOLO from transformers import AutoProcessor, AutoModelForCausalLM import os # Import utility functions from utils import check_ocr_box, get_som_labeled_img # Initialize models and processor try: yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda") except Exception as e: raise RuntimeError(f"Error loading YOLO model: {e}") processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) try: model = AutoModelForCausalLM.from_pretrained( "weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True ).to("cuda") except Exception as e: raise RuntimeError(f"Error loading captioning model: {e}") caption_model_processor = {"processor": processor, "model": model} # FastAPI app initialization app = FastAPI() class ProcessResponse(BaseModel): image: str # Base64 encoded image parsed_content_list: str label_coordinates: str def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse: image_save_path = "imgs/saved_image_demo.png" image_input.save(image_save_path) # Image processing and OCR ocr_bbox_rslt, _ = check_ocr_box( image_save_path, display_img=False, output_bb_format="xyxy", use_paddleocr=True ) text, ocr_bbox = ocr_bbox_rslt # Labeling the image with YOLO and captioning dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_save_path, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, ) # Convert labeled image to base64 image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img))) buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") parsed_content_str = "\n".join(parsed_content_list) return ProcessResponse( image=img_str, parsed_content_list=parsed_content_str, label_coordinates=str(label_coordinates), ) @app.post("/process_image", response_model=ProcessResponse) async def process_image( image_file: UploadFile = File(...), box_threshold: float = 0.05, iou_threshold: float = 0.1, ): try: contents = await image_file.read() image_input = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail="Invalid image file") return process(image_input, box_threshold, iou_threshold)