import torch from typing import Tuple, Dict, Any from transformers import AutoModelForCausalLM, AutoProcessor from unittest.mock import patch from PIL import Image from utils.imports import fixed_get_imports CHECKPOINTS = [ "microsoft/Florence-2-large-ft", "microsoft/Florence-2-large", "microsoft/Florence-2-base-ft", "microsoft/Florence-2-base", ] def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]: with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): models = {} processors = {} for checkpoint in CHECKPOINTS: models[checkpoint] = AutoModelForCausalLM.from_pretrained( checkpoint, trust_remote_code=True).to(device).eval() processors[checkpoint] = AutoProcessor.from_pretrained( checkpoint, trust_remote_code=True) return models, processors def run_inference( model: Any, processor: Any, device: torch.device, image: Image, task: str, text: str = "" ) -> Tuple[str, Dict]: prompt = task + text inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3 ) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=False)[0] response = processor.post_process_generation( generated_text, task=task, image_size=image.size) return generated_text, response