File size: 6,956 Bytes
eeb602e
 
f00c873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Add missing import for preprocess_onnx_input
from utils.onnx_helpers import preprocess_onnx_input
"""
Model loading and registration logic for OpenSight Deepfake Detection Playground.
Handles ONNX, HuggingFace, and Gradio API model registration and metadata.
"""
from utils.registry import register_model, MODEL_REGISTRY, ModelEntry
from utils.onnx_model_loader import load_onnx_model_and_preprocessor, get_onnx_model_from_cache
from utils.utils import preprocess_resize_256, postprocess_logits, infer_gradio_api, preprocess_gradio_api, postprocess_gradio_api
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import numpy as np
from PIL import Image

# Cache for ONNX sessions and preprocessors
_onnx_model_cache = {}

def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
    entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
    MODEL_REGISTRY[model_id] = entry

class ONNXModelWrapper:
    def __init__(self, hf_model_id):
        self.hf_model_id = hf_model_id
        self._session = None
        self._preprocessor_config = None
        self._model_config = None

    def load(self):
        if self._session is None:
            self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
                self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
            )

    def __call__(self, image_np):
        self.load()
        return infer_onnx_model(self.hf_model_id, image_np, self._model_config)

    def preprocess(self, image: Image.Image):
        self.load()
        return preprocess_onnx_input(image, self._preprocessor_config)

    def postprocess(self, onnx_output: dict, class_names_from_registry: list):
        self.load()
        return postprocess_onnx_output(onnx_output, self._model_config)

# The main registration function

def register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output):
    for model_key, hf_model_path in MODEL_PATHS.items():
        model_num = model_key.replace("model_", "").upper()
        contributor = "Unknown"
        architecture = "Unknown"
        dataset = "TBA"
        current_class_names = CLASS_NAMES.get(model_key, [])
        if "ONNX" in hf_model_path:
            onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
            if model_key == "model_1":
                contributor = "haywoodsloan"
                architecture = "SwinV2"
                dataset = "DeepFakeDetection"
            elif model_key == "model_2":
                contributor = "Heem2"
                architecture = "ViT"
                dataset = "DeepFakeDetection"
            elif model_key == "model_3":
                contributor = "Organika"
                architecture = "VIT"
                dataset = "SDXL"
            elif model_key == "model_5":
                contributor = "prithivMLmods"
                architecture = "VIT"
            elif model_key == "model_6":
                contributor = "ideepankarsharma2003"
                architecture = "SWINv1"
                dataset = "SDXL, Midjourney"
            elif model_key == "model_7":
                contributor = "date3k2"
                architecture = "VIT"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts) + "_ONNX"
            register_model_with_metadata(
                model_id=model_key,
                model=onnx_wrapper_instance,
                preprocess=onnx_wrapper_instance.preprocess,
                postprocess=onnx_wrapper_instance.postprocess,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        elif model_key == "model_8":
            contributor = "aiwithoutborders-xyz"
            architecture = "ViT"
            dataset = "DeepfakeDetection"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts)
            register_model_with_metadata(
                model_id=model_key,
                model=infer_gradio_api,
                preprocess=preprocess_gradio_api,
                postprocess=postprocess_gradio_api,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        elif model_key == "model_4":
            contributor = "cmckinle"
            architecture = "VIT"
            dataset = "SDXL, FLUX"
            display_name_parts = [model_num]
            if architecture and architecture not in ["Unknown"]:
                display_name_parts.append(architecture)
            if dataset and dataset not in ["TBA"]:
                display_name_parts.append(dataset)
            display_name = "-".join(display_name_parts)
            current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
            model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
            preprocess_func = preprocess_resize_256
            postprocess_func = postprocess_logits
            def custom_infer(image, processor_local=current_processor, model_local=model_instance):
                inputs = processor_local(image, return_tensors="pt").to(device)
                with torch.no_grad():
                    outputs = model_local(**inputs)
                return outputs
            model_instance = custom_infer
            register_model_with_metadata(
                model_id=model_key,
                model=model_instance,
                preprocess=preprocess_func,
                postprocess=postprocess_func,
                class_names=current_class_names,
                display_name=display_name,
                contributor=contributor,
                model_path=hf_model_path,
                architecture=architecture,
                dataset=dataset
            )
        else:
            pass # Fallback for any unhandled models