LPX55 commited on
Commit
ce372d3
·
1 Parent(s): c378b41

refactor(weights): update strongest model ID and adjust queue concurrency limit

Browse files
Files changed (2) hide show
  1. agents/ensemble_weights.py +5 -7
  2. app.py +2 -2
agents/ensemble_weights.py CHANGED
@@ -38,8 +38,6 @@ class ContextualWeightOverrideAgent:
38
  agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
39
  return combined_overrides
40
 
41
-
42
-
43
  class ModelWeightManager:
44
  def __init__(self, strongest_model_id: str = None):
45
  agent_logger = AgentLogger()
@@ -49,8 +47,8 @@ class ModelWeightManager:
49
  if num_models > 0:
50
  if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
51
  agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
52
- # Assign a high weight to the strongest model (e.g., 50%)
53
- strongest_weight_share = 0.5
54
  self.base_weights = {strongest_model_id: strongest_weight_share}
55
  remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
56
  if remaining_models:
@@ -126,7 +124,7 @@ class ModelWeightManager:
126
  """Check if models agree on prediction"""
127
  agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
128
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
129
- agent_logger.log("weight_optimization", "debug", f"Non-none predictions for consensus check: {non_none_predictions}")
130
  result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
131
  agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
132
  return result
@@ -135,7 +133,7 @@ class ModelWeightManager:
135
  """Check if models have conflicting predictions"""
136
  agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
137
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
138
- agent_logger.log("weight_optimization", "debug", f"Non-none predictions for conflict check: {non_none_predictions}")
139
  result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
140
  agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
141
  return result
@@ -154,4 +152,4 @@ class ModelWeightManager:
154
  return {} # No models registered
155
  normalized = {k: v/total for k, v in weights.items()}
156
  agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
157
- return normalized
 
38
  agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
39
  return combined_overrides
40
 
 
 
41
  class ModelWeightManager:
42
  def __init__(self, strongest_model_id: str = None):
43
  agent_logger = AgentLogger()
 
47
  if num_models > 0:
48
  if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
49
  agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
50
+ # Assign a high weight to the strongest model (e.g., 40%)
51
+ strongest_weight_share = 0.4
52
  self.base_weights = {strongest_model_id: strongest_weight_share}
53
  remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
54
  if remaining_models:
 
124
  """Check if models agree on prediction"""
125
  agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
126
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
127
+ agent_logger.debug("weight_optimization", "info", f"Non-none predictions for consensus check: {non_none_predictions}")
128
  result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
129
  agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
130
  return result
 
133
  """Check if models have conflicting predictions"""
134
  agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
135
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
136
+ agent_logger.debug("weight_optimization", "info", f"Non-none predictions for conflict check: {non_none_predictions}")
137
  result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
138
  agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
139
  return result
 
152
  return {} # No models registered
153
  normalized = {k: v/total for k, v in weights.items()}
154
  agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
155
+ return normalized
app.py CHANGED
@@ -188,7 +188,7 @@ def full_prediction(img, confidence_threshold, rotate_degrees, noise_level, shar
188
  img = img.convert('RGB')
189
 
190
  monitor_agent = EnsembleMonitorAgent()
191
- weight_manager = ModelWeightManager(strongest_model_id="simple_prediction")
192
  optimization_agent = WeightOptimizationAgent(weight_manager)
193
  health_agent = SystemHealthAgent()
194
  context_agent = ContextualIntelligenceAgent()
@@ -679,4 +679,4 @@ with gr.Blocks() as app:
679
  footer.render()
680
 
681
 
682
- app.queue(max_size=10, default_concurrency_limit=2).launch(mcp_server=True)
 
188
  img = img.convert('RGB')
189
 
190
  monitor_agent = EnsembleMonitorAgent()
191
+ weight_manager = ModelWeightManager(strongest_model_id="model_8")
192
  optimization_agent = WeightOptimizationAgent(weight_manager)
193
  health_agent = SystemHealthAgent()
194
  context_agent = ContextualIntelligenceAgent()
 
679
  footer.render()
680
 
681
 
682
+ app.queue(max_size=10, default_concurrency_limit=1).launch(mcp_server=True)