LPX55
commited on
Commit
·
24db732
1
Parent(s):
9310be4
refactor(logging): replace logger with agent_logger for consistent logging in ModelWeightManager
Browse files- agents/ensemble_weights.py +12 -13
agents/ensemble_weights.py
CHANGED
@@ -4,7 +4,6 @@ from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
|
|
4 |
from utils.agent_logger import AgentLogger
|
5 |
|
6 |
agent_logger = AgentLogger()
|
7 |
-
logger = logging.getLogger(__name__)
|
8 |
|
9 |
class ContextualWeightOverrideAgent:
|
10 |
def __init__(self):
|
@@ -49,7 +48,7 @@ class ModelWeightManager:
|
|
49 |
num_models = len(MODEL_REGISTRY)
|
50 |
if num_models > 0:
|
51 |
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
|
52 |
-
|
53 |
# Assign a high weight to the strongest model (e.g., 50%)
|
54 |
strongest_weight_share = 0.5
|
55 |
self.base_weights = {strongest_model_id: strongest_weight_share}
|
@@ -62,12 +61,12 @@ class ModelWeightManager:
|
|
62 |
self.base_weights[strongest_model_id] = 1.0
|
63 |
else:
|
64 |
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
|
65 |
-
|
66 |
initial_weight = 1.0 / num_models
|
67 |
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
|
68 |
else:
|
69 |
self.base_weights = {} # Handle case with no registered models
|
70 |
-
|
71 |
|
72 |
self.situation_weights = {
|
73 |
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
@@ -85,7 +84,7 @@ class ModelWeightManager:
|
|
85 |
|
86 |
# 1. Apply contextual overrides first
|
87 |
if context_tags:
|
88 |
-
|
89 |
overrides = self.context_override_agent.get_overrides(context_tags)
|
90 |
for model_id, multiplier in overrides.items():
|
91 |
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
@@ -109,7 +108,7 @@ class ModelWeightManager:
|
|
109 |
agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}")
|
110 |
|
111 |
# Adjust based on confidence
|
112 |
-
|
113 |
for model, confidence in confidence_scores.items():
|
114 |
if confidence > 0.8:
|
115 |
adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
@@ -117,28 +116,28 @@ class ModelWeightManager:
|
|
117 |
elif confidence < 0.5:
|
118 |
adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
119 |
agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
|
120 |
-
|
121 |
|
122 |
normalized_weights = self._normalize_weights(adjusted_weights)
|
123 |
-
|
124 |
return normalized_weights
|
125 |
|
126 |
def _has_consensus(self, predictions):
|
127 |
"""Check if models agree on prediction"""
|
128 |
agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
|
129 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
130 |
-
|
131 |
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
132 |
-
|
133 |
return result
|
134 |
|
135 |
def _has_conflicts(self, predictions):
|
136 |
"""Check if models have conflicting predictions"""
|
137 |
agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
|
138 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
139 |
-
|
140 |
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
141 |
-
|
142 |
return result
|
143 |
|
144 |
def _normalize_weights(self, weights):
|
@@ -155,4 +154,4 @@ class ModelWeightManager:
|
|
155 |
return {} # No models registered
|
156 |
normalized = {k: v/total for k, v in weights.items()}
|
157 |
agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
|
158 |
-
return normalized
|
|
|
4 |
from utils.agent_logger import AgentLogger
|
5 |
|
6 |
agent_logger = AgentLogger()
|
|
|
7 |
|
8 |
class ContextualWeightOverrideAgent:
|
9 |
def __init__(self):
|
|
|
48 |
num_models = len(MODEL_REGISTRY)
|
49 |
if num_models > 0:
|
50 |
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
|
51 |
+
agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
|
52 |
# Assign a high weight to the strongest model (e.g., 50%)
|
53 |
strongest_weight_share = 0.5
|
54 |
self.base_weights = {strongest_model_id: strongest_weight_share}
|
|
|
61 |
self.base_weights[strongest_model_id] = 1.0
|
62 |
else:
|
63 |
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
|
64 |
+
agent_logger.log("weight_optimization", "warning", f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
|
65 |
initial_weight = 1.0 / num_models
|
66 |
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
|
67 |
else:
|
68 |
self.base_weights = {} # Handle case with no registered models
|
69 |
+
agent_logger.log("weight_optimization", "info", f"Base weights initialized: {self.base_weights}")
|
70 |
|
71 |
self.situation_weights = {
|
72 |
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
|
|
84 |
|
85 |
# 1. Apply contextual overrides first
|
86 |
if context_tags:
|
87 |
+
agent_logger.log("weight_optimization", "info", f"Applying contextual overrides for tags: {context_tags}")
|
88 |
overrides = self.context_override_agent.get_overrides(context_tags)
|
89 |
for model_id, multiplier in overrides.items():
|
90 |
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
|
|
108 |
agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}")
|
109 |
|
110 |
# Adjust based on confidence
|
111 |
+
agent_logger.log("weight_optimization", "info", "Adjusting weights based on model confidence scores.")
|
112 |
for model, confidence in confidence_scores.items():
|
113 |
if confidence > 0.8:
|
114 |
adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
|
|
116 |
elif confidence < 0.5:
|
117 |
adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
118 |
agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
|
119 |
+
agent_logger.log("weight_optimization", "info", f"Adjusted weights before normalization: {adjusted_weights}")
|
120 |
|
121 |
normalized_weights = self._normalize_weights(adjusted_weights)
|
122 |
+
agent_logger.log("weight_optimization", "info", f"Final normalized adjusted weights: {normalized_weights}")
|
123 |
return normalized_weights
|
124 |
|
125 |
def _has_consensus(self, predictions):
|
126 |
"""Check if models agree on prediction"""
|
127 |
agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
|
128 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
129 |
+
agent_logger.log("weight_optimization", "debug", f"Non-none predictions for consensus check: {non_none_predictions}")
|
130 |
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
131 |
+
agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
|
132 |
return result
|
133 |
|
134 |
def _has_conflicts(self, predictions):
|
135 |
"""Check if models have conflicting predictions"""
|
136 |
agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
|
137 |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
|
138 |
+
agent_logger.log("weight_optimization", "debug", f"Non-none predictions for conflict check: {non_none_predictions}")
|
139 |
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
140 |
+
agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
|
141 |
return result
|
142 |
|
143 |
def _normalize_weights(self, weights):
|
|
|
154 |
return {} # No models registered
|
155 |
normalized = {k: v/total for k, v in weights.items()}
|
156 |
agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
|
157 |
+
return normalized
|