LPX55 commited on
Commit
1146644
·
1 Parent(s): 39558cb

feat: enhance image handling in predictions and dynamically configure model weights based on MODEL_REGISTRY

Browse files
Files changed (2) hide show
  1. agents/ensemble_weights.py +22 -11
  2. app.py +9 -2
agents/ensemble_weights.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import torch
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
@@ -35,17 +36,27 @@ class ContextualWeightOverrideAgent:
35
 
36
 
37
  class ModelWeightManager:
38
- def __init__(self):
39
- self.base_weights = {
40
- "model_1": 0.15, # SwinV2 Based
41
- "model_2": 0.15, # ViT Based
42
- "model_3": 0.15, # SDXL Dataset
43
- "model_4": 0.15, # SDXL + FLUX
44
- "model_5": 0.15, # ViT Based
45
- "model_5b": 0.10, # ViT Based, Newer Dataset
46
- "model_6": 0.10, # Swin, Midj + SDXL
47
- "model_7": 0.05 # ViT
48
- }
 
 
 
 
 
 
 
 
 
 
49
  self.situation_weights = {
50
  "high_confidence": 1.2, # Boost weights for high confidence predictions
51
  "low_confidence": 0.8, # Reduce weights for low confidence
 
1
  import logging
2
  import torch
3
+ from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
36
 
37
 
38
  class ModelWeightManager:
39
+ def __init__(self, strongest_model_id: str = None):
40
+ # Dynamically initialize base_weights from MODEL_REGISTRY
41
+ num_models = len(MODEL_REGISTRY)
42
+ if num_models > 0:
43
+ if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
44
+ # Assign a high weight to the strongest model (e.g., 50%)
45
+ strongest_weight_share = 0.5
46
+ self.base_weights = {strongest_model_id: strongest_weight_share}
47
+ remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
48
+ if remaining_models:
49
+ other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
50
+ for model_id in remaining_models:
51
+ self.base_weights[model_id] = other_models_weight_share
52
+ else: # Only one model, which is the strongest
53
+ self.base_weights[strongest_model_id] = 1.0
54
+ else:
55
+ initial_weight = 1.0 / num_models
56
+ self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
57
+ else:
58
+ self.base_weights = {} # Handle case with no registered models
59
+
60
  self.situation_weights = {
61
  "high_confidence": 1.2, # Boost weights for high confidence predictions
62
  "low_confidence": 0.8, # Reduce weights for low confidence
app.py CHANGED
@@ -5,6 +5,7 @@ import numpy as np
5
  import os
6
  import time
7
  import logging
 
8
 
9
  # Assuming these are available from your utils and agents directories
10
  # You might need to adjust paths or copy these functions/classes if they are not directly importable.
@@ -185,8 +186,14 @@ def postprocess_simple_prediction(result, class_names):
185
 
186
  def simple_prediction(img):
187
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
 
 
 
 
 
 
188
  result = client.predict(
189
- input_image=handle_file(img),
190
  api_name="/simple_predict"
191
  )
192
  return result
@@ -251,7 +258,7 @@ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotat
251
  raise ValueError("Input image could not be converted to PIL Image.")
252
 
253
  monitor_agent = EnsembleMonitorAgent()
254
- weight_manager = ModelWeightManager()
255
  optimization_agent = WeightOptimizationAgent(weight_manager)
256
  health_agent = SystemHealthAgent()
257
  context_agent = ContextualIntelligenceAgent()
 
5
  import os
6
  import time
7
  import logging
8
+ import io
9
 
10
  # Assuming these are available from your utils and agents directories
11
  # You might need to adjust paths or copy these functions/classes if they are not directly importable.
 
186
 
187
  def simple_prediction(img):
188
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
189
+
190
+ # Convert PIL Image to a file-like object in memory
191
+ img_byte_arr = io.BytesIO()
192
+ img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
193
+ img_byte_arr.seek(0) # Rewind to the beginning of the stream
194
+
195
  result = client.predict(
196
+ input_image=handle_file(img_byte_arr),
197
  api_name="/simple_predict"
198
  )
199
  return result
 
258
  raise ValueError("Input image could not be converted to PIL Image.")
259
 
260
  monitor_agent = EnsembleMonitorAgent()
261
+ weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
262
  optimization_agent = WeightOptimizationAgent(weight_manager)
263
  health_agent = SystemHealthAgent()
264
  context_agent = ContextualIntelligenceAgent()