LPX commited on
Commit
8f7f87a
·
1 Parent(s): 1af0cb5

make clear distinction between ensemble agents and smart agents

Browse files
utils/monitoring_agents.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import torch
4
+ import psutil # Ensure psutil is imported here as well
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class EnsembleMonitorAgent:
9
+ def __init__(self):
10
+ self.performance_metrics = {
11
+ "model_accuracy": {},
12
+ "response_times": {},
13
+ "confidence_distribution": {},
14
+ "consensus_rate": 0.0
15
+ }
16
+ self.alerts = []
17
+
18
+ def monitor_prediction(self, model_id, prediction, confidence, response_time):
19
+ """Monitor individual model performance"""
20
+ if model_id not in self.performance_metrics["model_accuracy"]:
21
+ self.performance_metrics["model_accuracy"][model_id] = []
22
+ self.performance_metrics["response_times"][model_id] = []
23
+ self.performance_metrics["confidence_distribution"][model_id] = []
24
+
25
+ self.performance_metrics["response_times"][model_id].append(response_time)
26
+ self.performance_metrics["confidence_distribution"][model_id].append(confidence)
27
+
28
+ # Check for performance issues
29
+ self._check_performance_issues(model_id)
30
+
31
+ def _check_performance_issues(self, model_id):
32
+ """Check for any performance anomalies"""
33
+ response_times = self.performance_metrics["response_times"][model_id]
34
+ if len(response_times) > 10:
35
+ avg_time = sum(response_times[-10:]) / 10
36
+ if avg_time > 2.0: # More than 2 seconds
37
+ self.alerts.append(f"High latency detected for {model_id}: {avg_time:.2f}s")
38
+
39
+ class WeightOptimizationAgent:
40
+ def __init__(self, weight_manager):
41
+ self.weight_manager = weight_manager
42
+ self.prediction_history = [] # Stores (ensemble_prediction_label, assumed_actual_label)
43
+ self.optimization_threshold = 0.05 # 5% change in accuracy triggers optimization
44
+ self.min_history_for_optimization = 20 # Minimum samples before optimizing
45
+
46
+ def analyze_performance(self, ensemble_prediction_label, actual_label=None):
47
+ """Analyze ensemble performance and record for optimization"""
48
+ # If actual_label is not provided, assume ensemble is correct if not UNCERTAIN
49
+ assumed_actual_label = actual_label
50
+ if assumed_actual_label is None and ensemble_prediction_label != "UNCERTAIN":
51
+ assumed_actual_label = ensemble_prediction_label
52
+
53
+ self.prediction_history.append((ensemble_prediction_label, assumed_actual_label))
54
+
55
+ if len(self.prediction_history) >= self.min_history_for_optimization and self._should_optimize():
56
+ self._optimize_weights()
57
+
58
+ def _calculate_accuracy(self, history_subset):
59
+ """Calculates accuracy based on history where actual_label is known."""
60
+ correct_predictions = 0
61
+ total_known = 0
62
+ for ensemble_pred, actual_label in history_subset:
63
+ if actual_label is not None:
64
+ total_known += 1
65
+ if ensemble_pred == actual_label:
66
+ correct_predictions += 1
67
+ return correct_predictions / total_known if total_known > 0 else 0.0
68
+
69
+ def _should_optimize(self):
70
+ """Determine if weights should be optimized based on recent performance change."""
71
+ if len(self.prediction_history) < self.min_history_for_optimization * 2: # Need enough history for comparison
72
+ return False
73
+
74
+ # Compare accuracy of recent batch with previous batch
75
+ recent_batch = self.prediction_history[-self.min_history_for_optimization:]
76
+ previous_batch = self.prediction_history[-self.min_history_for_optimization*2:-self.min_history_for_optimization]
77
+
78
+ recent_accuracy = self._calculate_accuracy(recent_batch)
79
+ previous_accuracy = self._calculate_accuracy(previous_batch)
80
+
81
+ # Trigger optimization if there's a significant drop in accuracy
82
+ if previous_accuracy > 0 and (previous_accuracy - recent_accuracy) / previous_accuracy > self.optimization_threshold:
83
+ logger.warning(f"Performance degradation detected (from {previous_accuracy:.2f} to {recent_accuracy:.2f}). Triggering weight optimization.")
84
+ return True
85
+ return False
86
+
87
+ def _optimize_weights(self):
88
+ """Optimize model weights based on performance."""
89
+ logger.info("Optimizing model weights based on recent performance.")
90
+ # Placeholder for sophisticated optimization logic.
91
+ # This is where you would adjust self.weight_manager.base_weights
92
+ # based on which models contributed more to correct predictions or errors.
93
+ # For now, it's just a log message.
94
+
95
+
96
+ class SystemHealthAgent:
97
+ def __init__(self):
98
+ self.health_metrics = {
99
+ "memory_usage": [],
100
+ "gpu_utilization": [],
101
+ "model_load_times": {},
102
+ "error_rates": {}
103
+ }
104
+
105
+ def monitor_system_health(self):
106
+ """Monitor overall system health"""
107
+ self._check_memory_usage()
108
+ self._check_gpu_utilization()
109
+ # You might add _check_model_health() here later
110
+
111
+ def _check_memory_usage(self):
112
+ """Monitor memory usage"""
113
+ try:
114
+ import psutil
115
+ memory = psutil.virtual_memory()
116
+ self.health_metrics["memory_usage"].append(memory.percent)
117
+
118
+ if memory.percent > 90:
119
+ logger.warning(f"High memory usage detected: {memory.percent}%")
120
+ except ImportError:
121
+ logger.warning("psutil not installed. Cannot monitor memory usage.")
122
+
123
+ def _check_gpu_utilization(self):
124
+ """Monitor GPU utilization if available"""
125
+ if torch.cuda.is_available():
126
+ try:
127
+ gpu_util = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
128
+ self.health_metrics["gpu_utilization"].append(gpu_util)
129
+
130
+ if gpu_util > 0.9:
131
+ logger.warning(f"High GPU utilization detected: {gpu_util*100:.2f}%")
132
+ except Exception as e:
133
+ logger.warning(f"Error monitoring GPU utilization: {e}")
134
+ else:
135
+ logger.info("CUDA not available. Skipping GPU utilization monitoring.")
utils/weight_management.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ class ContextualWeightOverrideAgent:
7
+ def __init__(self):
8
+ self.context_overrides = {
9
+ # Example: when image is outdoor, model_X is penalized, model_Y is boosted
10
+ "outdoor": {
11
+ "model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
12
+ "model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
13
+ },
14
+ "low_light": {
15
+ "model_2": 0.7,
16
+ "model_7": 1.3,
17
+ },
18
+ "sunny": {
19
+ "model_3": 0.9,
20
+ "model_4": 1.1,
21
+ }
22
+ # Add more contexts and their specific model weight adjustments here
23
+ }
24
+
25
+ def get_overrides(self, context_tags: list[str]) -> dict:
26
+ """Returns combined weight overrides for given context tags."""
27
+ combined_overrides = {}
28
+ for tag in context_tags:
29
+ if tag in self.context_overrides:
30
+ for model_id, multiplier in self.context_overrides[tag].items():
31
+ # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
32
+ # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
33
+ combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
34
+ return combined_overrides
35
+
36
+
37
+ class ModelWeightManager:
38
+ def __init__(self):
39
+ self.base_weights = {
40
+ "model_1": 0.15, # SwinV2 Based
41
+ "model_2": 0.15, # ViT Based
42
+ "model_3": 0.15, # SDXL Dataset
43
+ "model_4": 0.15, # SDXL + FLUX
44
+ "model_5": 0.15, # ViT Based
45
+ "model_5b": 0.10, # ViT Based, Newer Dataset
46
+ "model_6": 0.10, # Swin, Midj + SDXL
47
+ "model_7": 0.05 # ViT
48
+ }
49
+ self.situation_weights = {
50
+ "high_confidence": 1.2, # Boost weights for high confidence predictions
51
+ "low_confidence": 0.8, # Reduce weights for low confidence
52
+ "conflict": 0.5, # Reduce weights when models disagree
53
+ "consensus": 1.5 # Boost weights when models agree
54
+ }
55
+ self.context_override_agent = ContextualWeightOverrideAgent()
56
+
57
+ def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
58
+ """Dynamically adjust weights based on prediction patterns and optional context."""
59
+ adjusted_weights = self.base_weights.copy()
60
+
61
+ # 1. Apply contextual overrides first
62
+ if context_tags:
63
+ overrides = self.context_override_agent.get_overrides(context_tags)
64
+ for model_id, multiplier in overrides.items():
65
+ adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
66
+
67
+ # 2. Apply situation-based adjustments (consensus, conflict, confidence)
68
+ # Check for consensus
69
+ if self._has_consensus(predictions):
70
+ for model in adjusted_weights:
71
+ adjusted_weights[model] *= self.situation_weights["consensus"]
72
+
73
+ # Check for conflicts
74
+ if self._has_conflicts(predictions):
75
+ for model in adjusted_weights:
76
+ adjusted_weights[model] *= self.situation_weights["conflict"]
77
+
78
+ # Adjust based on confidence
79
+ for model, confidence in confidence_scores.items():
80
+ if confidence > 0.8:
81
+ adjusted_weights[model] *= self.situation_weights["high_confidence"]
82
+ elif confidence < 0.5:
83
+ adjusted_weights[model] *= self.situation_weights["low_confidence"]
84
+
85
+ return self._normalize_weights(adjusted_weights)
86
+
87
+ def _has_consensus(self, predictions):
88
+ """Check if models agree on prediction"""
89
+ # Ensure all predictions are not None before checking for consensus
90
+ non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
91
+ return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
92
+
93
+ def _has_conflicts(self, predictions):
94
+ """Check if models have conflicting predictions"""
95
+ # Ensure all predictions are not None before checking for conflicts
96
+ non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
97
+ return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
98
+
99
+ def _normalize_weights(self, weights):
100
+ """Normalize weights to sum to 1"""
101
+ total = sum(weights.values())
102
+ if total == 0:
103
+ # Handle case where all weights became zero due to aggressive multipliers
104
+ # This could assign equal weights or revert to base weights
105
+ logger.warning("All weights became zero after adjustments. Reverting to base weights.")
106
+ return {k: 1.0/len(self.base_weights) for k in self.base_weights}
107
+ return {k: v/total for k, v in weights.items()}