LPX55
feat: implement model registration logic for ONNX, HuggingFace, and Gradio API
f00c873
raw
history blame
6.86 kB
"""
Model loading and registration logic for OpenSight Deepfake Detection Playground.
Handles ONNX, HuggingFace, and Gradio API model registration and metadata.
"""
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import numpy as np
from PIL import Image
# Cache for ONNX sessions and preprocessors
_onnx_model_cache = {}
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
MODEL_REGISTRY[model_id] = entry
class ONNXModelWrapper:
def __init__(self, hf_model_id):
self.hf_model_id = hf_model_id
self._session = None
self._preprocessor_config = None
self._model_config = None
def load(self):
if self._session is None:
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
)
def __call__(self, image_np):
self.load()
return infer_onnx_model(self.hf_model_id, image_np, self._model_config)
def preprocess(self, image: Image.Image):
self.load()
return preprocess_onnx_input(image, self._preprocessor_config)
def postprocess(self, onnx_output: dict, class_names_from_registry: list):
self.load()
return postprocess_onnx_output(onnx_output, self._model_config)
# The main registration function
def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output):
for model_key, hf_model_path in MODEL_PATHS.items():
model_num = model_key.replace("model_", "").upper()
contributor = "Unknown"
architecture = "Unknown"
dataset = "TBA"
current_class_names = CLASS_NAMES.get(model_key, [])
if "ONNX" in hf_model_path:
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
if model_key == "model_1":
contributor = "haywoodsloan"
architecture = "SwinV2"
dataset = "DeepFakeDetection"
elif model_key == "model_2":
contributor = "Heem2"
architecture = "ViT"
dataset = "DeepFakeDetection"
elif model_key == "model_3":
contributor = "Organika"
architecture = "VIT"
dataset = "SDXL"
elif model_key == "model_5":
contributor = "prithivMLmods"
architecture = "VIT"
elif model_key == "model_6":
contributor = "ideepankarsharma2003"
architecture = "SWINv1"
dataset = "SDXL, Midjourney"
elif model_key == "model_7":
contributor = "date3k2"
architecture = "VIT"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts) + "_ONNX"
register_model_with_metadata(
model_id=model_key,
model=onnx_wrapper_instance,
preprocess=onnx_wrapper_instance.preprocess,
postprocess=onnx_wrapper_instance.postprocess,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_8":
contributor = "aiwithoutborders-xyz"
architecture = "ViT"
dataset = "DeepfakeDetection"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
register_model_with_metadata(
model_id=model_key,
model=infer_gradio_api,
preprocess=preprocess_gradio_api,
postprocess=postprocess_gradio_api,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_4":
contributor = "cmckinle"
architecture = "VIT"
dataset = "SDXL, FLUX"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
preprocess_func = preprocess_resize_256
postprocess_func = postprocess_logits
def custom_infer(image, processor_local=current_processor, model_local=model_instance):
inputs = processor_local(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_local(**inputs)
return outputs
model_instance = custom_infer
register_model_with_metadata(
model_id=model_key,
model=model_instance,
preprocess=preprocess_func,
postprocess=postprocess_func,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
else:
pass # Fallback for any unhandled models