File size: 7,985 Bytes
6524e7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import torch
import numpy as np
from PIL import Image
import cv2
import imutils
import os
import sys
import time
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.data import MetadataCatalog
from scipy import ndimage
import colorsys
import math

torch.set_num_threads(16)
torch.set_num_interop_threads(16)

from oneformer import (
    add_oneformer_config,
    add_common_config,
    add_swin_config,
    add_dinat_config,
)

from demo.defaults import DefaultPredictor
from demo.visualizer import Visualizer, ColorMode

import gradio as gr
from huggingface_hub import hf_hub_download

# NeuroNest specific imports
from utils.contrast_detector import ContrastDetector
from utils.luminance_contrast import LuminanceContrastDetector
from utils.hue_contrast import HueContrastDetector
from utils.saturation_contrast import SaturationContrastDetector
from utils.combined_contrast import CombinedContrastDetector

KEY_DICT = {
    "ADE20K (150 classes)": "ade20k",
}

SWIN_CFG_DICT = {
    "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
}

SWIN_MODEL_DICT = {
    "ade20k": hf_hub_download(
        repo_id="shi-labs/oneformer_ade20k_swin_large",
        filename="250_16_swin_l_oneformer_ade20k_160k.pth"
    )
}

DINAT_CFG_DICT = {
    "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",
}

DINAT_MODEL_DICT = {
    "ade20k": hf_hub_download(
        repo_id="shi-labs/oneformer_ade20k_dinat_large",
        filename="250_16_dinat_l_oneformer_ade20k_160k.pth"
    )
}

MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT, "Swin-L": SWIN_MODEL_DICT}
CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT, "Swin-L": SWIN_CFG_DICT}
WIDTH_DICT = {"ade20k": 640}

cpu_device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PREDICTORS = {
    "DiNAT-L": {"ADE20K (150 classes)": None},
    "Swin-L": {"ADE20K (150 classes)": None}
}

METADATA = {
    "DiNAT-L": {"ADE20K (150 classes)": None},
    "Swin-L": {"ADE20K (150 classes)": None}
}

# Contrast detector mapping
CONTRAST_DETECTORS = {
    "Luminance (WCAG)": LuminanceContrastDetector(),
    "Hue": HueContrastDetector(),
    "Saturation": SaturationContrastDetector(),
    "Combined": CombinedContrastDetector()
}

def setup_modules():
    for dataset in ["ADE20K (150 classes)"]:
        for backbone in ["DiNAT-L", "Swin-L"]:
            cfg = setup_cfg(dataset, backbone)
            metadata = MetadataCatalog.get(
                cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
            )
            PREDICTORS[backbone][dataset] = DefaultPredictor(cfg)
            METADATA[backbone][dataset] = metadata

def setup_cfg(dataset, backbone):
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_common_config(cfg)
    add_swin_config(cfg)
    add_oneformer_config(cfg)
    add_dinat_config(cfg)
    dataset = KEY_DICT[dataset]
    cfg_path = CFG_DICT[backbone][dataset]
    cfg.merge_from_file(cfg_path)
    cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
    cfg.freeze()
    return cfg

def semantic_run(img, predictor, metadata):
    visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    predictions = predictor(img, "semantic")
    out = visualizer.draw_sem_seg(
        predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
    )
    return out, predictions["sem_seg"].argmax(dim=0).to(cpu_device).numpy()

def analyze_contrast(image, segmentation, contrast_method, threshold):
    """Analyze contrast between segments using selected method"""
    detector = CONTRAST_DETECTORS[contrast_method]
    
    # Perform contrast analysis
    contrast_image, problem_areas, stats = detector.analyze(
        image, segmentation, threshold
    )
    
    return contrast_image, problem_areas, stats

def segment_and_analyze_contrast(path, backbone, contrast_method, threshold):
    """Main function to segment and analyze contrast"""
    dataset = "ADE20K (150 classes)"
    predictor = PREDICTORS[backbone][dataset]
    metadata = METADATA[backbone][dataset]
    
    # Read and resize image
    img = cv2.imread(path)
    if img is None:
        return None, None, "Error: Could not load image"
    
    width = WIDTH_DICT[KEY_DICT[dataset]]
    img = imutils.resize(img, width=width)
    
    # Get segmentation
    out, seg_mask = semantic_run(img, predictor, metadata)
    out_img = Image.fromarray(out.get_image())
    
    # Analyze contrast
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    contrast_img, problem_areas, stats = analyze_contrast(
        img_rgb, seg_mask, contrast_method, threshold
    )
    
    # Create stats text
    stats_text = f"### Contrast Analysis Results\n\n"
    stats_text += f"**Method:** {contrast_method}\n"
    stats_text += f"**Threshold:** {threshold:.2f}\n"
    stats_text += f"**Problem Areas:** {stats['problem_count']}\n"
    
    if 'min_contrast' in stats:
        stats_text += f"**Min Contrast:** {stats['min_contrast']:.2f}\n"
    if 'max_contrast' in stats:
        stats_text += f"**Max Contrast:** {stats['max_contrast']:.2f}\n"
    if 'average_contrast' in stats:
        stats_text += f"**Average Contrast:** {stats['average_contrast']:.2f}\n"
    
    # Convert contrast image to PIL
    contrast_pil = Image.fromarray(contrast_img)
    
    return out_img, contrast_pil, stats_text

# Initialize models
setup_modules()

# Gradio Interface
title = "<h1 style='text-align: center'>NeuroNest: Abheek Pradhan - Contrast Model</h1>"
description = "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> "\
              "<a href='https://github.com/lolout1/sam2Contrast' target='_blank'>Github Repo</a></p>" \
              "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'>" \
              "I am developing NeuroNest, a contrast detection system designed to identify areas with insufficient contrast " \
              "for individuals with Alzheimer's disease. This tool leverages OneFormer's state-of-the-art segmentation " \
              "capabilities trained on ADE20K dataset to detect indoor objects like floors, furniture, walls, and ceilings. " \
              "By analyzing contrast ratios between adjacent segments, NeuroNest flags potential visual accessibility issues " \
              "that may trigger confusion or disorientation in elderly individuals with cognitive impairments.</p>" \
              "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'>" \
              "[Note: When running on my Linux cluster, please request a GPU node for optimal performance. " \
              "On login nodes, CUDA may not be available.]</p>"

gradio_inputs = [
    gr.Image(label="Input Image", type="filepath"),
    gr.Radio(choices=["Swin-L", "DiNAT-L"], value="Swin-L", label="Backbone"),
    gr.Radio(
        choices=["Luminance (WCAG)", "Hue", "Saturation", "Combined"],
        value="Luminance (WCAG)",
        label="Contrast Detection Method"
    ),
    gr.Slider(
        minimum=1.0,
        maximum=10.0,
        value=4.5,
        step=0.1,
        label="Contrast Threshold (Lower = More Strict)"
    )
]

gradio_outputs = [
    gr.Image(type="pil", label="Segmentation Result"),
    gr.Image(type="pil", label="Contrast Analysis"),
    gr.Markdown(label="Analysis Results")
]

examples = [
    ["examples/indoor_room.jpg", "Swin-L", "Luminance (WCAG)", 4.5],
    ["examples/living_room.jpg", "DiNAT-L", "Combined", 3.0],
]

iface = gr.Interface(
    fn=segment_and_analyze_contrast,
    inputs=gradio_inputs,
    outputs=gradio_outputs,
    examples_per_page=5,
    allow_flagging="never",
    examples=examples,
    title=title,
    description=description
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", share=True)