feat: enhance image handling in predictions and dynamically configure model weights based on MODEL_REGISTRY
Browse files- agents/ensemble_weights.py +22 -11
- app.py +9 -2
agents/ensemble_weights.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import logging
|
2 |
import torch
|
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
@@ -35,17 +36,27 @@ class ContextualWeightOverrideAgent:
|
|
35 |
|
36 |
|
37 |
class ModelWeightManager:
|
38 |
-
def __init__(self):
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
self.situation_weights = {
|
50 |
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
51 |
"low_confidence": 0.8, # Reduce weights for low confidence
|
|
|
1 |
import logging
|
2 |
import torch
|
3 |
+
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
|
|
36 |
|
37 |
|
38 |
class ModelWeightManager:
|
39 |
+
def __init__(self, strongest_model_id: str = None):
|
40 |
+
# Dynamically initialize base_weights from MODEL_REGISTRY
|
41 |
+
num_models = len(MODEL_REGISTRY)
|
42 |
+
if num_models > 0:
|
43 |
+
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
|
44 |
+
# Assign a high weight to the strongest model (e.g., 50%)
|
45 |
+
strongest_weight_share = 0.5
|
46 |
+
self.base_weights = {strongest_model_id: strongest_weight_share}
|
47 |
+
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
|
48 |
+
if remaining_models:
|
49 |
+
other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
|
50 |
+
for model_id in remaining_models:
|
51 |
+
self.base_weights[model_id] = other_models_weight_share
|
52 |
+
else: # Only one model, which is the strongest
|
53 |
+
self.base_weights[strongest_model_id] = 1.0
|
54 |
+
else:
|
55 |
+
initial_weight = 1.0 / num_models
|
56 |
+
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
|
57 |
+
else:
|
58 |
+
self.base_weights = {} # Handle case with no registered models
|
59 |
+
|
60 |
self.situation_weights = {
|
61 |
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
62 |
"low_confidence": 0.8, # Reduce weights for low confidence
|
app.py
CHANGED
@@ -5,6 +5,7 @@ import numpy as np
|
|
5 |
import os
|
6 |
import time
|
7 |
import logging
|
|
|
8 |
|
9 |
# Assuming these are available from your utils and agents directories
|
10 |
# You might need to adjust paths or copy these functions/classes if they are not directly importable.
|
@@ -185,8 +186,14 @@ def postprocess_simple_prediction(result, class_names):
|
|
185 |
|
186 |
def simple_prediction(img):
|
187 |
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
result = client.predict(
|
189 |
-
input_image=handle_file(
|
190 |
api_name="/simple_predict"
|
191 |
)
|
192 |
return result
|
@@ -251,7 +258,7 @@ def ensemble_prediction_stream(img, confidence_threshold, augment_methods, rotat
|
|
251 |
raise ValueError("Input image could not be converted to PIL Image.")
|
252 |
|
253 |
monitor_agent = EnsembleMonitorAgent()
|
254 |
-
weight_manager = ModelWeightManager()
|
255 |
optimization_agent = WeightOptimizationAgent(weight_manager)
|
256 |
health_agent = SystemHealthAgent()
|
257 |
context_agent = ContextualIntelligenceAgent()
|
|
|
5 |
import os
|
6 |
import time
|
7 |
import logging
|
8 |
+
import io
|
9 |
|
10 |
# Assuming these are available from your utils and agents directories
|
11 |
# You might need to adjust paths or copy these functions/classes if they are not directly importable.
|
|
|
186 |
|
187 |
def simple_prediction(img):
|
188 |
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
189 |
+
|
190 |
+
# Convert PIL Image to a file-like object in memory
|
191 |
+
img_byte_arr = io.BytesIO()
|
192 |
+
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
193 |
+
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
194 |
+
|
195 |
result = client.predict(
|
196 |
+
input_image=handle_file(img_byte_arr),
|
197 |
api_name="/simple_predict"
|
198 |
)
|
199 |
return result
|
|
|
258 |
raise ValueError("Input image could not be converted to PIL Image.")
|
259 |
|
260 |
monitor_agent = EnsembleMonitorAgent()
|
261 |
+
weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
|
262 |
optimization_agent = WeightOptimizationAgent(weight_manager)
|
263 |
health_agent = SystemHealthAgent()
|
264 |
context_agent = ContextualIntelligenceAgent()
|