LPX
commited on
Commit
·
8f7f87a
1
Parent(s):
1af0cb5
make clear distinction between ensemble agents and smart agents
Browse files- utils/monitoring_agents.py +135 -0
- utils/weight_management.py +107 -0
utils/monitoring_agents.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import psutil # Ensure psutil is imported here as well
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class EnsembleMonitorAgent:
|
9 |
+
def __init__(self):
|
10 |
+
self.performance_metrics = {
|
11 |
+
"model_accuracy": {},
|
12 |
+
"response_times": {},
|
13 |
+
"confidence_distribution": {},
|
14 |
+
"consensus_rate": 0.0
|
15 |
+
}
|
16 |
+
self.alerts = []
|
17 |
+
|
18 |
+
def monitor_prediction(self, model_id, prediction, confidence, response_time):
|
19 |
+
"""Monitor individual model performance"""
|
20 |
+
if model_id not in self.performance_metrics["model_accuracy"]:
|
21 |
+
self.performance_metrics["model_accuracy"][model_id] = []
|
22 |
+
self.performance_metrics["response_times"][model_id] = []
|
23 |
+
self.performance_metrics["confidence_distribution"][model_id] = []
|
24 |
+
|
25 |
+
self.performance_metrics["response_times"][model_id].append(response_time)
|
26 |
+
self.performance_metrics["confidence_distribution"][model_id].append(confidence)
|
27 |
+
|
28 |
+
# Check for performance issues
|
29 |
+
self._check_performance_issues(model_id)
|
30 |
+
|
31 |
+
def _check_performance_issues(self, model_id):
|
32 |
+
"""Check for any performance anomalies"""
|
33 |
+
response_times = self.performance_metrics["response_times"][model_id]
|
34 |
+
if len(response_times) > 10:
|
35 |
+
avg_time = sum(response_times[-10:]) / 10
|
36 |
+
if avg_time > 2.0: # More than 2 seconds
|
37 |
+
self.alerts.append(f"High latency detected for {model_id}: {avg_time:.2f}s")
|
38 |
+
|
39 |
+
class WeightOptimizationAgent:
|
40 |
+
def __init__(self, weight_manager):
|
41 |
+
self.weight_manager = weight_manager
|
42 |
+
self.prediction_history = [] # Stores (ensemble_prediction_label, assumed_actual_label)
|
43 |
+
self.optimization_threshold = 0.05 # 5% change in accuracy triggers optimization
|
44 |
+
self.min_history_for_optimization = 20 # Minimum samples before optimizing
|
45 |
+
|
46 |
+
def analyze_performance(self, ensemble_prediction_label, actual_label=None):
|
47 |
+
"""Analyze ensemble performance and record for optimization"""
|
48 |
+
# If actual_label is not provided, assume ensemble is correct if not UNCERTAIN
|
49 |
+
assumed_actual_label = actual_label
|
50 |
+
if assumed_actual_label is None and ensemble_prediction_label != "UNCERTAIN":
|
51 |
+
assumed_actual_label = ensemble_prediction_label
|
52 |
+
|
53 |
+
self.prediction_history.append((ensemble_prediction_label, assumed_actual_label))
|
54 |
+
|
55 |
+
if len(self.prediction_history) >= self.min_history_for_optimization and self._should_optimize():
|
56 |
+
self._optimize_weights()
|
57 |
+
|
58 |
+
def _calculate_accuracy(self, history_subset):
|
59 |
+
"""Calculates accuracy based on history where actual_label is known."""
|
60 |
+
correct_predictions = 0
|
61 |
+
total_known = 0
|
62 |
+
for ensemble_pred, actual_label in history_subset:
|
63 |
+
if actual_label is not None:
|
64 |
+
total_known += 1
|
65 |
+
if ensemble_pred == actual_label:
|
66 |
+
correct_predictions += 1
|
67 |
+
return correct_predictions / total_known if total_known > 0 else 0.0
|
68 |
+
|
69 |
+
def _should_optimize(self):
|
70 |
+
"""Determine if weights should be optimized based on recent performance change."""
|
71 |
+
if len(self.prediction_history) < self.min_history_for_optimization * 2: # Need enough history for comparison
|
72 |
+
return False
|
73 |
+
|
74 |
+
# Compare accuracy of recent batch with previous batch
|
75 |
+
recent_batch = self.prediction_history[-self.min_history_for_optimization:]
|
76 |
+
previous_batch = self.prediction_history[-self.min_history_for_optimization*2:-self.min_history_for_optimization]
|
77 |
+
|
78 |
+
recent_accuracy = self._calculate_accuracy(recent_batch)
|
79 |
+
previous_accuracy = self._calculate_accuracy(previous_batch)
|
80 |
+
|
81 |
+
# Trigger optimization if there's a significant drop in accuracy
|
82 |
+
if previous_accuracy > 0 and (previous_accuracy - recent_accuracy) / previous_accuracy > self.optimization_threshold:
|
83 |
+
logger.warning(f"Performance degradation detected (from {previous_accuracy:.2f} to {recent_accuracy:.2f}). Triggering weight optimization.")
|
84 |
+
return True
|
85 |
+
return False
|
86 |
+
|
87 |
+
def _optimize_weights(self):
|
88 |
+
"""Optimize model weights based on performance."""
|
89 |
+
logger.info("Optimizing model weights based on recent performance.")
|
90 |
+
# Placeholder for sophisticated optimization logic.
|
91 |
+
# This is where you would adjust self.weight_manager.base_weights
|
92 |
+
# based on which models contributed more to correct predictions or errors.
|
93 |
+
# For now, it's just a log message.
|
94 |
+
|
95 |
+
|
96 |
+
class SystemHealthAgent:
|
97 |
+
def __init__(self):
|
98 |
+
self.health_metrics = {
|
99 |
+
"memory_usage": [],
|
100 |
+
"gpu_utilization": [],
|
101 |
+
"model_load_times": {},
|
102 |
+
"error_rates": {}
|
103 |
+
}
|
104 |
+
|
105 |
+
def monitor_system_health(self):
|
106 |
+
"""Monitor overall system health"""
|
107 |
+
self._check_memory_usage()
|
108 |
+
self._check_gpu_utilization()
|
109 |
+
# You might add _check_model_health() here later
|
110 |
+
|
111 |
+
def _check_memory_usage(self):
|
112 |
+
"""Monitor memory usage"""
|
113 |
+
try:
|
114 |
+
import psutil
|
115 |
+
memory = psutil.virtual_memory()
|
116 |
+
self.health_metrics["memory_usage"].append(memory.percent)
|
117 |
+
|
118 |
+
if memory.percent > 90:
|
119 |
+
logger.warning(f"High memory usage detected: {memory.percent}%")
|
120 |
+
except ImportError:
|
121 |
+
logger.warning("psutil not installed. Cannot monitor memory usage.")
|
122 |
+
|
123 |
+
def _check_gpu_utilization(self):
|
124 |
+
"""Monitor GPU utilization if available"""
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
try:
|
127 |
+
gpu_util = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
|
128 |
+
self.health_metrics["gpu_utilization"].append(gpu_util)
|
129 |
+
|
130 |
+
if gpu_util > 0.9:
|
131 |
+
logger.warning(f"High GPU utilization detected: {gpu_util*100:.2f}%")
|
132 |
+
except Exception as e:
|
133 |
+
logger.warning(f"Error monitoring GPU utilization: {e}")
|
134 |
+
else:
|
135 |
+
logger.info("CUDA not available. Skipping GPU utilization monitoring.")
|
utils/weight_management.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import torch
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
|
6 |
+
class ContextualWeightOverrideAgent:
|
7 |
+
def __init__(self):
|
8 |
+
self.context_overrides = {
|
9 |
+
# Example: when image is outdoor, model_X is penalized, model_Y is boosted
|
10 |
+
"outdoor": {
|
11 |
+
"model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
|
12 |
+
"model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
|
13 |
+
},
|
14 |
+
"low_light": {
|
15 |
+
"model_2": 0.7,
|
16 |
+
"model_7": 1.3,
|
17 |
+
},
|
18 |
+
"sunny": {
|
19 |
+
"model_3": 0.9,
|
20 |
+
"model_4": 1.1,
|
21 |
+
}
|
22 |
+
# Add more contexts and their specific model weight adjustments here
|
23 |
+
}
|
24 |
+
|
25 |
+
def get_overrides(self, context_tags: list[str]) -> dict:
|
26 |
+
"""Returns combined weight overrides for given context tags."""
|
27 |
+
combined_overrides = {}
|
28 |
+
for tag in context_tags:
|
29 |
+
if tag in self.context_overrides:
|
30 |
+
for model_id, multiplier in self.context_overrides[tag].items():
|
31 |
+
# If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
|
32 |
+
# For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
|
33 |
+
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
|
34 |
+
return combined_overrides
|
35 |
+
|
36 |
+
|
37 |
+
class ModelWeightManager:
|
38 |
+
def __init__(self):
|
39 |
+
self.base_weights = {
|
40 |
+
"model_1": 0.15, # SwinV2 Based
|
41 |
+
"model_2": 0.15, # ViT Based
|
42 |
+
"model_3": 0.15, # SDXL Dataset
|
43 |
+
"model_4": 0.15, # SDXL + FLUX
|
44 |
+
"model_5": 0.15, # ViT Based
|
45 |
+
"model_5b": 0.10, # ViT Based, Newer Dataset
|
46 |
+
"model_6": 0.10, # Swin, Midj + SDXL
|
47 |
+
"model_7": 0.05 # ViT
|
48 |
+
}
|
49 |
+
self.situation_weights = {
|
50 |
+
"high_confidence": 1.2, # Boost weights for high confidence predictions
|
51 |
+
"low_confidence": 0.8, # Reduce weights for low confidence
|
52 |
+
"conflict": 0.5, # Reduce weights when models disagree
|
53 |
+
"consensus": 1.5 # Boost weights when models agree
|
54 |
+
}
|
55 |
+
self.context_override_agent = ContextualWeightOverrideAgent()
|
56 |
+
|
57 |
+
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
|
58 |
+
"""Dynamically adjust weights based on prediction patterns and optional context."""
|
59 |
+
adjusted_weights = self.base_weights.copy()
|
60 |
+
|
61 |
+
# 1. Apply contextual overrides first
|
62 |
+
if context_tags:
|
63 |
+
overrides = self.context_override_agent.get_overrides(context_tags)
|
64 |
+
for model_id, multiplier in overrides.items():
|
65 |
+
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
|
66 |
+
|
67 |
+
# 2. Apply situation-based adjustments (consensus, conflict, confidence)
|
68 |
+
# Check for consensus
|
69 |
+
if self._has_consensus(predictions):
|
70 |
+
for model in adjusted_weights:
|
71 |
+
adjusted_weights[model] *= self.situation_weights["consensus"]
|
72 |
+
|
73 |
+
# Check for conflicts
|
74 |
+
if self._has_conflicts(predictions):
|
75 |
+
for model in adjusted_weights:
|
76 |
+
adjusted_weights[model] *= self.situation_weights["conflict"]
|
77 |
+
|
78 |
+
# Adjust based on confidence
|
79 |
+
for model, confidence in confidence_scores.items():
|
80 |
+
if confidence > 0.8:
|
81 |
+
adjusted_weights[model] *= self.situation_weights["high_confidence"]
|
82 |
+
elif confidence < 0.5:
|
83 |
+
adjusted_weights[model] *= self.situation_weights["low_confidence"]
|
84 |
+
|
85 |
+
return self._normalize_weights(adjusted_weights)
|
86 |
+
|
87 |
+
def _has_consensus(self, predictions):
|
88 |
+
"""Check if models agree on prediction"""
|
89 |
+
# Ensure all predictions are not None before checking for consensus
|
90 |
+
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
|
91 |
+
return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
|
92 |
+
|
93 |
+
def _has_conflicts(self, predictions):
|
94 |
+
"""Check if models have conflicting predictions"""
|
95 |
+
# Ensure all predictions are not None before checking for conflicts
|
96 |
+
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"]
|
97 |
+
return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
|
98 |
+
|
99 |
+
def _normalize_weights(self, weights):
|
100 |
+
"""Normalize weights to sum to 1"""
|
101 |
+
total = sum(weights.values())
|
102 |
+
if total == 0:
|
103 |
+
# Handle case where all weights became zero due to aggressive multipliers
|
104 |
+
# This could assign equal weights or revert to base weights
|
105 |
+
logger.warning("All weights became zero after adjustments. Reverting to base weights.")
|
106 |
+
return {k: 1.0/len(self.base_weights) for k in self.base_weights}
|
107 |
+
return {k: v/total for k, v in weights.items()}
|