LPX55
commited on
Commit
·
59627db
1
Parent(s):
f00c873
refactor: consolidate model registration logic for ONNX and Gradio API
Browse files
app.py
CHANGED
@@ -80,8 +80,6 @@ CLASS_NAMES = {
|
|
80 |
"model_8": ['Fake', 'Real'],
|
81 |
}
|
82 |
|
83 |
-
# Register all models (ONNX, HuggingFace, Gradio API)
|
84 |
-
register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
|
85 |
|
86 |
|
87 |
|
@@ -124,167 +122,14 @@ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
|
|
124 |
# Return a structure consistent with other model errors
|
125 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
126 |
|
|
|
|
|
127 |
|
128 |
# Register the ONNX quantized model
|
129 |
# Dummy entry for ONNX model to be loaded dynamically
|
130 |
# We will now register a 'wrapper' that handles dynamic loading
|
131 |
|
132 |
-
|
133 |
-
def __init__(self, hf_model_id):
|
134 |
-
self.hf_model_id = hf_model_id
|
135 |
-
self._session = None
|
136 |
-
self._preprocessor_config = None
|
137 |
-
self._model_config = None
|
138 |
-
|
139 |
-
def load(self):
|
140 |
-
if self._session is None:
|
141 |
-
self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
|
142 |
-
self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
|
143 |
-
)
|
144 |
-
logger.info(f"ONNX model {self.hf_model_id} loaded into wrapper.")
|
145 |
-
|
146 |
-
def __call__(self, image_np):
|
147 |
-
self.load() # Ensure model is loaded on first call
|
148 |
-
# Pass model_config to infer_onnx_model
|
149 |
-
return infer_onnx_model(self.hf_model_id, image_np, self._model_config)
|
150 |
-
|
151 |
-
def preprocess(self, image: Image.Image):
|
152 |
-
self.load()
|
153 |
-
return preprocess_onnx_input(image, self._preprocessor_config)
|
154 |
-
|
155 |
-
def postprocess(self, onnx_output: dict, class_names_from_registry: list): # class_names_from_registry is ignored
|
156 |
-
self.load()
|
157 |
-
return postprocess_onnx_output(onnx_output, self._model_config)
|
158 |
-
|
159 |
-
# Consolidate all model loading and registration
|
160 |
-
for model_key, hf_model_path in MODEL_PATHS.items():
|
161 |
-
logger.debug(f"Attempting to register model: {model_key} with path: {hf_model_path}")
|
162 |
-
model_num = model_key.replace("model_", "").upper()
|
163 |
-
contributor = "Unknown"
|
164 |
-
architecture = "Unknown"
|
165 |
-
dataset = "TBA"
|
166 |
-
|
167 |
-
current_class_names = CLASS_NAMES.get(model_key, [])
|
168 |
-
|
169 |
-
# Logic for ONNX models (1, 2, 3, 5, 6, 7)
|
170 |
-
if "ONNX" in hf_model_path:
|
171 |
-
logger.debug(f"Model {model_key} identified as ONNX.")
|
172 |
-
logger.info(f"Registering ONNX model: {model_key} from {hf_model_path}")
|
173 |
-
onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
|
174 |
-
|
175 |
-
# Attempt to derive contributor, architecture, dataset based on model_key
|
176 |
-
if model_key == "model_1":
|
177 |
-
contributor = "haywoodsloan"
|
178 |
-
architecture = "SwinV2"
|
179 |
-
dataset = "DeepFakeDetection"
|
180 |
-
elif model_key == "model_2":
|
181 |
-
contributor = "Heem2"
|
182 |
-
architecture = "ViT"
|
183 |
-
dataset = "DeepFakeDetection"
|
184 |
-
elif model_key == "model_3":
|
185 |
-
contributor = "Organika"
|
186 |
-
architecture = "VIT"
|
187 |
-
dataset = "SDXL"
|
188 |
-
elif model_key == "model_5":
|
189 |
-
contributor = "prithivMLmods"
|
190 |
-
architecture = "VIT"
|
191 |
-
elif model_key == "model_6":
|
192 |
-
contributor = "ideepankarsharma2003"
|
193 |
-
architecture = "SWINv1"
|
194 |
-
dataset = "SDXL, Midjourney"
|
195 |
-
elif model_key == "model_7":
|
196 |
-
contributor = "date3k2"
|
197 |
-
architecture = "VIT"
|
198 |
-
|
199 |
-
display_name_parts = [model_num]
|
200 |
-
if architecture and architecture not in ["Unknown"]:
|
201 |
-
display_name_parts.append(architecture)
|
202 |
-
if dataset and dataset not in ["TBA"]:
|
203 |
-
display_name_parts.append(dataset)
|
204 |
-
display_name = "-".join(display_name_parts)
|
205 |
-
display_name += "_ONNX" # Always append _ONNX for ONNX models
|
206 |
-
|
207 |
-
register_model_with_metadata(
|
208 |
-
model_id=model_key,
|
209 |
-
model=onnx_wrapper_instance, # The callable wrapper for the ONNX model
|
210 |
-
preprocess=onnx_wrapper_instance.preprocess,
|
211 |
-
postprocess=onnx_wrapper_instance.postprocess,
|
212 |
-
class_names=current_class_names, # Initial class names; will be overridden by model_config if available
|
213 |
-
display_name=display_name,
|
214 |
-
contributor=contributor,
|
215 |
-
model_path=hf_model_path,
|
216 |
-
architecture=architecture,
|
217 |
-
dataset=dataset
|
218 |
-
)
|
219 |
-
# Logic for Gradio API model (model_8)
|
220 |
-
elif model_key == "model_8":
|
221 |
-
logger.debug(f"Model {model_key} identified as Gradio API.")
|
222 |
-
logger.info(f"Registering Gradio API model: {model_key} from {hf_model_path}")
|
223 |
-
contributor = "aiwithoutborders-xyz"
|
224 |
-
architecture = "ViT"
|
225 |
-
dataset = "DeepfakeDetection"
|
226 |
-
|
227 |
-
display_name_parts = [model_num]
|
228 |
-
if architecture and architecture not in ["Unknown"]:
|
229 |
-
display_name_parts.append(architecture)
|
230 |
-
if dataset and dataset not in ["TBA"]:
|
231 |
-
display_name_parts.append(dataset)
|
232 |
-
display_name = "-".join(display_name_parts)
|
233 |
-
|
234 |
-
register_model_with_metadata(
|
235 |
-
model_id=model_key,
|
236 |
-
model=infer_gradio_api,
|
237 |
-
preprocess=preprocess_gradio_api,
|
238 |
-
postprocess=postprocess_gradio_api,
|
239 |
-
class_names=current_class_names,
|
240 |
-
display_name=display_name,
|
241 |
-
contributor=contributor,
|
242 |
-
model_path=hf_model_path,
|
243 |
-
architecture=architecture,
|
244 |
-
dataset=dataset
|
245 |
-
)
|
246 |
-
# Logic for PyTorch/Hugging Face pipeline models (currently only model_4)
|
247 |
-
elif model_key == "model_4": # Explicitly handle model_4
|
248 |
-
logger.debug(f"Model {model_key} identified as PyTorch/HuggingFace pipeline.")
|
249 |
-
logger.info(f"Registering HuggingFace pipeline/AutoModel: {model_key} from {hf_model_path}")
|
250 |
-
contributor = "cmckinle"
|
251 |
-
architecture = "VIT"
|
252 |
-
dataset = "SDXL, FLUX"
|
253 |
-
|
254 |
-
display_name_parts = [model_num]
|
255 |
-
if architecture and architecture not in ["Unknown"]:
|
256 |
-
display_name_parts.append(architecture)
|
257 |
-
if dataset and dataset not in ["TBA"]:
|
258 |
-
display_name_parts.append(dataset)
|
259 |
-
display_name = "-".join(display_name_parts)
|
260 |
-
|
261 |
-
current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
|
262 |
-
model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
|
263 |
-
|
264 |
-
preprocess_func = preprocess_resize_256
|
265 |
-
postprocess_func = postprocess_logits
|
266 |
-
|
267 |
-
def custom_infer(image, processor_local=current_processor, model_local=model_instance):
|
268 |
-
inputs = processor_local(image, return_tensors="pt").to(device)
|
269 |
-
with torch.no_grad():
|
270 |
-
outputs = model_local(**inputs)
|
271 |
-
return outputs
|
272 |
-
model_instance = custom_infer
|
273 |
-
|
274 |
-
register_model_with_metadata(
|
275 |
-
model_id=model_key,
|
276 |
-
model=model_instance,
|
277 |
-
preprocess=preprocess_func,
|
278 |
-
postprocess=postprocess_func,
|
279 |
-
class_names=current_class_names,
|
280 |
-
display_name=display_name,
|
281 |
-
contributor=contributor,
|
282 |
-
model_path=hf_model_path,
|
283 |
-
architecture=architecture,
|
284 |
-
dataset=dataset
|
285 |
-
)
|
286 |
-
else: # Fallback for any unhandled models (shouldn't happen if MODEL_PATHS is fully covered)
|
287 |
-
logger.warning(f"Could not automatically load and register model: {model_key} from {hf_model_path}. No matching registration logic found.")
|
288 |
|
289 |
|
290 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
|
|
80 |
"model_8": ['Fake', 'Real'],
|
81 |
}
|
82 |
|
|
|
|
|
83 |
|
84 |
|
85 |
|
|
|
122 |
# Return a structure consistent with other model errors
|
123 |
return {"logits": np.array([]), "probabilities": np.array([])}
|
124 |
|
125 |
+
# Register all models (ONNX, HuggingFace, Gradio API)
|
126 |
+
register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
|
127 |
|
128 |
# Register the ONNX quantized model
|
129 |
# Dummy entry for ONNX model to be loaded dynamically
|
130 |
# We will now register a 'wrapper' that handles dynamic loading
|
131 |
|
132 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|