Spaces:
Running
Running
# Create an improved contrast detection file: utils/improved_contrast_analyzer.py | |
import numpy as np | |
import cv2 | |
import colorsys | |
from scipy import ndimage, spatial | |
from sklearn.cluster import DBSCAN | |
class ImprovedContrastAnalyzer: | |
""" | |
Advanced contrast analyzer focused on Alzheimer's-friendly environments | |
""" | |
def __init__(self, wcag_threshold=4.5): | |
self.wcag_threshold = wcag_threshold | |
# ADE20K class mappings for important objects | |
self.important_classes = { | |
'floor': [3, 4], # floor, wood floor | |
'wall': [0, 1], # wall, building | |
'ceiling': [5], # ceiling | |
'sofa': [10], # sofa | |
'chair': [19], # chair | |
'table': [15], # table | |
'door': [25], # door | |
'window': [8], # window | |
'stairs': [53], # stairs | |
'bed': [7], # bed | |
} | |
# Priority relationships (high priority = more important for safety) | |
self.priority_relationships = { | |
('floor', 'sofa'): 'high', | |
('floor', 'chair'): 'high', | |
('floor', 'table'): 'high', | |
('wall', 'sofa'): 'medium', | |
('wall', 'chair'): 'medium', | |
('wall', 'door'): 'high', | |
('floor', 'stairs'): 'critical', | |
('floor', 'bed'): 'medium', | |
('wall', 'window'): 'low', | |
('ceiling', 'wall'): 'low', | |
} | |
def get_object_category(self, class_id): | |
"""Map segmentation class to object category""" | |
for category, class_ids in self.important_classes.items(): | |
if class_id in class_ids: | |
return category | |
return 'other' | |
def calculate_wcag_contrast(self, color1, color2): | |
"""Calculate WCAG contrast ratio""" | |
def relative_luminance(color): | |
rgb = [c / 255.0 for c in color] | |
return sum(c / 12.92 if c <= 0.03928 else ((c + 0.055) / 1.055) ** 2.4 | |
for c in rgb) * [0.2126, 0.7152, 0.0722][i] for i, c in enumerate(rgb)) | |
lum1 = sum(self.relative_luminance_component(color1)) | |
lum2 = sum(self.relative_luminance_component(color2)) | |
lighter = max(lum1, lum2) | |
darker = min(lum1, lum2) | |
return (lighter + 0.05) / (darker + 0.05) | |
def relative_luminance_component(self, color): | |
"""Calculate relative luminance components""" | |
rgb = [c / 255.0 for c in color] | |
components = [] | |
factors = [0.2126, 0.7152, 0.0722] | |
for i, c in enumerate(rgb): | |
if c <= 0.03928: | |
components.append((c / 12.92) * factors[i]) | |
else: | |
components.append(((c + 0.055) / 1.055) ** 2.4 * factors[i]) | |
return components | |
def calculate_perceptual_contrast(self, color1, color2): | |
"""Calculate perceptual contrast including hue and saturation differences""" | |
# Convert to HSV for better perceptual analysis | |
hsv1 = cv2.cvtColor(np.uint8([[color1]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0 | |
hsv2 = cv2.cvtColor(np.uint8([[color2]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0 | |
# Hue difference (circular) | |
hue_diff = abs(hsv1[0] - hsv2[0]) | |
if hue_diff > 0.5: | |
hue_diff = 1 - hue_diff | |
# Saturation difference | |
sat_diff = abs(hsv1[1] - hsv2[1]) | |
# Value (brightness) difference | |
val_diff = abs(hsv1[2] - hsv2[2]) | |
# Combined perceptual score (0-1, higher is more different) | |
perceptual_contrast = np.sqrt(hue_diff**2 + sat_diff**2 + val_diff**2) / np.sqrt(3) | |
return perceptual_contrast | |
def find_clean_boundaries(self, mask1, mask2, min_boundary_length=50): | |
"""Find clean boundaries between two segments""" | |
# Dilate both masks slightly | |
kernel = np.ones((3, 3), np.uint8) | |
dilated1 = cv2.dilate(mask1.astype(np.uint8), kernel, iterations=1) | |
dilated2 = cv2.dilate(mask2.astype(np.uint8), kernel, iterations=1) | |
# Find intersection (boundary area) | |
boundary = (dilated1 & dilated2).astype(bool) | |
# Remove small disconnected boundary pieces | |
labeled_boundary = ndimage.label(boundary)[0] | |
for region_id in range(1, labeled_boundary.max() + 1): | |
region_mask = labeled_boundary == region_id | |
if np.sum(region_mask) < min_boundary_length: | |
boundary[region_mask] = False | |
return boundary | |
def get_representative_colors(self, image, mask, n_samples=1000): | |
"""Get representative colors from a masked region using clustering""" | |
if not np.any(mask): | |
return np.array([0, 0, 0]) | |
# Sample pixels from the mask | |
y_coords, x_coords = np.where(mask) | |
if len(y_coords) > n_samples: | |
indices = np.random.choice(len(y_coords), n_samples, replace=False) | |
y_coords = y_coords[indices] | |
x_coords = x_coords[indices] | |
colors = image[y_coords, x_coords] | |
# Use DBSCAN clustering to find dominant colors | |
if len(colors) > 10: | |
clustering = DBSCAN(eps=30, min_samples=5).fit(colors) | |
labels = clustering.labels_ | |
# Get the largest cluster | |
unique_labels, counts = np.unique(labels[labels >= 0], return_counts=True) | |
if len(unique_labels) > 0: | |
dominant_label = unique_labels[np.argmax(counts)] | |
dominant_colors = colors[labels == dominant_label] | |
return np.mean(dominant_colors, axis=0).astype(int) | |
# Fallback to mean color | |
return np.mean(colors, axis=0).astype(int) | |
def analyze_improved_contrast(self, image, segmentation): | |
""" | |
Perform improved contrast analysis focused on important relationships | |
""" | |
h, w = segmentation.shape | |
results = { | |
'critical_issues': [], | |
'high_priority_issues': [], | |
'medium_priority_issues': [], | |
'statistics': {}, | |
'visualization': image.copy() | |
} | |
# Get unique segments and their categories | |
unique_segments = np.unique(segmentation) | |
segment_categories = {} | |
segment_colors = {} | |
for seg_id in unique_segments: | |
if seg_id == 0: # Skip background | |
continue | |
mask = segmentation == seg_id | |
category = self.get_object_category(seg_id) | |
segment_categories[seg_id] = category | |
segment_colors[seg_id] = self.get_representative_colors(image, mask) | |
# Analyze important relationships | |
total_issues = 0 | |
critical_count = 0 | |
high_count = 0 | |
medium_count = 0 | |
for i, seg_id1 in enumerate(unique_segments): | |
if seg_id1 == 0: | |
continue | |
category1 = segment_categories.get(seg_id1, 'other') | |
if category1 == 'other': | |
continue | |
for seg_id2 in unique_segments[i+1:]: | |
if seg_id2 == 0: | |
continue | |
category2 = segment_categories.get(seg_id2, 'other') | |
if category2 == 'other': | |
continue | |
# Check if this is an important relationship | |
relationship = tuple(sorted([category1, category2])) | |
priority = self.priority_relationships.get(relationship) | |
if priority is None: | |
continue | |
# Check if segments are adjacent | |
mask1 = segmentation == seg_id1 | |
mask2 = segmentation == seg_id2 | |
boundary = self.find_clean_boundaries(mask1, mask2) | |
if not np.any(boundary): | |
continue | |
# Calculate contrasts | |
color1 = segment_colors[seg_id1] | |
color2 = segment_colors[seg_id2] | |
wcag_contrast = self.calculate_wcag_contrast(color1, color2) | |
perceptual_contrast = self.calculate_perceptual_contrast(color1, color2) | |
# Determine if there's an issue | |
wcag_issue = wcag_contrast < self.wcag_threshold | |
perceptual_issue = perceptual_contrast < 0.3 # Threshold for perceptual difference | |
if wcag_issue or perceptual_issue: | |
issue = { | |
'categories': (category1, category2), | |
'segment_ids': (seg_id1, seg_id2), | |
'wcag_contrast': wcag_contrast, | |
'perceptual_contrast': perceptual_contrast, | |
'boundary_pixels': np.sum(boundary), | |
'priority': priority | |
} | |
# Color-code the boundary based on priority | |
if priority == 'critical': | |
results['critical_issues'].append(issue) | |
results['visualization'][boundary] = [255, 0, 0] # Red | |
critical_count += 1 | |
elif priority == 'high': | |
results['high_priority_issues'].append(issue) | |
results['visualization'][boundary] = [255, 128, 0] # Orange | |
high_count += 1 | |
elif priority == 'medium': | |
results['medium_priority_issues'].append(issue) | |
results['visualization'][boundary] = [255, 255, 0] # Yellow | |
medium_count += 1 | |
total_issues += 1 | |
# Calculate statistics | |
results['statistics'] = { | |
'total_issues': total_issues, | |
'critical_issues': critical_count, | |
'high_priority_issues': high_count, | |
'medium_priority_issues': medium_count, | |
'segments_analyzed': len([cat for cat in segment_categories.values() if cat != 'other']) | |
} | |
return results | |
# Update your contrast detection imports and usage | |
class PrioritizedContrastDetector: | |
"""Wrapper for the improved contrast analyzer""" | |
def __init__(self, threshold=4.5): | |
self.analyzer = ImprovedContrastAnalyzer(wcag_threshold=threshold) | |
def analyze(self, image, segmentation, threshold, highlight_color=(255, 0, 0)): | |
"""Analyze with improved logic""" | |
results = self.analyzer.analyze_improved_contrast(image, segmentation) | |
# Convert to format expected by original interface | |
contrast_image = results['visualization'] | |
# Create a simple problem areas mask for compatibility | |
problem_areas = np.any([ | |
contrast_image[:, :, 0] == 255, # Any red channel highlighting | |
], axis=0) | |
# Format statistics | |
stats = results['statistics'].copy() | |
stats['threshold'] = threshold | |
stats['problem_count'] = stats['total_issues'] | |
# Add detailed breakdown | |
if results['critical_issues']: | |
stats['critical_details'] = [ | |
f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1" | |
for issue in results['critical_issues'] | |
] | |
if results['high_priority_issues']: | |
stats['high_priority_details'] = [ | |
f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1" | |
for issue in results['high_priority_issues'] | |
] | |
return contrast_image, problem_areas, stats | |