File size: 12,031 Bytes
6524e7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
# Create an improved contrast detection file: utils/improved_contrast_analyzer.py

import numpy as np
import cv2
import colorsys
from scipy import ndimage, spatial
from sklearn.cluster import DBSCAN

class ImprovedContrastAnalyzer:
    """
    Advanced contrast analyzer focused on Alzheimer's-friendly environments
    """
    
    def __init__(self, wcag_threshold=4.5):
        self.wcag_threshold = wcag_threshold
        
        # ADE20K class mappings for important objects
        self.important_classes = {
            'floor': [3, 4],  # floor, wood floor
            'wall': [0, 1],   # wall, building
            'ceiling': [5],   # ceiling
            'sofa': [10],     # sofa
            'chair': [19],    # chair
            'table': [15],    # table
            'door': [25],     # door
            'window': [8],    # window
            'stairs': [53],   # stairs
            'bed': [7],       # bed
        }
        
        # Priority relationships (high priority = more important for safety)
        self.priority_relationships = {
            ('floor', 'sofa'): 'high',
            ('floor', 'chair'): 'high', 
            ('floor', 'table'): 'high',
            ('wall', 'sofa'): 'medium',
            ('wall', 'chair'): 'medium',
            ('wall', 'door'): 'high',
            ('floor', 'stairs'): 'critical',
            ('floor', 'bed'): 'medium',
            ('wall', 'window'): 'low',
            ('ceiling', 'wall'): 'low',
        }
    
    def get_object_category(self, class_id):
        """Map segmentation class to object category"""
        for category, class_ids in self.important_classes.items():
            if class_id in class_ids:
                return category
        return 'other'
    
    def calculate_wcag_contrast(self, color1, color2):
        """Calculate WCAG contrast ratio"""
        def relative_luminance(color):
            rgb = [c / 255.0 for c in color]
            return sum(c / 12.92 if c <= 0.03928 else ((c + 0.055) / 1.055) ** 2.4 
                      for c in rgb) * [0.2126, 0.7152, 0.0722][i] for i, c in enumerate(rgb))
        
        lum1 = sum(self.relative_luminance_component(color1))
        lum2 = sum(self.relative_luminance_component(color2))
        
        lighter = max(lum1, lum2)
        darker = min(lum1, lum2)
        
        return (lighter + 0.05) / (darker + 0.05)
    
    def relative_luminance_component(self, color):
        """Calculate relative luminance components"""
        rgb = [c / 255.0 for c in color]
        components = []
        factors = [0.2126, 0.7152, 0.0722]
        
        for i, c in enumerate(rgb):
            if c <= 0.03928:
                components.append((c / 12.92) * factors[i])
            else:
                components.append(((c + 0.055) / 1.055) ** 2.4 * factors[i])
        
        return components
    
    def calculate_perceptual_contrast(self, color1, color2):
        """Calculate perceptual contrast including hue and saturation differences"""
        # Convert to HSV for better perceptual analysis
        hsv1 = cv2.cvtColor(np.uint8([[color1]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0
        hsv2 = cv2.cvtColor(np.uint8([[color2]]), cv2.COLOR_RGB2HSV)[0][0] / 255.0
        
        # Hue difference (circular)
        hue_diff = abs(hsv1[0] - hsv2[0])
        if hue_diff > 0.5:
            hue_diff = 1 - hue_diff
        
        # Saturation difference
        sat_diff = abs(hsv1[1] - hsv2[1])
        
        # Value (brightness) difference  
        val_diff = abs(hsv1[2] - hsv2[2])
        
        # Combined perceptual score (0-1, higher is more different)
        perceptual_contrast = np.sqrt(hue_diff**2 + sat_diff**2 + val_diff**2) / np.sqrt(3)
        
        return perceptual_contrast
    
    def find_clean_boundaries(self, mask1, mask2, min_boundary_length=50):
        """Find clean boundaries between two segments"""
        # Dilate both masks slightly
        kernel = np.ones((3, 3), np.uint8)
        dilated1 = cv2.dilate(mask1.astype(np.uint8), kernel, iterations=1)
        dilated2 = cv2.dilate(mask2.astype(np.uint8), kernel, iterations=1)
        
        # Find intersection (boundary area)
        boundary = (dilated1 & dilated2).astype(bool)
        
        # Remove small disconnected boundary pieces
        labeled_boundary = ndimage.label(boundary)[0]
        for region_id in range(1, labeled_boundary.max() + 1):
            region_mask = labeled_boundary == region_id
            if np.sum(region_mask) < min_boundary_length:
                boundary[region_mask] = False
        
        return boundary
    
    def get_representative_colors(self, image, mask, n_samples=1000):
        """Get representative colors from a masked region using clustering"""
        if not np.any(mask):
            return np.array([0, 0, 0])
        
        # Sample pixels from the mask
        y_coords, x_coords = np.where(mask)
        if len(y_coords) > n_samples:
            indices = np.random.choice(len(y_coords), n_samples, replace=False)
            y_coords = y_coords[indices]
            x_coords = x_coords[indices]
        
        colors = image[y_coords, x_coords]
        
        # Use DBSCAN clustering to find dominant colors
        if len(colors) > 10:
            clustering = DBSCAN(eps=30, min_samples=5).fit(colors)
            labels = clustering.labels_
            
            # Get the largest cluster
            unique_labels, counts = np.unique(labels[labels >= 0], return_counts=True)
            if len(unique_labels) > 0:
                dominant_label = unique_labels[np.argmax(counts)]
                dominant_colors = colors[labels == dominant_label]
                return np.mean(dominant_colors, axis=0).astype(int)
        
        # Fallback to mean color
        return np.mean(colors, axis=0).astype(int)
    
    def analyze_improved_contrast(self, image, segmentation):
        """
        Perform improved contrast analysis focused on important relationships
        """
        h, w = segmentation.shape
        results = {
            'critical_issues': [],
            'high_priority_issues': [],
            'medium_priority_issues': [],
            'statistics': {},
            'visualization': image.copy()
        }
        
        # Get unique segments and their categories
        unique_segments = np.unique(segmentation)
        segment_categories = {}
        segment_colors = {}
        
        for seg_id in unique_segments:
            if seg_id == 0:  # Skip background
                continue
            
            mask = segmentation == seg_id
            category = self.get_object_category(seg_id)
            segment_categories[seg_id] = category
            segment_colors[seg_id] = self.get_representative_colors(image, mask)
        
        # Analyze important relationships
        total_issues = 0
        critical_count = 0
        high_count = 0
        medium_count = 0
        
        for i, seg_id1 in enumerate(unique_segments):
            if seg_id1 == 0:
                continue
                
            category1 = segment_categories.get(seg_id1, 'other')
            if category1 == 'other':
                continue
                
            for seg_id2 in unique_segments[i+1:]:
                if seg_id2 == 0:
                    continue
                    
                category2 = segment_categories.get(seg_id2, 'other')
                if category2 == 'other':
                    continue
                
                # Check if this is an important relationship
                relationship = tuple(sorted([category1, category2]))
                priority = self.priority_relationships.get(relationship)
                
                if priority is None:
                    continue
                
                # Check if segments are adjacent
                mask1 = segmentation == seg_id1
                mask2 = segmentation == seg_id2
                boundary = self.find_clean_boundaries(mask1, mask2)
                
                if not np.any(boundary):
                    continue
                
                # Calculate contrasts
                color1 = segment_colors[seg_id1]
                color2 = segment_colors[seg_id2]
                
                wcag_contrast = self.calculate_wcag_contrast(color1, color2)
                perceptual_contrast = self.calculate_perceptual_contrast(color1, color2)
                
                # Determine if there's an issue
                wcag_issue = wcag_contrast < self.wcag_threshold
                perceptual_issue = perceptual_contrast < 0.3  # Threshold for perceptual difference
                
                if wcag_issue or perceptual_issue:
                    issue = {
                        'categories': (category1, category2),
                        'segment_ids': (seg_id1, seg_id2),
                        'wcag_contrast': wcag_contrast,
                        'perceptual_contrast': perceptual_contrast,
                        'boundary_pixels': np.sum(boundary),
                        'priority': priority
                    }
                    
                    # Color-code the boundary based on priority
                    if priority == 'critical':
                        results['critical_issues'].append(issue)
                        results['visualization'][boundary] = [255, 0, 0]  # Red
                        critical_count += 1
                    elif priority == 'high':
                        results['high_priority_issues'].append(issue)
                        results['visualization'][boundary] = [255, 128, 0]  # Orange
                        high_count += 1
                    elif priority == 'medium':
                        results['medium_priority_issues'].append(issue)
                        results['visualization'][boundary] = [255, 255, 0]  # Yellow
                        medium_count += 1
                    
                    total_issues += 1
        
        # Calculate statistics
        results['statistics'] = {
            'total_issues': total_issues,
            'critical_issues': critical_count,
            'high_priority_issues': high_count,
            'medium_priority_issues': medium_count,
            'segments_analyzed': len([cat for cat in segment_categories.values() if cat != 'other'])
        }
        
        return results

# Update your contrast detection imports and usage
class PrioritizedContrastDetector:
    """Wrapper for the improved contrast analyzer"""
    
    def __init__(self, threshold=4.5):
        self.analyzer = ImprovedContrastAnalyzer(wcag_threshold=threshold)
    
    def analyze(self, image, segmentation, threshold, highlight_color=(255, 0, 0)):
        """Analyze with improved logic"""
        results = self.analyzer.analyze_improved_contrast(image, segmentation)
        
        # Convert to format expected by original interface
        contrast_image = results['visualization']
        
        # Create a simple problem areas mask for compatibility
        problem_areas = np.any([
            contrast_image[:, :, 0] == 255,  # Any red channel highlighting
        ], axis=0)
        
        # Format statistics
        stats = results['statistics'].copy()
        stats['threshold'] = threshold
        stats['problem_count'] = stats['total_issues']
        
        # Add detailed breakdown
        if results['critical_issues']:
            stats['critical_details'] = [
                f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1"
                for issue in results['critical_issues']
            ]
        
        if results['high_priority_issues']:
            stats['high_priority_details'] = [
                f"{issue['categories'][0]}-{issue['categories'][1]}: WCAG {issue['wcag_contrast']:.1f}:1"
                for issue in results['high_priority_issues']
            ]
        
        return contrast_image, problem_areas, stats