|
import logging
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class ContextualWeightOverrideAgent:
|
|
def __init__(self):
|
|
self.context_overrides = {
|
|
|
|
"outdoor": {
|
|
"model_1": 0.8,
|
|
"model_5": 1.2,
|
|
},
|
|
"low_light": {
|
|
"model_2": 0.7,
|
|
"model_7": 1.3,
|
|
},
|
|
"sunny": {
|
|
"model_3": 0.9,
|
|
"model_4": 1.1,
|
|
}
|
|
|
|
}
|
|
|
|
def get_overrides(self, context_tags: list[str]) -> dict:
|
|
"""Returns combined weight overrides for given context tags."""
|
|
combined_overrides = {}
|
|
for tag in context_tags:
|
|
if tag in self.context_overrides:
|
|
for model_id, multiplier in self.context_overrides[tag].items():
|
|
|
|
|
|
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
|
|
return combined_overrides
|
|
|
|
|
|
class ModelWeightManager:
|
|
def __init__(self):
|
|
self.base_weights = {
|
|
"model_1": 0.15,
|
|
"model_2": 0.15,
|
|
"model_3": 0.15,
|
|
"model_4": 0.15,
|
|
"model_5": 0.15,
|
|
"model_5b": 0.10,
|
|
"model_6": 0.10,
|
|
"model_7": 0.05
|
|
}
|
|
self.situation_weights = {
|
|
"high_confidence": 1.2,
|
|
"low_confidence": 0.8,
|
|
"conflict": 0.5,
|
|
"consensus": 1.5
|
|
}
|
|
self.context_override_agent = ContextualWeightOverrideAgent()
|
|
|
|
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
|
|
"""Dynamically adjust weights based on prediction patterns and optional context."""
|
|
adjusted_weights = self.base_weights.copy()
|
|
|
|
|
|
if context_tags:
|
|
overrides = self.context_override_agent.get_overrides(context_tags)
|
|
for model_id, multiplier in overrides.items():
|
|
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
|
|
|
|
|
|
|
if self._has_consensus(predictions):
|
|
for model in adjusted_weights:
|
|
adjusted_weights[model] *= self.situation_weights["consensus"]
|
|
|
|
|
|
if self._has_conflicts(predictions):
|
|
for model in adjusted_weights:
|
|
adjusted_weights[model] *= self.situation_weights["conflict"]
|
|
|
|
|
|
for model, confidence in confidence_scores.items():
|
|
if confidence > 0.8:
|
|
adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
|
elif confidence < 0.5:
|
|
adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
|
|
|
return self._normalize_weights(adjusted_weights)
|
|
|
|
def _has_consensus(self, predictions):
|
|
"""Check if models agree on prediction"""
|
|
|
|
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
|
return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
|
|
|
def _has_conflicts(self, predictions):
|
|
"""Check if models have conflicting predictions"""
|
|
|
|
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
|
return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
|
|
|
def _normalize_weights(self, weights):
|
|
"""Normalize weights to sum to 1"""
|
|
total = sum(weights.values())
|
|
if total == 0:
|
|
|
|
|
|
logger.warning("All weights became zero after adjustments. Reverting to base weights.")
|
|
return {k: 1.0/len(self.base_weights) for k in self.base_weights}
|
|
return {k: v/total for k, v in weights.items()} |