LPX commited on
Commit
19d34ca
·
1 Parent(s): 9fe5b90

major: testing out first implentation of an ensemble team

Browse files
Files changed (1) hide show
  1. app_mcp.py +225 -6
app_mcp.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Literal
3
  import spaces
4
  import gradio as gr
@@ -301,17 +302,232 @@ def get_consensus_label(results):
301
 
302
  # Update predict_image_with_json to return consensus label
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  def predict_image_with_json(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
 
 
 
 
 
 
 
 
 
305
  if augment_methods:
306
  img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
307
  else:
308
  img_pil = img
309
- img_pil, results = predict_image(img_pil, confidence_threshold)
310
- img_np = np.array(img_pil) # Convert PIL Image to NumPy array
311
  img_np_og = np.array(img) # Convert PIL Image to NumPy array
312
 
313
- gradient_image = gradient_processing(img_np) # Added gradient processing
314
- minmax_image = minmax_preprocess(img_np) # Added MinMax processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # First pass - standard analysis
317
  ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
@@ -330,8 +546,11 @@ def predict_image_with_json(img, confidence_threshold, augment_methods, rotate_d
330
  r.get("Real Score", ""),
331
  r.get("Label", "")
332
  ] for r in results]
333
- consensus = get_consensus_label(table_rows)
334
- return img_pil, forensics_images, table_rows, results, consensus
 
 
 
335
 
336
  with gr.Blocks(css="#post-gallery { overflow: hidden !important;} .grid-wrap{ overflow-y: hidden !important;} .ms-gr-ant-welcome-icon{ height:unset !important;} .tabs{margin-top:10px;}") as demo:
337
  with ms.Application() as app:
 
1
  import os
2
+ import time
3
  from typing import Literal
4
  import spaces
5
  import gradio as gr
 
302
 
303
  # Update predict_image_with_json to return consensus label
304
 
