LPX55 commited on
Commit
401d88f
·
1 Parent(s): 25ba4f3

refactor(logging): replace logger.info with agent_logger.log for weight optimization messages

Browse files
Files changed (1) hide show
  1. agents/ensemble_weights.py +6 -3
agents/ensemble_weights.py CHANGED
@@ -3,11 +3,13 @@ import torch
3
  from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
4
  from utils.agent_logger import AgentLogger
5
 
 
6
  logger = logging.getLogger(__name__)
7
 
8
  class ContextualWeightOverrideAgent:
9
  def __init__(self):
10
- logger.info("Initializing ContextualWeightOverrideAgent.")
 
11
  self.context_overrides = {
12
  # Example: when image is outdoor, model_X is penalized, model_Y is boosted
13
  "outdoor": {
@@ -26,7 +28,7 @@ class ContextualWeightOverrideAgent:
26
  }
27
 
28
  def get_overrides(self, context_tags: list[str]) -> dict:
29
- logger.info(f"Getting weight overrides for context tags: {context_tags}")
30
  combined_overrides = {}
31
  for tag in context_tags:
32
  if tag in self.context_overrides:
@@ -34,10 +36,11 @@ class ContextualWeightOverrideAgent:
34
  # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
35
  # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
36
  combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
37
- logger.info(f"Combined context overrides: {combined_overrides}")
38
  return combined_overrides
39
 
40
 
 
41
  class ModelWeightManager:
42
  def __init__(self, strongest_model_id: str = None):
43
  agent_logger = AgentLogger()
 
3
  from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
4
  from utils.agent_logger import AgentLogger
5
 
6
+
7
  logger = logging.getLogger(__name__)
8
 
9
  class ContextualWeightOverrideAgent:
10
  def __init__(self):
11
+ agent_logger = AgentLogger()
12
+ agent_logger.log("weight_optimization", "info", "Initializing ContextualWeightOverrideAgent.")
13
  self.context_overrides = {
14
  # Example: when image is outdoor, model_X is penalized, model_Y is boosted
15
  "outdoor": {
 
28
  }
29
 
30
  def get_overrides(self, context_tags: list[str]) -> dict:
31
+ agent_logger.log("weight_optimization", "info", f"Getting weight overrides for context tags: {context_tags}")
32
  combined_overrides = {}
33
  for tag in context_tags:
34
  if tag in self.context_overrides:
 
36
  # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
37
  # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
38
  combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
39
+ agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}")
40
  return combined_overrides
41
 
42
 
43
+
44
  class ModelWeightManager:
45
  def __init__(self, strongest_model_id: str = None):
46
  agent_logger = AgentLogger()