lolout1 commited on
Commit
c5e56a3
·
1 Parent(s): 361f6c2

fixing imports

Browse files
Files changed (1) hide show
  1. universal_contrast_analyzer.py +455 -0
universal_contrast_analyzer.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal Contrast Analyzer for detecting low contrast between ALL adjacent objects.
3
+ Optimized for Alzheimer's/dementia care environments.
4
+ """
5
+
6
+ import numpy as np
7
+ import cv2
8
+ from typing import Dict, List, Tuple, Optional
9
+ import logging
10
+ from scipy.spatial import distance
11
+ from skimage.segmentation import find_boundaries
12
+ from sklearn.cluster import DBSCAN
13
+ import colorsys
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class UniversalContrastAnalyzer:
19
+ """
20
+ Analyzes contrast between ALL adjacent objects in a room.
21
+ Ensures proper visibility for elderly individuals with Alzheimer's or dementia.
22
+ """
23
+
24
+ def __init__(self, wcag_threshold: float = 4.5):
25
+ self.wcag_threshold = wcag_threshold
26
+
27
+ # Comprehensive ADE20K semantic class mappings
28
+ self.semantic_classes = {
29
+ # Floors and ground surfaces
30
+ 'floor': [3, 4, 13, 28, 78], # floor, wood floor, rug, carpet, mat
31
+
32
+ # Walls and vertical surfaces
33
+ 'wall': [0, 1, 9, 21], # wall, building, brick, house
34
+
35
+ # Ceiling
36
+ 'ceiling': [5, 16], # ceiling, sky (for rooms with skylights)
37
+
38
+ # Furniture - expanded list
39
+ 'furniture': [
40
+ 10, 19, 15, 7, 18, 23, 30, 33, 34, 36, 44, 45, 57, 63, 64, 65, 75,
41
+ # sofa, chair, table, bed, armchair, cabinet, desk, counter, stool,
42
+ # bench, nightstand, coffee table, ottoman, wardrobe, dresser, shelf,
43
+ # chest of drawers
44
+ ],
45
+
46
+ # Doors and openings
47
+ 'door': [25, 14, 79], # door, windowpane, screen door
48
+
49
+ # Windows
50
+ 'window': [8, 14], # window, windowpane
51
+
52
+ # Stairs and steps
53
+ 'stairs': [53, 59], # stairs, step
54
+
55
+ # Small objects that might be on floors/furniture
56
+ 'objects': [
57
+ 17, 20, 24, 37, 38, 39, 42, 62, 68, 71, 73, 80, 82, 84, 89, 90, 92, 93,
58
+ # curtain, book, picture, towel, clothes, pillow, box, bag, lamp, fan,
59
+ # cushion, basket, bottle, plate, clock, vase, tray, bowl
60
+ ],
61
+
62
+ # Kitchen/bathroom fixtures
63
+ 'fixtures': [
64
+ 32, 46, 49, 50, 54, 66, 69, 70, 77, 94, 97, 98, 99, 117, 118, 119, 120,
65
+ # sink, toilet, bathtub, shower, dishwasher, oven, microwave,
66
+ # refrigerator, stove, washer, dryer, range hood, kitchen island
67
+ ],
68
+
69
+ # Decorative elements
70
+ 'decorative': [
71
+ 6, 12, 56, 60, 61, 72, 83, 91, 96, 100, 102, 104, 106, 110, 112,
72
+ # painting, mirror, sculpture, chandelier, sconce, poster, tapestry
73
+ ]
74
+ }
75
+
76
+ # Create reverse mapping for quick lookup
77
+ self.class_to_category = {}
78
+ for category, class_ids in self.semantic_classes.items():
79
+ for class_id in class_ids:
80
+ self.class_to_category[class_id] = category
81
+
82
+ def calculate_wcag_contrast(self, color1: np.ndarray, color2: np.ndarray) -> float:
83
+ """Calculate WCAG 2.0 contrast ratio between two colors"""
84
+ def relative_luminance(rgb):
85
+ # Normalize to 0-1
86
+ rgb_norm = rgb / 255.0
87
+
88
+ # Apply gamma correction (linearize)
89
+ rgb_linear = np.where(
90
+ rgb_norm <= 0.03928,
91
+ rgb_norm / 12.92,
92
+ ((rgb_norm + 0.055) / 1.055) ** 2.4
93
+ )
94
+
95
+ # Calculate luminance using ITU-R BT.709 coefficients
96
+ return np.dot(rgb_linear, [0.2126, 0.7152, 0.0722])
97
+
98
+ lum1 = relative_luminance(color1)
99
+ lum2 = relative_luminance(color2)
100
+
101
+ # Ensure lighter color is in numerator
102
+ lighter = max(lum1, lum2)
103
+ darker = min(lum1, lum2)
104
+
105
+ return (lighter + 0.05) / (darker + 0.05)
106
+
107
+ def calculate_hue_difference(self, color1: np.ndarray, color2: np.ndarray) -> float:
108
+ """Calculate hue difference in degrees (0-180)"""
109
+ # Convert RGB to HSV
110
+ hsv1 = cv2.cvtColor(color1.reshape(1, 1, 3).astype(np.uint8), cv2.COLOR_RGB2HSV)[0, 0]
111
+ hsv2 = cv2.cvtColor(color2.reshape(1, 1, 3).astype(np.uint8), cv2.COLOR_RGB2HSV)[0, 0]
112
+
113
+ # Calculate circular hue difference (0-180 range in OpenCV)
114
+ hue_diff = abs(hsv1[0] - hsv2[0])
115
+ if hue_diff > 90:
116
+ hue_diff = 180 - hue_diff
117
+
118
+ return hue_diff
119
+
120
+ def calculate_saturation_difference(self, color1: np.ndarray, color2: np.ndarray) -> float:
121
+ """Calculate saturation difference (0-255)"""
122
+ hsv1 = cv2.cvtColor(color1.reshape(1, 1, 3).astype(np.uint8), cv2.COLOR_RGB2HSV)[0, 0]
123
+ hsv2 = cv2.cvtColor(color2.reshape(1, 1, 3).astype(np.uint8), cv2.COLOR_RGB2HSV)[0, 0]
124
+
125
+ return abs(int(hsv1[1]) - int(hsv2[1]))
126
+
127
+ def extract_dominant_color(self, image: np.ndarray, mask: np.ndarray,
128
+ sample_size: int = 1000) -> np.ndarray:
129
+ """Extract dominant color from masked region using robust statistics"""
130
+ if not np.any(mask):
131
+ return np.array([128, 128, 128]) # Default gray
132
+
133
+ # Get masked pixels
134
+ masked_pixels = image[mask]
135
+ if len(masked_pixels) == 0:
136
+ return np.array([128, 128, 128])
137
+
138
+ # Sample if too many pixels (for efficiency)
139
+ if len(masked_pixels) > sample_size:
140
+ indices = np.random.choice(len(masked_pixels), sample_size, replace=False)
141
+ masked_pixels = masked_pixels[indices]
142
+
143
+ # Use DBSCAN clustering to find dominant color cluster
144
+ if len(masked_pixels) > 50:
145
+ try:
146
+ clustering = DBSCAN(eps=30, min_samples=10).fit(masked_pixels)
147
+ labels = clustering.labels_
148
+
149
+ # Get the largest cluster
150
+ unique_labels, counts = np.unique(labels[labels >= 0], return_counts=True)
151
+ if len(unique_labels) > 0:
152
+ dominant_label = unique_labels[np.argmax(counts)]
153
+ dominant_colors = masked_pixels[labels == dominant_label]
154
+ return np.median(dominant_colors, axis=0).astype(int)
155
+ except:
156
+ pass
157
+
158
+ # Fallback to median
159
+ return np.median(masked_pixels, axis=0).astype(int)
160
+
161
+ def find_adjacent_segments(self, segmentation: np.ndarray) -> Dict[Tuple[int, int], np.ndarray]:
162
+ """
163
+ Find all pairs of adjacent segments and their boundaries.
164
+ Returns dict mapping (seg1_id, seg2_id) to boundary mask.
165
+ """
166
+ adjacencies = {}
167
+
168
+ # Find boundaries using 4-connectivity
169
+ boundaries = find_boundaries(segmentation, mode='inner')
170
+
171
+ # For each boundary pixel, check its neighbors
172
+ h, w = segmentation.shape
173
+ for y in range(1, h-1):
174
+ for x in range(1, w-1):
175
+ if boundaries[y, x]:
176
+ center_id = segmentation[y, x]
177
+
178
+ # Check 8-connected neighbors for more complete boundaries
179
+ neighbors = [
180
+ segmentation[y-1, x], # top
181
+ segmentation[y+1, x], # bottom
182
+ segmentation[y, x-1], # left
183
+ segmentation[y, x+1], # right
184
+ segmentation[y-1, x-1], # top-left
185
+ segmentation[y-1, x+1], # top-right
186
+ segmentation[y+1, x-1], # bottom-left
187
+ segmentation[y+1, x+1] # bottom-right
188
+ ]
189
+
190
+ for neighbor_id in neighbors:
191
+ if neighbor_id != center_id and neighbor_id != 0: # Different segment, not background
192
+ # Create ordered pair (smaller id first)
193
+ pair = tuple(sorted([center_id, neighbor_id]))
194
+
195
+ # Add this boundary pixel to the adjacency map
196
+ if pair not in adjacencies:
197
+ adjacencies[pair] = np.zeros((h, w), dtype=bool)
198
+ adjacencies[pair][y, x] = True
199
+
200
+ # Filter out small boundaries (noise)
201
+ min_boundary_pixels = 20 # Reduced threshold for better detection
202
+ filtered_adjacencies = {}
203
+ for pair, boundary in adjacencies.items():
204
+ if np.sum(boundary) >= min_boundary_pixels:
205
+ filtered_adjacencies[pair] = boundary
206
+
207
+ return filtered_adjacencies
208
+
209
+ def is_contrast_sufficient(self, color1: np.ndarray, color2: np.ndarray,
210
+ category1: str, category2: str) -> Tuple[bool, str]:
211
+ """
212
+ Determine if contrast is sufficient based on WCAG and perceptual guidelines.
213
+ Returns (is_sufficient, severity_if_not)
214
+ """
215
+ wcag_ratio = self.calculate_wcag_contrast(color1, color2)
216
+ hue_diff = self.calculate_hue_difference(color1, color2)
217
+ sat_diff = self.calculate_saturation_difference(color1, color2)
218
+
219
+ # Critical relationships requiring highest contrast
220
+ critical_pairs = [
221
+ ('floor', 'stairs'),
222
+ ('floor', 'door'),
223
+ ('stairs', 'wall')
224
+ ]
225
+
226
+ # High priority relationships
227
+ high_priority_pairs = [
228
+ ('floor', 'furniture'),
229
+ ('wall', 'door'),
230
+ ('wall', 'furniture'),
231
+ ('floor', 'objects')
232
+ ]
233
+
234
+ # Check relationship type
235
+ relationship = tuple(sorted([category1, category2]))
236
+
237
+ # Determine thresholds based on relationship
238
+ if relationship in critical_pairs:
239
+ # Critical: require 7:1 contrast ratio
240
+ if wcag_ratio < 7.0:
241
+ return False, 'critical'
242
+ if hue_diff < 30 and sat_diff < 50:
243
+ return False, 'critical'
244
+
245
+ elif relationship in high_priority_pairs:
246
+ # High priority: require 4.5:1 contrast ratio
247
+ if wcag_ratio < 4.5:
248
+ return False, 'high'
249
+ if wcag_ratio < 7.0 and hue_diff < 20 and sat_diff < 40:
250
+ return False, 'high'
251
+
252
+ else:
253
+ # Standard: require 3:1 contrast ratio minimum
254
+ if wcag_ratio < 3.0:
255
+ return False, 'medium'
256
+ if wcag_ratio < 4.5 and hue_diff < 15 and sat_diff < 30:
257
+ return False, 'medium'
258
+
259
+ return True, None
260
+
261
+ def analyze_contrast(self, image: np.ndarray, segmentation: np.ndarray) -> Dict:
262
+ """
263
+ Perform comprehensive contrast analysis between ALL adjacent objects.
264
+
265
+ Args:
266
+ image: RGB image
267
+ segmentation: Segmentation mask with class IDs
268
+
269
+ Returns:
270
+ Dictionary containing analysis results and visualizations
271
+ """
272
+ h, w = segmentation.shape
273
+ results = {
274
+ 'issues': [],
275
+ 'visualization': image.copy(),
276
+ 'statistics': {
277
+ 'total_segments': 0,
278
+ 'analyzed_pairs': 0,
279
+ 'low_contrast_pairs': 0,
280
+ 'critical_issues': 0,
281
+ 'high_priority_issues': 0,
282
+ 'medium_priority_issues': 0,
283
+ 'floor_object_issues': 0
284
+ }
285
+ }
286
+
287
+ # Get unique segments
288
+ unique_segments = np.unique(segmentation)
289
+ unique_segments = unique_segments[unique_segments != 0] # Remove background
290
+ results['statistics']['total_segments'] = len(unique_segments)
291
+
292
+ # Build segment information
293
+ segment_info = {}
294
+
295
+ logger.info(f"Building segment information for {len(unique_segments)} segments...")
296
+
297
+ for seg_id in unique_segments:
298
+ mask = segmentation == seg_id
299
+ area = np.sum(mask)
300
+
301
+ if area < 50: # Skip very small segments
302
+ continue
303
+
304
+ category = self.class_to_category.get(seg_id, 'unknown')
305
+ color = self.extract_dominant_color(image, mask)
306
+
307
+ segment_info[seg_id] = {
308
+ 'category': category,
309
+ 'mask': mask,
310
+ 'color': color,
311
+ 'area': area,
312
+ 'class_id': seg_id
313
+ }
314
+
315
+ # Find all adjacent segment pairs
316
+ logger.info("Finding adjacent segments...")
317
+ adjacencies = self.find_adjacent_segments(segmentation)
318
+ logger.info(f"Found {len(adjacencies)} adjacent segment pairs")
319
+
320
+ # Analyze each adjacent pair
321
+ for (seg1_id, seg2_id), boundary in adjacencies.items():
322
+ if seg1_id not in segment_info or seg2_id not in segment_info:
323
+ continue
324
+
325
+ info1 = segment_info[seg1_id]
326
+ info2 = segment_info[seg2_id]
327
+
328
+ # Skip if both are unknown categories
329
+ if info1['category'] == 'unknown' and info2['category'] == 'unknown':
330
+ continue
331
+
332
+ results['statistics']['analyzed_pairs'] += 1
333
+
334
+ # Check contrast sufficiency
335
+ is_sufficient, severity = self.is_contrast_sufficient(
336
+ info1['color'], info2['color'],
337
+ info1['category'], info2['category']
338
+ )
339
+
340
+ if not is_sufficient:
341
+ results['statistics']['low_contrast_pairs'] += 1
342
+
343
+ # Calculate detailed metrics
344
+ wcag_ratio = self.calculate_wcag_contrast(info1['color'], info2['color'])
345
+ hue_diff = self.calculate_hue_difference(info1['color'], info2['color'])
346
+ sat_diff = self.calculate_saturation_difference(info1['color'], info2['color'])
347
+
348
+ # Check if it's a floor-object issue
349
+ is_floor_object = (
350
+ (info1['category'] == 'floor' and info2['category'] in ['furniture', 'objects']) or
351
+ (info2['category'] == 'floor' and info1['category'] in ['furniture', 'objects'])
352
+ )
353
+
354
+ if is_floor_object:
355
+ results['statistics']['floor_object_issues'] += 1
356
+
357
+ # Count by severity
358
+ if severity == 'critical':
359
+ results['statistics']['critical_issues'] += 1
360
+ elif severity == 'high':
361
+ results['statistics']['high_priority_issues'] += 1
362
+ elif severity == 'medium':
363
+ results['statistics']['medium_priority_issues'] += 1
364
+
365
+ # Record the issue
366
+ issue = {
367
+ 'segment_ids': (seg1_id, seg2_id),
368
+ 'categories': (info1['category'], info2['category']),
369
+ 'colors': (info1['color'].tolist(), info2['color'].tolist()),
370
+ 'wcag_ratio': float(wcag_ratio),
371
+ 'hue_difference': float(hue_diff),
372
+ 'saturation_difference': float(sat_diff),
373
+ 'boundary_pixels': int(np.sum(boundary)),
374
+ 'severity': severity,
375
+ 'is_floor_object': is_floor_object,
376
+ 'boundary_mask': boundary
377
+ }
378
+
379
+ results['issues'].append(issue)
380
+
381
+ # Visualize on the output image
382
+ self._visualize_issue(results['visualization'], boundary, severity)
383
+
384
+ # Sort issues by severity
385
+ severity_order = {'critical': 0, 'high': 1, 'medium': 2}
386
+ results['issues'].sort(key=lambda x: severity_order.get(x['severity'], 3))
387
+
388
+ logger.info(f"Contrast analysis complete: {results['statistics']['low_contrast_pairs']} issues found")
389
+
390
+ return results
391
+
392
+ def _visualize_issue(self, image: np.ndarray, boundary: np.ndarray, severity: str):
393
+ """Add visual indicators for contrast issues"""
394
+ # Color coding by severity
395
+ colors = {
396
+ 'critical': (255, 0, 0), # Red
397
+ 'high': (255, 128, 0), # Orange
398
+ 'medium': (255, 255, 0), # Yellow
399
+ }
400
+
401
+ color = colors.get(severity, (255, 255, 255))
402
+
403
+ # Dilate boundary for better visibility
404
+ kernel = np.ones((3, 3), np.uint8)
405
+ dilated = cv2.dilate(boundary.astype(np.uint8), kernel, iterations=2)
406
+
407
+ # Apply color overlay with transparency
408
+ overlay = image.copy()
409
+ overlay[dilated > 0] = color
410
+ cv2.addWeighted(overlay, 0.5, image, 0.5, 0, image)
411
+
412
+ return image
413
+
414
+ def generate_report(self, results: Dict) -> str:
415
+ """Generate a detailed text report of contrast analysis"""
416
+ stats = results['statistics']
417
+ issues = results['issues']
418
+
419
+ report = []
420
+ report.append("=== Universal Contrast Analysis Report ===\n")
421
+
422
+ # Summary statistics
423
+ report.append(f"Total segments analyzed: {stats['total_segments']}")
424
+ report.append(f"Adjacent pairs analyzed: {stats['analyzed_pairs']}")
425
+ report.append(f"Low contrast pairs found: {stats['low_contrast_pairs']}")
426
+ report.append(f"- Critical issues: {stats['critical_issues']}")
427
+ report.append(f"- High priority issues: {stats['high_priority_issues']}")
428
+ report.append(f"- Medium priority issues: {stats['medium_priority_issues']}")
429
+ report.append(f"Floor-object contrast issues: {stats['floor_object_issues']}\n")
430
+
431
+ # Detailed issues
432
+ if issues:
433
+ report.append("=== Contrast Issues (sorted by severity) ===\n")
434
+
435
+ for i, issue in enumerate(issues[:10], 1): # Show top 10 issues
436
+ cat1, cat2 = issue['categories']
437
+ wcag = issue['wcag_ratio']
438
+ hue_diff = issue['hue_difference']
439
+ sat_diff = issue['saturation_difference']
440
+ severity = issue['severity'].upper()
441
+
442
+ report.append(f"{i}. [{severity}] {cat1} ↔ {cat2}")
443
+ report.append(f" - WCAG Contrast Ratio: {wcag:.2f}:1 (minimum: 4.5:1)")
444
+ report.append(f" - Hue Difference: {hue_diff:.1f}° (recommended: >30°)")
445
+ report.append(f" - Saturation Difference: {sat_diff} (recommended: >50)")
446
+
447
+ if issue['is_floor_object']:
448
+ report.append(" - ⚠️ Object on floor - requires high visibility!")
449
+
450
+ report.append(f" - Boundary size: {issue['boundary_pixels']} pixels")
451
+ report.append("")
452
+ else:
453
+ report.append("✅ No contrast issues detected!")
454
+
455
+ return "\n".join(report)