305
+ class ModelWeightManager:
306
+ def __init__(self):
307
+ self.base_weights = {
308
+ "model_1": 0.15, # SwinV2 Based
309
+ "model_2": 0.15, # ViT Based
310
+ "model_3": 0.15, # SDXL Dataset
311
+ "model_4": 0.15, # SDXL + FLUX
312
+ "model_5": 0.15, # ViT Based
313
+ "model_5b": 0.10, # ViT Based, Newer Dataset
314
+ "model_6": 0.10, # Swin, Midj + SDXL
315
+ "model_7": 0.05 # ViT
316
+ }
317
+ self.situation_weights = {
318
+ "high_confidence": 1.2, # Boost weights for high confidence predictions
319
+ "low_confidence": 0.8, # Reduce weights for low confidence
320
+ "conflict": 0.5, # Reduce weights when models disagree
321
+ "consensus": 1.5 # Boost weights when models agree
322
+ }
323
+
324
+ def adjust_weights(self, predictions, confidence_scores):
325
+ """Dynamically adjust weights based on prediction patterns"""
326
+ adjusted_weights = self.base_weights.copy()
327
+
328
+ # Check for consensus
329
+ if self._has_consensus(predictions):
330
+ for model in adjusted_weights:
331
+ adjusted_weights[model] *= self.situation_weights["consensus"]
332
+
333
+ # Check for conflicts
334
+ if self._has_conflicts(predictions):
335
+ for model in adjusted_weights:
336
+ adjusted_weights[model] *= self.situation_weights["conflict"]
337
+
338
+ # Adjust based on confidence
339
+ for model, confidence in confidence_scores.items():
340
+ if confidence > 0.8:
341
+ adjusted_weights[model] *= self.situation_weights["high_confidence"]
342
+ elif confidence < 0.5:
343
+ adjusted_weights[model] *= self.situation_weights["low_confidence"]
344
+
345
+ return self._normalize_weights(adjusted_weights)
346
+
347
+ def _has_consensus(self, predictions):
348
+ """Check if models agree on prediction"""
349
+ return len(set(predictions.values())) == 1
350
+
351
+ def _has_conflicts(self, predictions):
352
+ """Check if models have conflicting predictions"""
353
+ return len(set(predictions.values())) > 2
354
+
355
+ def _normalize_weights(self, weights):
356
+ """Normalize weights to sum to 1"""
357
+ total = sum(weights.values())
358
+ return {k: v/total for k, v in weights.items()}
359
+
360
+ class EnsembleMonitorAgent:
361
+ def __init__(self):
362
+ self.performance_metrics = {
363
+ "model_accuracy": {},
364
+ "response_times": {},
365
+ "confidence_distribution": {},
366
+ "consensus_rate": 0.0
367
+ }
368
+ self.alerts = []
369
+
370
+ def monitor_prediction(self, model_id, prediction, confidence, response_time):
371
+ """Monitor individual model performance"""
372
+ if model_id not in self.performance_metrics["model_accuracy"]:
373
+ self.performance_metrics["model_accuracy"][model_id] = []
374
+ self.performance_metrics["response_times"][model_id] = []
375
+ self.performance_metrics["confidence_distribution"][model_id] = []
376
+
377
+ self.performance_metrics["response_times"][model_id].append(response_time)
378
+ self.performance_metrics["confidence_distribution"][model_id].append(confidence)
379
+
380
+ # Check for performance issues
381
+ self._check_performance_issues(model_id)
382
+
383
+ def _check_performance_issues(self, model_id):
384
+ """Check for any performance anomalies"""
385
+ response_times = self.performance_metrics["response_times"][model_id]
386
+ if len(response_times) > 10:
387
+ avg_time = sum(response_times[-10:]) / 10
388
+ if avg_time > 2.0: # More than 2 seconds
389
+ self.alerts.append(f"High latency detected for {model_id}: {avg_time:.2f}s")
390
+
391
+ class WeightOptimizationAgent:
392
+ def __init__(self, weight_manager):
393
+ self.weight_manager = weight_manager
394
+ self.performance_history = []
395
+ self.optimization_threshold = 0.1 # 10% performance change triggers optimization
396
+
397
+ def analyze_performance(self, predictions, actual_results):
398
+ """Analyze model performance and suggest weight adjustments"""
399
+ # Placeholder for actual_results. In a real scenario, this would come from a validation set.
400
+ # For now, we'll just track predictions.
401
+ self.performance_history.append(predictions)
402
+
403
+ if self._should_optimize():
404
+ self._optimize_weights()
405
+
406
+ def _should_optimize(self):
407
+ """Determine if weights should be optimized"""
408
+ if len(self.performance_history) < 10:
409
+ return False
410
+
411
+ # Placeholder for actual performance calculation
412
+ # For demonstration, let's say we optimize every 10 runs
413
+ return len(self.performance_history) % 10 == 0
414
+
415
+ def _optimize_weights(self):
416
+ """Optimize model weights based on performance"""
417
+ logger.info("Optimizing model weights based on recent performance.")
418
+ # This is where more sophisticated optimization logic would go.
419
+ # For example, you could slightly adjust weights of models that consistently predict correctly.
420
+ pass
421
+
422
+ class SystemHealthAgent:
423
+ def __init__(self):
424
+ self.health_metrics = {
425
+ "memory_usage": [],
426
+ "gpu_utilization": [],
427
+ "model_load_times": {},
428
+ "error_rates": {}
429
+ }
430
+
431
+ def monitor_system_health(self):
432
+ """Monitor overall system health"""
433
+ self._check_memory_usage()
434
+ self._check_gpu_utilization()
435
+ # You might add _check_model_health() here later
436
+
437
+ def _check_memory_usage(self):
438
+ """Monitor memory usage"""
439
+ try:
440
+ import psutil
441
+ memory = psutil.virtual_memory()
442
+ self.health_metrics["memory_usage"].append(memory.percent)
443
+
444
+ if memory.percent > 90:
445
+ logger.warning(f"High memory usage detected: {memory.percent}%")
446
+ except ImportError:
447
+ logger.warning("psutil not installed. Cannot monitor memory usage.")
448
+
449
+ def _check_gpu_utilization(self):
450
+ """Monitor GPU utilization if available"""
451
+ if torch.cuda.is_available():
452
+ try:
453
+ gpu_util = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
454
+ self.health_metrics["gpu_utilization"].append(gpu_util)
455
+
456
+ if gpu_util > 0.9:
457
+ logger.warning(f"High GPU utilization detected: {gpu_util*100:.2f}%")
458
+ except Exception as e:
459
+ logger.warning(f"Error monitoring GPU utilization: {e}")
460
+ else:
461
+ logger.info("CUDA not available. Skipping GPU utilization monitoring.")
462
+
463
  def predict_image_with_json(img, confidence_threshold, augment_methods, rotate_degrees, noise_level, sharpen_strength):
