LPX55
refactor(logging): replace logger with agent_logger for consistent logging in ModelWeightManager
24db732
import logging | |
import torch | |
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY | |
from utils.agent_logger import AgentLogger | |
agent_logger = AgentLogger() | |
class ContextualWeightOverrideAgent: | |
def __init__(self): | |
agent_logger = AgentLogger() | |
agent_logger.log("weight_optimization", "info", "Initializing ContextualWeightOverrideAgent.") | |
self.context_overrides = { | |
# Example: when image is outdoor, model_X is penalized, model_Y is boosted | |
"outdoor": { | |
"model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes | |
"model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes | |
}, | |
"low_light": { | |
"model_2": 0.7, | |
"model_7": 1.3, | |
}, | |
"sunny": { | |
"model_3": 0.9, | |
"model_4": 1.1, | |
} | |
# Add more contexts and their specific model weight adjustments here | |
} | |
def get_overrides(self, context_tags: list[str]) -> dict: | |
agent_logger.log("weight_optimization", "info", f"Getting weight overrides for context tags: {context_tags}") | |
combined_overrides = {} | |
for tag in context_tags: | |
if tag in self.context_overrides: | |
for model_id, multiplier in self.context_overrides[tag].items(): | |
# If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max) | |
# For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect. | |
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier | |
agent_logger.log("weight_optimization", "info", f"Combined context overrides: {combined_overrides}") | |
return combined_overrides | |
class ModelWeightManager: | |
def __init__(self, strongest_model_id: str = None): | |
agent_logger = AgentLogger() | |
agent_logger.log("weight_optimization", "info", f"Initializing ModelWeightManager. Strongest model: {strongest_model_id}") | |
# Dynamically initialize base_weights from MODEL_REGISTRY | |
num_models = len(MODEL_REGISTRY) | |
if num_models > 0: | |
if strongest_model_id and strongest_model_id in MODEL_REGISTRY: | |
agent_logger.log("weight_optimization", "info", f"Designating '{strongest_model_id}' as the strongest model.") | |
# Assign a high weight to the strongest model (e.g., 50%) | |
strongest_weight_share = 0.5 | |
self.base_weights = {strongest_model_id: strongest_weight_share} | |
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id] | |
if remaining_models: | |
other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models) | |
for model_id in remaining_models: | |
self.base_weights[model_id] = other_models_weight_share | |
else: # Only one model, which is the strongest | |
self.base_weights[strongest_model_id] = 1.0 | |
else: | |
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY: | |
agent_logger.log("weight_optimization", "warning", f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.") | |
initial_weight = 1.0 / num_models | |
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()} | |
else: | |
self.base_weights = {} # Handle case with no registered models | |
agent_logger.log("weight_optimization", "info", f"Base weights initialized: {self.base_weights}") | |
self.situation_weights = { | |
"high_confidence": 1.2, # Boost weights for high confidence predictions | |
"low_confidence": 0.8, # Reduce weights for low confidence | |
"conflict": 0.5, # Reduce weights when models disagree | |
"consensus": 1.5 # Boost weights when models agree | |
} | |
self.context_override_agent = ContextualWeightOverrideAgent() | |
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): | |
"""Dynamically adjust weights based on prediction patterns and optional context.""" | |
agent_logger.log("weight_optimization", "info", "Adjusting model weights.") | |
adjusted_weights = self.base_weights.copy() | |
agent_logger.log("weight_optimization", "info", f"Initial adjusted weights (copy of base): {adjusted_weights}") | |
# 1. Apply contextual overrides first | |
if context_tags: | |
agent_logger.log("weight_optimization", "info", f"Applying contextual overrides for tags: {context_tags}") | |
overrides = self.context_override_agent.get_overrides(context_tags) | |
for model_id, multiplier in overrides.items(): | |
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier | |
agent_logger.log("weight_optimization", "info", f"Adjusted weights after context overrides: {adjusted_weights}") | |
# 2. Apply situation-based adjustments (consensus, conflict, confidence) | |
# Check for consensus | |
has_consensus = self._has_consensus(predictions) | |
if has_consensus: | |
agent_logger.log("weight_optimization", "info", "Consensus detected. Boosting weights for consensus.") | |
for model in adjusted_weights: | |
adjusted_weights[model] *= self.situation_weights["consensus"] | |
agent_logger.log("weight_optimization", "info", f"Adjusted weights after consensus boost: {adjusted_weights}") | |
# Check for conflicts | |
has_conflicts = self._has_conflicts(predictions) | |
if has_conflicts: | |
agent_logger.log("weight_optimization", "info", "Conflicts detected. Reducing weights for conflict.") | |
for model in adjusted_weights: | |
adjusted_weights[model] *= self.situation_weights["conflict"] | |
agent_logger.log("weight_optimization", "info", f"Adjusted weights after conflict reduction: {adjusted_weights}") | |
# Adjust based on confidence | |
agent_logger.log("weight_optimization", "info", "Adjusting weights based on model confidence scores.") | |
for model, confidence in confidence_scores.items(): | |
if confidence > 0.8: | |
adjusted_weights[model] *= self.situation_weights["high_confidence"] | |
agent_logger.log("weight_optimization", "info", f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.") | |
elif confidence < 0.5: | |
adjusted_weights[model] *= self.situation_weights["low_confidence"] | |
agent_logger.log("weight_optimization", "info", f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.") | |
agent_logger.log("weight_optimization", "info", f"Adjusted weights before normalization: {adjusted_weights}") | |
normalized_weights = self._normalize_weights(adjusted_weights) | |
agent_logger.log("weight_optimization", "info", f"Final normalized adjusted weights: {normalized_weights}") | |
return normalized_weights | |
def _has_consensus(self, predictions): | |
"""Check if models agree on prediction""" | |
agent_logger.log("weight_optimization", "info", "Checking for consensus among model predictions.") | |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] | |
agent_logger.log("weight_optimization", "debug", f"Non-none predictions for consensus check: {non_none_predictions}") | |
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 | |
agent_logger.log("weight_optimization", "info", f"Consensus detected: {result}") | |
return result | |
def _has_conflicts(self, predictions): | |
"""Check if models have conflicting predictions""" | |
agent_logger.log("weight_optimization", "info", "Checking for conflicts among model predictions.") | |
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] | |
agent_logger.log("weight_optimization", "debug", f"Non-none predictions for conflict check: {non_none_predictions}") | |
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 | |
agent_logger.log("weight_optimization", "info", f"Conflicts detected: {result}") | |
return result | |
def _normalize_weights(self, weights): | |
"""Normalize weights to sum to 1""" | |
agent_logger.log("weight_optimization", "info", "Normalizing weights.") | |
total = sum(weights.values()) | |
if total == 0: | |
agent_logger.log("weight_optimization", "warning", "All weights became zero after adjustments. Reverting to equal base weights for registered models.") | |
# Revert to equal weights for all *registered* models if total becomes zero | |
num_registered_models = len(MODEL_REGISTRY) | |
if num_registered_models > 0: | |
return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()} | |
else: | |
return {} # No models registered | |
normalized = {k: v/total for k, v in weights.items()} | |
agent_logger.log("weight_optimization", "info", f"Weights normalized. Total sum: {sum(normalized.values()):.2f}") | |
return normalized |