OmniPar / main.py
banao-tech's picture
Update main.py
d9070a8 verified
raw
history blame
2.88 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import base64
import io
from PIL import Image
import torch
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForCausalLM
import os
# Import utility functions
from utils import check_ocr_box, get_som_labeled_img
# Initialize models and processor
try:
yolo_model = YOLO("weights/icon_detect/best.pt").to("cuda")
except Exception as e:
raise RuntimeError(f"Error loading YOLO model: {e}")
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
try:
model = AutoModelForCausalLM.from_pretrained(
"weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True
).to("cuda")
except Exception as e:
raise RuntimeError(f"Error loading captioning model: {e}")
caption_model_processor = {"processor": processor, "model": model}
# FastAPI app initialization
app = FastAPI()
class ProcessResponse(BaseModel):
image: str # Base64 encoded image
parsed_content_list: str
label_coordinates: str
def process(image_input: Image.Image, box_threshold: float, iou_threshold: float) -> ProcessResponse:
image_save_path = "imgs/saved_image_demo.png"
image_input.save(image_save_path)
# Image processing and OCR
ocr_bbox_rslt, _ = check_ocr_box(
image_save_path, display_img=False, output_bb_format="xyxy", use_paddleocr=True
)
text, ocr_bbox = ocr_bbox_rslt
# Labeling the image with YOLO and captioning
dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_save_path,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox,
caption_model_processor=caption_model_processor,
ocr_text=text,
iou_threshold=iou_threshold,
)
# Convert labeled image to base64
image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
parsed_content_str = "\n".join(parsed_content_list)
return ProcessResponse(
image=img_str,
parsed_content_list=parsed_content_str,
label_coordinates=str(label_coordinates),
)
@app.post("/process_image", response_model=ProcessResponse)
async def process_image(
image_file: UploadFile = File(...),
box_threshold: float = 0.05,
iou_threshold: float = 0.1,
):
try:
contents = await image_file.read()
image_input = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid image file")
return process(image_input, box_threshold, iou_threshold)