LPX55 commited on
Commit
24db732
·
1 Parent(s): 9310be4

refactor(logging): replace logger with agent_logger for consistent logging in ModelWeightManager

Browse files
Files changed (1) hide show
  1. agents/ensemble_weights.py +12 -13
agents/ensemble_weights.py CHANGED
@@ -4,7 +4,6 @@ from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
4
  from utils.agent_logger import AgentLogger
5
 
6
  agent_logger = AgentLogger()
7
- logger = logging.getLogger(__name__)
8
 
9
  class ContextualWeightOverrideAgent:
10
  def __init__(self):
@@ -49,7 +48,7 @@ class ModelWeightManager:
49
  num_models = len(MODEL_REGISTRY)
50
  if num_models > 0:
51
  if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
52
- logger.info(f"Designating '{strongest_model_id}' as the strongest model.")
53
  # Assign a high weight to the strongest model (e.g., 50%)
54
  strongest_weight_share = 0.5
55
  self.base_weights = {strongest_model_id: strongest_weight_share}
@@ -62,12 +61,12 @@ class ModelWeightManager:
62
  self.base_weights[strongest_model_id] = 1.0
63
  else:
64
  if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
65
- logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
66
  initial_weight = 1.0 / num_models
67
  self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
68
  else:
69
  self.base_weights = {} # Handle case with no registered models
70
- logger.info(f"Base weights initialized: {self.base_weights}")
71
 
72
  self.situation_weights = {
73
  "high_confidence": 1.2, # Boost weights for high confidence predictions
@@ -85,7 +84,7 @@ class ModelWeightManager:
85
 
86
  # 1. Apply contextual overrides first
87
  if context_tags:
88
- logger.info(f"Applying contextual overrides for tags: {context_tags}")
89
  overrides = self.context_override_agent.get_overrides(context_tags)
90
  for model_id, multiplier in overrides.items():
91
  adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
@@ -109,7 +108,7 @@ class ModelWeightManager:
109
  agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}")
110
 
111
  # Adjust based on confidence
112
- logger.info("Adjusting weights based on model confidence scores.")
113
  for model, confidence in confidence_scores.items():
114
  if confidence > 0.8:
115
  adjusted_weights[model] *= self.situation_weights["high_confidence"]
@@ -117,28 +116,28 @@ class ModelWeightManager:
117
  elif confidence < 0.5:
118
  adjusted_weights[model] *= self.situation_weights["low_confidence"]
119
  agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
120
- logger.info(f"Adjusted weights before normalization: {adjusted_weights}")
121
 
122
  normalized_weights = self._normalize_weights(adjusted_weights)
123
- logger.info(f"Final normalized adjusted weights: {normalized_weights}")
124
  return normalized_weights
125
 
126
  def _has_consensus(self, predictions):
127
  """Check if models agree on prediction"""
128
  agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
129
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
130
- logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}")
131
  result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
132
- logger.info(f"Consensus detected: {result}")
133
  return result
134
 
135
  def _has_conflicts(self, predictions):
136
  """Check if models have conflicting predictions"""
137
  agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
138
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
139
- logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}")
140
  result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
141
- logger.info(f"Conflicts detected: {result}")
142
  return result
143
 
144
  def _normalize_weights(self, weights):
@@ -155,4 +154,4 @@ class ModelWeightManager:
155
  return {} # No models registered
156
  normalized = {k: v/total for k, v in weights.items()}
157
  agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
158
- return normalized
 
4
  from utils.agent_logger import AgentLogger
5
 
6
  agent_logger = AgentLogger()
 
7
 
8
  class ContextualWeightOverrideAgent:
9
  def __init__(self):
 
48
  num_models = len(MODEL_REGISTRY)
49
  if num_models > 0:
50
  if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
51
+ agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.")
52
  # Assign a high weight to the strongest model (e.g., 50%)
53
  strongest_weight_share = 0.5
54
  self.base_weights = {strongest_model_id: strongest_weight_share}
 
61
  self.base_weights[strongest_model_id] = 1.0
62
  else:
63
  if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
64
+ agent_logger.log("weight_optimization", "warning", f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
65
  initial_weight = 1.0 / num_models
66
  self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
67
  else:
68
  self.base_weights = {} # Handle case with no registered models
69
+ agent_logger.log("weight_optimization", "info", f"Base weights initialized: {self.base_weights}")
70
 
71
  self.situation_weights = {
72
  "high_confidence": 1.2, # Boost weights for high confidence predictions
 
84
 
85
  # 1. Apply contextual overrides first
86
  if context_tags:
87
+ agent_logger.log("weight_optimization", "info", f"Applying contextual overrides for tags: {context_tags}")
88
  overrides = self.context_override_agent.get_overrides(context_tags)
89
  for model_id, multiplier in overrides.items():
90
  adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
 
108
  agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}")
109
 
110
  # Adjust based on confidence
111
+ agent_logger.log("weight_optimization", "info", "Adjusting weights based on model confidence scores.")
112
  for model, confidence in confidence_scores.items():
113
  if confidence > 0.8:
114
  adjusted_weights[model] *= self.situation_weights["high_confidence"]
 
116
  elif confidence < 0.5:
117
  adjusted_weights[model] *= self.situation_weights["low_confidence"]
118
  agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
119
+ agent_logger.log("weight_optimization", "info", f"Adjusted weights before normalization: {adjusted_weights}")
120
 
121
  normalized_weights = self._normalize_weights(adjusted_weights)
122
+ agent_logger.log("weight_optimization", "info", f"Final normalized adjusted weights: {normalized_weights}")
123
  return normalized_weights
124
 
125
  def _has_consensus(self, predictions):
126
  """Check if models agree on prediction"""
127
  agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.")
128
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
129
+ agent_logger.log("weight_optimization", "debug", f"Non-none predictions for consensus check: {non_none_predictions}")
130
  result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
131
+ agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}")
132
  return result
133
 
134
  def _has_conflicts(self, predictions):
135
  """Check if models have conflicting predictions"""
136
  agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.")
137
  non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
138
+ agent_logger.log("weight_optimization", "debug", f"Non-none predictions for conflict check: {non_none_predictions}")
139
  result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
140
+ agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}")
141
  return result
142
 
143
  def _normalize_weights(self, weights):
 
154
  return {} # No models registered
155
  normalized = {k: v/total for k, v in weights.items()}
156
  agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
157
+ return normalized