File size: 7,897 Bytes
0adef02 3232315 ceff40b eeb602e f00c873 a4381af f00c873 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from utils.onnx_helpers import postprocess_onnx_output
# Add missing import for infer_onnx_model
from utils.onnx_helpers import infer_onnx_model
# Add missing import for preprocess_onnx_input
from utils.onnx_helpers import preprocess_onnx_input
"""
Model loading and registration logic for OpenSight Deepfake Detection Playground.
Handles ONNX, HuggingFace, and Gradio API model registration and metadata.
"""
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import numpy as np
from PIL import Image
# Model paths and class names (copied from app_mcp.py)
MODEL_PATHS = {
"model_1": "LPX55/detection-model-1-ONNX",
"model_2": "LPX55/detection-model-2-ONNX",
"model_3": "LPX55/detection-model-3-ONNX",
"model_4": "cmckinle/sdxl-flux-detector_v1.1",
"model_5": "LPX55/detection-model-5-ONNX",
"model_6": "LPX55/detection-model-6-ONNX",
"model_7": "LPX55/detection-model-7-ONNX",
"model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
}
CLASS_NAMES = {
"model_1": ['artificial', 'real'],
"model_2": ['AI Image', 'Real Image'],
"model_3": ['artificial', 'human'],
"model_4": ['AI', 'Real'],
"model_5": ['Realism', 'Deepfake'],
"model_6": ['ai_gen', 'human'],
"model_7": ['Fake', 'Real'],
"model_8": ['Fake', 'Real'],
}
# Cache for ONNX sessions and preprocessors
_onnx_model_cache = {}
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
MODEL_REGISTRY[model_id] = entry
class ONNXModelWrapper:
def __init__(self, hf_model_id):
self.hf_model_id = hf_model_id
self._session = None
self._preprocessor_config = None
self._model_config = None
def load(self):
if self._session is None:
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
)
def __call__(self, image_np):
self.load()
return infer_onnx_model(self.hf_model_id, image_np, self._model_config)
def preprocess(self, image: Image.Image):
self.load()
return preprocess_onnx_input(image, self._preprocessor_config)
def postprocess(self, onnx_output: dict, class_names_from_registry: list):
self.load()
return postprocess_onnx_output(onnx_output, self._model_config)
# The main registration function
def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output):
for model_key, hf_model_path in MODEL_PATHS.items():
model_num = model_key.replace("model_", "").upper()
contributor = "Unknown"
architecture = "Unknown"
dataset = "TBA"
current_class_names = CLASS_NAMES.get(model_key, [])
if "ONNX" in hf_model_path:
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
if model_key == "model_1":
contributor = "haywoodsloan"
architecture = "SwinV2"
dataset = "DeepFakeDetection"
elif model_key == "model_2":
contributor = "Heem2"
architecture = "ViT"
dataset = "DeepFakeDetection"
elif model_key == "model_3":
contributor = "Organika"
architecture = "VIT"
dataset = "SDXL"
elif model_key == "model_5":
contributor = "prithivMLmods"
architecture = "VIT"
elif model_key == "model_6":
contributor = "ideepankarsharma2003"
architecture = "SWINv1"
dataset = "SDXL, Midjourney"
elif model_key == "model_7":
contributor = "date3k2"
architecture = "VIT"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts) + "_ONNX"
register_model_with_metadata(
model_id=model_key,
model=onnx_wrapper_instance,
preprocess=onnx_wrapper_instance.preprocess,
postprocess=onnx_wrapper_instance.postprocess,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_8":
contributor = "aiwithoutborders-xyz"
architecture = "ViT"
dataset = "DeepfakeDetection"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
register_model_with_metadata(
model_id=model_key,
model=infer_gradio_api,
preprocess=preprocess_gradio_api,
postprocess=postprocess_gradio_api,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
elif model_key == "model_4":
contributor = "cmckinle"
architecture = "VIT"
dataset = "SDXL, FLUX"
display_name_parts = [model_num]
if architecture and architecture not in ["Unknown"]:
display_name_parts.append(architecture)
if dataset and dataset not in ["TBA"]:
display_name_parts.append(dataset)
display_name = "-".join(display_name_parts)
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
preprocess_func = preprocess_resize_256
postprocess_func = postprocess_logits
def custom_infer(image, processor_local=current_processor, model_local=model_instance):
inputs = processor_local(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_local(**inputs)
return outputs
model_instance = custom_infer
register_model_with_metadata(
model_id=model_key,
model=model_instance,
preprocess=preprocess_func,
postprocess=postprocess_func,
class_names=current_class_names,
display_name=display_name,
contributor=contributor,
model_path=hf_model_path,
architecture=architecture,
dataset=dataset
)
else:
pass # Fallback for any unhandled models
|