464
+ # Initialize agents
465
+ monitor_agent = EnsembleMonitorAgent()
466
+ weight_manager = ModelWeightManager()
467
+ optimization_agent = WeightOptimizationAgent(weight_manager)
468
+ health_agent = SystemHealthAgent()
469
+
470
+ # Monitor system health
471
+ health_agent.monitor_system_health()
472
+
473
  if augment_methods:
474
  img_pil, _ = augment_image(img, augment_methods, rotate_degrees, noise_level, sharpen_strength)
475
  else:
476
  img_pil = img
 
 
477
  img_np_og = np.array(img) # Convert PIL Image to NumPy array
478
 
479
+ # Get predictions with timing
480
+ model_predictions = {}
481
+ confidence_scores = {}
482
+ results = [] # To store the results for the DataFrame
483
+
484
+ for model_id in MODEL_REGISTRY:
485
+ model_start = time.time()
486
+ result = infer(img_pil, model_id, confidence_threshold)
487
+ model_end = time.time()
488
+
489
+ # Monitor individual model performance
490
+ monitor_agent.monitor_prediction(
491
+ model_id,
492
+ result["Label"],
493
+ max(result.get("AI Score", 0.0), result.get("Real Score", 0.0)),
494
+ model_end - model_start
495
+ )
496
+
497
+ model_predictions[model_id] = result["Label"]
498
+ confidence_scores[model_id] = max(result.get("AI Score", 0.0), result.get("Real Score", 0.0))
499
+ results.append(result) # Add individual model result to the list
500
+
501
+ # Get adjusted weights
502
+ adjusted_weights = weight_manager.adjust_weights(model_predictions, confidence_scores)
503
+
504
+ # Optimize weights if needed
505
+ optimization_agent.analyze_performance(model_predictions, None) # Placeholder for actual results
506
+
507
+ # Calculate weighted consensus
508
+ weighted_predictions = {
509
+ "AI": 0.0,
510
+ "REAL": 0.0,
511
+ "UNCERTAIN": 0.0
512
+ }
513
+
514
+ for model_id, prediction in model_predictions.items():
515
+ # Ensure the prediction label is valid for weighted_predictions
516
+ if prediction in weighted_predictions:
517
+ weighted_predictions[prediction] += adjusted_weights[model_id]
518
+ else:
519
+ # Handle cases where prediction might be an error or unexpected label
520
+ logger.warning(f"Unexpected prediction label '{prediction}' from model '{model_id}'. Skipping its weight in consensus.")
521
+
522
+ final_prediction_label = "UNCERTAIN"
523
+ if weighted_predictions["AI"] > weighted_predictions["REAL"] and weighted_predictions["AI"] > weighted_predictions["UNCERTAIN"]:
524
+ final_prediction_label = "AI"
525
+ elif weighted_predictions["REAL"] > weighted_predictions["AI"] and weighted_predictions["REAL"] > weighted_predictions["UNCERTAIN"]:
526
+ final_prediction_label = "REAL"
527
+
528
+ # Rest of your existing code remains the same after this point
529
+ gradient_image = gradient_processing(img_np_og) # Added gradient processing
530
+ minmax_image = minmax_preprocess(img_np_og) # Added MinMax processing
531
 
532
  # First pass - standard analysis
533
  ela1 = ELA(img_np_og, quality=75, scale=50, contrast=20, linear=False, grayscale=True)
 
546
  r.get("Real Score", ""),
547
  r.get("Label", "")
548
  ] for r in results]
549
+
550
+ # The get_consensus_label function is now replaced by final_prediction_label from weighted consensus
551
+ consensus_html = f"<b><span style='color:{'red' if final_prediction_label == 'AI' else ('green' if final_prediction_label == 'REAL' else 'orange')}'>{final_prediction_label}</span></b>"
552
+
553
+ return img_pil, forensics_images, table_rows, results, consensus_html
554
 
555
  with gr.Blocks(css="#post-gallery { overflow: hidden !important;} .grid-wrap{ overflow-y: hidden !important;} .ms-gr-ant-welcome-icon{ height:unset !important;} .tabs{margin-top:10px;}") as demo:
556
  with ms.Application() as app: