File size: 7,897 Bytes
0adef02
3232315
ceff40b
eeb602e
 
f00c873
 
 
 
 
 
 
 
 
 
 
 
a4381af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from utils.onnx_helpers import postprocess_onnx_output
# Add missing import for infer_onnx_model
from utils.onnx_helpers import infer_onnx_model
# Add missing import for preprocess_onnx_input
from utils.onnx_helpers import preprocess_onnx_input
"""
Model loading and registration logic for OpenSight Deepfake Detection Playground.
Handles ONNX, HuggingFace, and Gradio API model registration and metadata.
"""
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import numpy as np
from PIL import Image


# Model paths and class names (copied from app_mcp.py)
MODEL_PATHS = {
    "model_1": "LPX55/detection-model-1-ONNX",
    "model_2": "LPX55/detection-model-2-ONNX",
    "model_3": "LPX55/detection-model-3-ONNX",
    "model_4": "cmckinle/sdxl-flux-detector_v1.1",
    "model_5": "LPX55/detection-model-5-ONNX",
    "model_6": "LPX55/detection-model-6-ONNX",
    "model_7": "LPX55/detection-model-7-ONNX",
    "model_8": "aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
}

CLASS_NAMES = {
    "model_1": ['artificial', 'real'],
    "model_2": ['AI Image', 'Real Image'],
    "model_3": ['artificial', 'human'],
    "model_4": ['AI', 'Real'],
    "model_5": ['Realism', 'Deepfake'],
    "model_6": ['ai_gen', 'human'],
    "model_7": ['Fake', 'Real'],
    "model_8": ['Fake', 'Real'],
}


# Cache for ONNX sessions and preprocessors
_onnx_model_cache = {}

def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
    entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
    MODEL_REGISTRY[model_id] = entry

class ONNXModelWrapper:
    def __init__(self, hf_model_id):
        self.hf_model_id = hf_model_id
        self._session = None
        self._preprocessor_config = None
        self._model_config = None

    def load(self):
        if self._session is None:
            self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
                self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
            )

    def __call__(self, image_np):
        self.load()
        return infer_onnx_model(self.hf_model_id, image_np, self._model_config)

    def preprocess(self, image: Image.Image):
        self.load()
        return preprocess_onnx_input(image, self._preprocessor_config)

    def postprocess(self, onnx_output: dict, class_names_from_registry: list):
        self.load()
        return postprocess_onnx_output(onnx_output, self._model_config)

# The main registration function

def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output):
    for model_key, hf_model_path in MODEL_PATHS.items():
        model_num = model_key.replace("model_", "").upper()
        contributor = "Unknown"
        architecture = "Unknown"
        dataset = "TBA"
        current_class_names = CLASS_NAMES.get(model_key, [])
        if "ONNX" in hf_model_path:
            onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
            if model_key == "model_1":
                contributor = "haywoodsloan"
                architecture = "SwinV2"
                dataset = "DeepFakeDetection"
            elif model_key == "model_2":
                contributor = "Heem2"
                architecture = "ViT"
                dataset = "DeepFakeDetection"
            elif model_key == "model_3":
                contributor = "Organika"
                architecture = "VIT"
                dataset = "SDXL"
            elif model_key == "model_5":
                contributor = "prithivMLmods"
                architecture = "VIT"
            elif model_key == "model_6":
                contributor = "ideepankarsharma2003"
                architecture = "SWINv1"
                dataset = "SDXL, Midjourney"
            elif model_key == "model_7":
                contributor = "date3k2"
                architecture = "VIT"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts) + "_ONNX"
            register_model_with_metadata(
                model_id=model_key,
                model=onnx_wrapper_instance,
                preprocess=onnx_wrapper_instance.preprocess,
                postprocess=onnx_wrapper_instance.postprocess,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        elif model_key == "model_8":
            contributor = "aiwithoutborders-xyz"
            architecture = "ViT"
            dataset = "DeepfakeDetection"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts)
            register_model_with_metadata(
                model_id=model_key,
                model=infer_gradio_api,
                preprocess=preprocess_gradio_api,
                postprocess=postprocess_gradio_api,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        elif model_key == "model_4":
            contributor = "cmckinle"
            architecture = "VIT"
            dataset = "SDXL, FLUX"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts)
            current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
            model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
            preprocess_func = preprocess_resize_256
            postprocess_func = postprocess_logits
            def custom_infer(image, processor_local=current_processor, model_local=model_instance):
                inputs = processor_local(image, return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model_local(**inputs)
                return outputs
            model_instance = custom_infer
            register_model_with_metadata(
                model_id=model_key,
                model=model_instance,
                preprocess=preprocess_func,
                postprocess=postprocess_func,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        else:
            pass # Fallback for any unhandled models