LPX55
fix: add missing import for infer_onnx_model in model_loader and update its definition in onnx_helpers
ceff40b
# Add missing import for infer_onnx_model | |
from utils.onnx_helpers import infer_onnx_model | |
# Add missing import for preprocess_onnx_input | |
from utils.onnx_helpers import preprocess_onnx_input | |
""" | |
Model loading and registration logic for OpenSight Deepfake Detection Playground. | |
Handles ONNX, HuggingFace, and Gradio API model registration and metadata. | |
""" | |
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry | |
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache | |
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
import torch | |
import numpy as np | |
from PIL import Image | |
# Cache for ONNX sessions and preprocessors | |
_onnx_model_cache = {} | |
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None): | |
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset) | |
MODEL_REGISTRY[model_id] = entry | |
class ONNXModelWrapper: | |
def __init__(self, hf_model_id): | |
self.hf_model_id = hf_model_id | |
self._session = None | |
self._preprocessor_config = None | |
self._model_config = None | |
def load(self): | |
if self._session is None: | |
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache( | |
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor | |
) | |
def __call__(self, image_np): | |
self.load() | |
return infer_onnx_model(self.hf_model_id, image_np, self._model_config) | |
def preprocess(self, image: Image.Image): | |
self.load() | |
return preprocess_onnx_input(image, self._preprocessor_config) | |
def postprocess(self, onnx_output: dict, class_names_from_registry: list): | |
self.load() | |
return postprocess_onnx_output(onnx_output, self._model_config) | |
# The main registration function | |
def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output): | |
for model_key, hf_model_path in MODEL_PATHS.items(): | |
model_num = model_key.replace("model_", "").upper() | |
contributor = "Unknown" | |
architecture = "Unknown" | |
dataset = "TBA" | |
current_class_names = CLASS_NAMES.get(model_key, []) | |
if "ONNX" in hf_model_path: | |
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path) | |
if model_key == "model_1": | |
contributor = "haywoodsloan" | |
architecture = "SwinV2" | |
dataset = "DeepFakeDetection" | |
elif model_key == "model_2": | |
contributor = "Heem2" | |
architecture = "ViT" | |
dataset = "DeepFakeDetection" | |
elif model_key == "model_3": | |
contributor = "Organika" | |
architecture = "VIT" | |
dataset = "SDXL" | |
elif model_key == "model_5": | |
contributor = "prithivMLmods" | |
architecture = "VIT" | |
elif model_key == "model_6": | |
contributor = "ideepankarsharma2003" | |
architecture = "SWINv1" | |
dataset = "SDXL, Midjourney" | |
elif model_key == "model_7": | |
contributor = "date3k2" | |
architecture = "VIT" | |
display_name_parts = [model_num] | |
if architecture and architecture not in ["Unknown"]: | |
display_name_parts.append(architecture) | |
if dataset and dataset not in ["TBA"]: | |
display_name_parts.append(dataset) | |
display_name = "-".join(display_name_parts) + "_ONNX" | |
register_model_with_metadata( | |
model_id=model_key, | |
model=onnx_wrapper_instance, | |
preprocess=onnx_wrapper_instance.preprocess, | |
postprocess=onnx_wrapper_instance.postprocess, | |
class_names=current_class_names, | |
display_name=display_name, | |
contributor=contributor, | |
model_path=hf_model_path, | |
architecture=architecture, | |
dataset=dataset | |
) | |
elif model_key == "model_8": | |
contributor = "aiwithoutborders-xyz" | |
architecture = "ViT" | |
dataset = "DeepfakeDetection" | |
display_name_parts = [model_num] | |
if architecture and architecture not in ["Unknown"]: | |
display_name_parts.append(architecture) | |
if dataset and dataset not in ["TBA"]: | |
display_name_parts.append(dataset) | |
display_name = "-".join(display_name_parts) | |
register_model_with_metadata( | |
model_id=model_key, | |
model=infer_gradio_api, | |
preprocess=preprocess_gradio_api, | |
postprocess=postprocess_gradio_api, | |
class_names=current_class_names, | |
display_name=display_name, | |
contributor=contributor, | |
model_path=hf_model_path, | |
architecture=architecture, | |
dataset=dataset | |
) | |
elif model_key == "model_4": | |
contributor = "cmckinle" | |
architecture = "VIT" | |
dataset = "SDXL, FLUX" | |
display_name_parts = [model_num] | |
if architecture and architecture not in ["Unknown"]: | |
display_name_parts.append(architecture) | |
if dataset and dataset not in ["TBA"]: | |
display_name_parts.append(dataset) | |
display_name = "-".join(display_name_parts) | |
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device) | |
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device) | |
preprocess_func = preprocess_resize_256 | |
postprocess_func = postprocess_logits | |
def custom_infer(image, processor_local=current_processor, model_local=model_instance): | |
inputs = processor_local(image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model_local(**inputs) | |
return outputs | |
model_instance = custom_infer | |
register_model_with_metadata( | |
model_id=model_key, | |
model=model_instance, | |
preprocess=preprocess_func, | |
postprocess=postprocess_func, | |
class_names=current_class_names, | |
display_name=display_name, | |
contributor=contributor, | |
model_path=hf_model_path, | |
architecture=architecture, | |
dataset=dataset | |
) | |
else: | |
pass # Fallback for any unhandled models | |