LPX55 commited on
Commit
59627db
·
1 Parent(s): f00c873

refactor: consolidate model registration logic for ONNX and Gradio API

Browse files
Files changed (1) hide show
  1. app.py +3 -158
app.py CHANGED
@@ -80,8 +80,6 @@ CLASS_NAMES = {
80
  "model_8": ['Fake', 'Real'],
81
  }
82
 
83
- # Register all models (ONNX, HuggingFace, Gradio API)
84
- register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
85
 
86
 
87
 
@@ -124,167 +122,14 @@ def infer_onnx_model(hf_model_id, preprocessed_image_np, model_config: dict):
124
  # Return a structure consistent with other model errors
125
  return {"logits": np.array([]), "probabilities": np.array([])}
126
 
 
 
127
 
128
  # Register the ONNX quantized model
129
  # Dummy entry for ONNX model to be loaded dynamically
130
  # We will now register a 'wrapper' that handles dynamic loading
131
 
132
- class ONNXModelWrapper:
133
- def __init__(self, hf_model_id):
134
- self.hf_model_id = hf_model_id
135
- self._session = None
136
- self._preprocessor_config = None
137
- self._model_config = None
138
-
139
- def load(self):
140
- if self._session is None:
141
- self._session, self._preprocessor_config, self._model_config = get_onnx_model_from_cache(
142
- self.hf_model_id, _onnx_model_cache, load_onnx_model_and_preprocessor
143
- )
144
- logger.info(f"ONNX model {self.hf_model_id} loaded into wrapper.")
145
-
146
- def __call__(self, image_np):
147
- self.load() # Ensure model is loaded on first call
148
- # Pass model_config to infer_onnx_model
149
- return infer_onnx_model(self.hf_model_id, image_np, self._model_config)
150
-
151
- def preprocess(self, image: Image.Image):
152
- self.load()
153
- return preprocess_onnx_input(image, self._preprocessor_config)
154
-
155
- def postprocess(self, onnx_output: dict, class_names_from_registry: list): # class_names_from_registry is ignored
156
- self.load()
157
- return postprocess_onnx_output(onnx_output, self._model_config)
158
-
159
- # Consolidate all model loading and registration
160
- for model_key, hf_model_path in MODEL_PATHS.items():
161
- logger.debug(f"Attempting to register model: {model_key} with path: {hf_model_path}")
162
- model_num = model_key.replace("model_", "").upper()
163
- contributor = "Unknown"
164
- architecture = "Unknown"
165
- dataset = "TBA"
166
-
167
- current_class_names = CLASS_NAMES.get(model_key, [])
168
-
169
- # Logic for ONNX models (1, 2, 3, 5, 6, 7)
170
- if "ONNX" in hf_model_path:
171
- logger.debug(f"Model {model_key} identified as ONNX.")
172
- logger.info(f"Registering ONNX model: {model_key} from {hf_model_path}")
173
- onnx_wrapper_instance = ONNXModelWrapper(hf_model_path)
174
-
175
- # Attempt to derive contributor, architecture, dataset based on model_key
176
- if model_key == "model_1":
177
- contributor = "haywoodsloan"
178
- architecture = "SwinV2"
179
- dataset = "DeepFakeDetection"
180
- elif model_key == "model_2":
181
- contributor = "Heem2"
182
- architecture = "ViT"
183
- dataset = "DeepFakeDetection"
184
- elif model_key == "model_3":
185
- contributor = "Organika"
186
- architecture = "VIT"
187
- dataset = "SDXL"
188
- elif model_key == "model_5":
189
- contributor = "prithivMLmods"
190
- architecture = "VIT"
191
- elif model_key == "model_6":
192
- contributor = "ideepankarsharma2003"
193
- architecture = "SWINv1"
194
- dataset = "SDXL, Midjourney"
195
- elif model_key == "model_7":
196
- contributor = "date3k2"
197
- architecture = "VIT"
198
-
199
- display_name_parts = [model_num]
200
- if architecture and architecture not in ["Unknown"]:
201
- display_name_parts.append(architecture)
202
- if dataset and dataset not in ["TBA"]:
203
- display_name_parts.append(dataset)
204
- display_name = "-".join(display_name_parts)
205
- display_name += "_ONNX" # Always append _ONNX for ONNX models
206
-
207
- register_model_with_metadata(
208
- model_id=model_key,
209
- model=onnx_wrapper_instance, # The callable wrapper for the ONNX model
210
- preprocess=onnx_wrapper_instance.preprocess,
211
- postprocess=onnx_wrapper_instance.postprocess,
212
- class_names=current_class_names, # Initial class names; will be overridden by model_config if available
213
- display_name=display_name,
214
- contributor=contributor,
215
- model_path=hf_model_path,
216
- architecture=architecture,
217
- dataset=dataset
218
- )
219
- # Logic for Gradio API model (model_8)
220
- elif model_key == "model_8":
221
- logger.debug(f"Model {model_key} identified as Gradio API.")
222
- logger.info(f"Registering Gradio API model: {model_key} from {hf_model_path}")
223
- contributor = "aiwithoutborders-xyz"
224
- architecture = "ViT"
225
- dataset = "DeepfakeDetection"
226
-
227
- display_name_parts = [model_num]
228
- if architecture and architecture not in ["Unknown"]:
229
- display_name_parts.append(architecture)
230
- if dataset and dataset not in ["TBA"]:
231
- display_name_parts.append(dataset)
232
- display_name = "-".join(display_name_parts)
233
-
234
- register_model_with_metadata(
235
- model_id=model_key,
236
- model=infer_gradio_api,
237
- preprocess=preprocess_gradio_api,
238
- postprocess=postprocess_gradio_api,
239
- class_names=current_class_names,
240
- display_name=display_name,
241
- contributor=contributor,
242
- model_path=hf_model_path,
243
- architecture=architecture,
244
- dataset=dataset
245
- )
246
- # Logic for PyTorch/Hugging Face pipeline models (currently only model_4)
247
- elif model_key == "model_4": # Explicitly handle model_4
248
- logger.debug(f"Model {model_key} identified as PyTorch/HuggingFace pipeline.")
249
- logger.info(f"Registering HuggingFace pipeline/AutoModel: {model_key} from {hf_model_path}")
250
- contributor = "cmckinle"
251
- architecture = "VIT"
252
- dataset = "SDXL, FLUX"
253
-
254
- display_name_parts = [model_num]
255
- if architecture and architecture not in ["Unknown"]:
256
- display_name_parts.append(architecture)
257
- if dataset and dataset not in ["TBA"]:
258
- display_name_parts.append(dataset)
259
- display_name = "-".join(display_name_parts)
260
-
261
- current_processor = AutoFeatureExtractor.from_pretrained(hf_model_path, device=device)
262
- model_instance = AutoModelForImageClassification.from_pretrained(hf_model_path).to(device)
263
-
264
- preprocess_func = preprocess_resize_256
265
- postprocess_func = postprocess_logits
266
-
267
- def custom_infer(image, processor_local=current_processor, model_local=model_instance):
268
- inputs = processor_local(image, return_tensors="pt").to(device)
269
- with torch.no_grad():
270
- outputs = model_local(**inputs)
271
- return outputs
272
- model_instance = custom_infer
273
-
274
- register_model_with_metadata(
275
- model_id=model_key,
276
- model=model_instance,
277
- preprocess=preprocess_func,
278
- postprocess=postprocess_func,
279
- class_names=current_class_names,
280
- display_name=display_name,
281
- contributor=contributor,
282
- model_path=hf_model_path,
283
- architecture=architecture,
284
- dataset=dataset
285
- )
286
- else: # Fallback for any unhandled models (shouldn't happen if MODEL_PATHS is fully covered)
287
- logger.warning(f"Could not automatically load and register model: {model_key} from {hf_model_path}. No matching registration logic found.")
288
 
289
 
290
  def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
 
80
  "model_8": ['Fake', 'Real'],
81
  }
82
 
 
 
83
 
84
 
85
 
 
122
  # Return a structure consistent with other model errors
123
  return {"logits": np.array([]), "probabilities": np.array([])}
124
 
125
+ # Register all models (ONNX, HuggingFace, Gradio API)
126
+ register_all_models(MODEL_PATHS, CLASS_NAMES, device, infer_onnx_model, preprocess_onnx_input, postprocess_onnx_output)
127
 
128
  # Register the ONNX quantized model
129
  # Dummy entry for ONNX model to be loaded dynamically
130
  # We will now register a 'wrapper' that handles dynamic loading
131
 
132
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict: