lolout1 commited on
Commit
6524e7a
·
0 Parent(s):

Initial deployment to Hugging Face

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. .gitignore +17 -0
  3. CLAUDE.md +76 -0
  4. Dockerfile +61 -0
  5. NeuroNest/.gitattributes +35 -0
  6. NeuroNest/README.md +12 -0
  7. app.py +38 -0
  8. configs/.DS_Store +0 -0
  9. configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml +68 -0
  10. configs/ade20k/oneformer_R50_bs16_160k.yaml +58 -0
  11. configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml +42 -0
  12. configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml +40 -0
  13. configs/cityscapes/.DS_Store +0 -0
  14. configs/cityscapes/Base-Cityscapes-UnifiedSegmentation.yaml +68 -0
  15. configs/cityscapes/oneformer_R50_bs16_90k.yaml +59 -0
  16. configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml +22 -0
  17. configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml +20 -0
  18. configs/coco/Base-COCO-UnifiedSegmentation.yaml +54 -0
  19. configs/coco/oneformer_R50_bs16_50ep.yaml +59 -0
  20. configs/coco/oneformer_dinat_large_bs16_100ep.yaml +22 -0
  21. configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml +25 -0
  22. deform_setup.sh +21 -0
  23. demo/colormap.py +170 -0
  24. demo/defaults.py +77 -0
  25. demo/predictor.py +190 -0
  26. demo/visualizer.py +1350 -0
  27. gradio_app.py +222 -0
  28. gradio_combine.py +384 -0
  29. gradio_contrast.py +455 -0
  30. gradio_gpu.py +234 -0
  31. gradio_test.py +1080 -0
  32. oneformer/.DS_Store +0 -0
  33. oneformer/__init__.py +9 -0
  34. oneformer/config.py +239 -0
  35. oneformer/data/__init__.py +2 -0
  36. oneformer/data/bpe_simple_vocab_16e6.txt +0 -0
  37. oneformer/data/bpe_simple_vocab_16e6.txt.gz +3 -0
  38. oneformer/data/build.py +117 -0
  39. oneformer/data/dataset_mappers/__init__.py +1 -0
  40. oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py +341 -0
  41. oneformer/data/dataset_mappers/dataset_mapper.py +203 -0
  42. oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py +375 -0
  43. oneformer/data/datasets/__init__.py +7 -0
  44. oneformer/data/datasets/register_ade20k_instance.py +56 -0
  45. oneformer/data/datasets/register_ade20k_panoptic.py +394 -0
  46. oneformer/data/datasets/register_cityscapes_panoptic.py +199 -0
  47. oneformer/data/datasets/register_coco_panoptic2instance.py +44 -0
  48. oneformer/data/datasets/register_coco_panoptic_annos_semseg.py +367 -0
  49. oneformer/data/tokenizer.py +200 -0
  50. oneformer/evaluation/__init__.py +3 -0
.gitattributes ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lfs merge=lfs -text
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ *.onnx filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.so filter=lfs diff=lfs merge=lfs -text
7
+ *.o filter=lfs diff=lfs merge=lfs -text
8
+ *.egg filter=lfs diff=lfs merge=lfs -text
9
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
10
+ *.tgz filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.out
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ slurm-*.out
7
+ output_*/
8
+ build/
9
+ dist/
10
+ *.egg-info/
11
+ .env
12
+ .claude/
13
+ oneformer/modeling/pixel_decoder/ops/build/
14
+ oneformer/modeling/pixel_decoder/ops/dist/
15
+ oneformer/modeling/pixel_decoder/ops/*.egg-info/
16
+ output_floor_blackspot/
17
+ environment.yml
CLAUDE.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+ This is a OneFormer-based universal image segmentation application with integrated contrast detection capabilities. It supports panoptic, instance, and semantic segmentation using transformer-based models (Swin-L and DiNAT-L backbones).
7
+
8
+ ## Essential Commands
9
+
10
+ ### Setup and Installation
11
+ ```bash
12
+ # Download model checkpoints from HuggingFace
13
+ python setup.py --download
14
+
15
+ # Build deformable attention CUDA operations (required)
16
+ bash deform_setup.sh
17
+
18
+ # Install NATTEN if needed
19
+ cd NATTEN && pip install -e .
20
+ ```
21
+
22
+ ### Running the Application
23
+ ```bash
24
+ # Main application launcher
25
+ bash t.sh
26
+
27
+ # Or run directly
28
+ python gradio_test.py
29
+ ```
30
+
31
+ ### Testing
32
+ ```bash
33
+ # Run NATTEN tests
34
+ cd NATTEN && python -m unittest discover -v -s ./tests
35
+ ```
36
+
37
+ ## Architecture Overview
38
+
39
+ ### Core Components
40
+ 1. **OneFormer Model** (`oneformer/`): Universal segmentation transformer implementation
41
+ - `oneformer_model.py`: Main model class
42
+ - `modeling/`: Model components including backbones, decoders, and criterion
43
+ - `data/`: Dataset mappers and tokenizers for text queries
44
+
45
+ 2. **Gradio Applications** (multiple entry points):
46
+ - `gradio_test.py`: Main combined interface with pre/post-processing logic
47
+ - `gradio_app.py`: Basic segmentation interface
48
+ - `gradio_contrast.py`: Contrast detection focused interface
49
+
50
+ 3. **Contrast Detection** (`utils/`): Multiple contrast analysis algorithms
51
+ - `improved_contrast_analyzer.py`: Main analyzer combining all methods
52
+ - Individual analyzers: luminance, saturation, hue contrast detection
53
+
54
+ 4. **Model Configurations** (`configs/`): Pre-defined configs for:
55
+ - ADE20K (150 classes)
56
+ - Cityscapes (19 classes)
57
+ - COCO (133 classes)
58
+
59
+ ### Key Implementation Details
60
+ - Models are loaded from HuggingFace hub (shi-labs organization)
61
+ - Supports both CPU and GPU inference (auto-detection)
62
+ - Custom CUDA kernels for multi-scale deformable attention
63
+ - NATTEN library provides neighborhood attention operations
64
+
65
+ ### Model Checkpoints
66
+ The application uses pre-trained models from:
67
+ - `shi-labs/oneformer_cityscapes_swin_large`
68
+ - `shi-labs/oneformer_coco_swin_large`
69
+ - `shi-labs/oneformer_ade20k_swin_large`
70
+ - DiNAT variants also available
71
+
72
+ ### Important Notes
73
+ - Always run `deform_setup.sh` after environment setup for CUDA operations
74
+ - The application expects model checkpoints to be downloaded before running
75
+ - Gradio interface runs on port 7860 by default
76
+ - GPU memory requirements vary by model (Swin-L models need ~12GB)
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
2
+ CMD nvidia-smi
3
+
4
+ ENV DEBIAN_FRONTEND noninteractive
5
+ RUN apt-get update && apt-get install -y \
6
+ git \
7
+ make build-essential libssl-dev zlib1g-dev \
8
+ libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
9
+ libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev \
10
+ ffmpeg libsm6 libxext6 cmake libgl1-mesa-glx \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ RUN useradd -ms /bin/bash user
14
+ USER user
15
+
16
+ ENV HOME=/home/user \
17
+ PATH=/home/user/.local/bin:$PATH
18
+
19
+ RUN curl https://pyenv.run | bash
20
+ ENV PATH=$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH
21
+ RUN pyenv install 3.8.15 && \
22
+ pyenv global 3.8.15 && \
23
+ pyenv rehash && \
24
+ pip install --no-cache-dir --upgrade pip setuptools wheel
25
+
26
+ ENV WORKDIR=/code
27
+ WORKDIR $WORKDIR
28
+ RUN chown -R user:user $WORKDIR
29
+ RUN chmod -R 777 $WORKDIR
30
+
31
+ COPY requirements.txt $WORKDIR/requirements.txt
32
+ RUN pip install --no-cache-dir --upgrade -r $WORKDIR/requirements.txt
33
+ RUN pip install ninja
34
+
35
+ COPY . .
36
+
37
+ ARG TORCH_CUDA_ARCH_LIST=7.5+PTX
38
+
39
+ USER root
40
+ RUN chown -R user:user $HOME
41
+ RUN chmod -R 777 $HOME
42
+ RUN chown -R user:user $WORKDIR
43
+ RUN chmod -R 777 $WORKDIR
44
+
45
+ USER user
46
+ RUN ln -s $WORKDIR/oneformer/modeling/pixel_decoder/ops/ $WORKDIR/ && ls && cd ops/ && FORCE_CUDA=1 python setup.py build --build-base=$WORKDIR/ install --user && cd ..
47
+ RUN sh deform_setup.sh
48
+
49
+ USER user
50
+ RUN sh deform_setup.sh
51
+
52
+ RUN mkdir -p examples
53
+ RUN wget https://praeclarumjj3.github.io/files/ade20k.jpeg -P $WORKDIR/examples/
54
+ RUN wget https://praeclarumjj3.github.io/files/cityscapes.png -P $WORKDIR/examples/
55
+ RUN wget https://praeclarumjj3.github.io/files/coco.jpeg -P $WORKDIR/examples/
56
+
57
+ USER user
58
+
59
+ EXPOSE 7860
60
+
61
+ ENTRYPOINT ["python", "gradio_app.py"]
NeuroNest/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
NeuroNest/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NeuroNest
3
+ emoji: 🏆
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Spaces entry point for OneFormer application
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import subprocess
9
+
10
+ # Set up environment variables for HF Spaces
11
+ os.environ['CUDA_HOME'] = '/usr/local/cuda' if os.path.exists('/usr/local/cuda') else ''
12
+
13
+ # Install deformable attention ops if not already installed
14
+ def setup_deformable_attention():
15
+ ops_dir = os.path.join(os.path.dirname(__file__), 'oneformer/modeling/pixel_decoder/ops')
16
+ if os.path.exists(ops_dir):
17
+ try:
18
+ subprocess.run(['bash', 'deform_setup.sh'], check=True, cwd=os.path.dirname(__file__))
19
+ print("Deformable attention ops installed successfully")
20
+ except Exception as e:
21
+ print(f"Warning: Could not install deformable attention ops: {e}")
22
+ print("Continuing without custom CUDA kernels...")
23
+
24
+ # Run setup on first launch
25
+ if not os.path.exists('oneformer/modeling/pixel_decoder/ops/build'):
26
+ setup_deformable_attention()
27
+
28
+ # Import and run the main gradio app
29
+ from gradio_test import demo
30
+
31
+ if __name__ == "__main__":
32
+ # Launch with HF Spaces compatible settings
33
+ demo.launch(
34
+ server_name="0.0.0.0",
35
+ server_port=7860,
36
+ share=False, # Disabled on HF Spaces
37
+ debug=False
38
+ )
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ BACKBONE:
3
+ FREEZE_AT: 0
4
+ NAME: "build_resnet_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
6
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
7
+ PIXEL_STD: [58.395, 57.120, 57.375]
8
+ RESNETS:
9
+ DEPTH: 50
10
+ STEM_TYPE: "basic" # not used
11
+ STEM_OUT_CHANNELS: 64
12
+ STRIDE_IN_1X1: False
13
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
14
+ # NORM: "SyncBN"
15
+ RES5_MULTI_GRID: [1, 1, 1] # not used
16
+ DATASETS:
17
+ TRAIN: ("ade20k_panoptic_train",)
18
+ TEST_PANOPTIC: ("ade20k_panoptic_val",)
19
+ TEST_INSTANCE: ("ade20k_instance_val",)
20
+ TEST_SEMANTIC: ("ade20k_sem_seg_val",)
21
+ SOLVER:
22
+ IMS_PER_BATCH: 16
23
+ BASE_LR: 0.0001
24
+ MAX_ITER: 160000
25
+ WARMUP_FACTOR: 1.0
26
+ WARMUP_ITERS: 0
27
+ WEIGHT_DECAY: 0.05
28
+ OPTIMIZER: "ADAMW"
29
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
30
+ BACKBONE_MULTIPLIER: 0.1
31
+ CLIP_GRADIENTS:
32
+ ENABLED: True
33
+ CLIP_TYPE: "full_model"
34
+ CLIP_VALUE: 0.01
35
+ NORM_TYPE: 2.0
36
+ AMP:
37
+ ENABLED: True
38
+ INPUT:
39
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
40
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
41
+ MIN_SIZE_TEST: 512
42
+ MAX_SIZE_TRAIN: 2048
43
+ MAX_SIZE_TEST: 2048
44
+ CROP:
45
+ ENABLED: True
46
+ TYPE: "absolute"
47
+ SIZE: (512, 512)
48
+ SINGLE_CATEGORY_MAX_AREA: 1.0
49
+ COLOR_AUG_SSD: True
50
+ SIZE_DIVISIBILITY: 512 # used in dataset mapper
51
+ FORMAT: "RGB"
52
+ DATASET_MAPPER_NAME: "oneformer_unified"
53
+ MAX_SEQ_LEN: 77
54
+ TASK_SEQ_LEN: 77
55
+ TASK_PROB:
56
+ SEMANTIC: 0.33
57
+ INSTANCE: 0.66
58
+ TEST:
59
+ EVAL_PERIOD: 5000
60
+ AUG:
61
+ ENABLED: False
62
+ MIN_SIZES: [256, 384, 512, 640, 768, 896]
63
+ MAX_SIZE: 3584
64
+ FLIP: True
65
+ DATALOADER:
66
+ FILTER_EMPTY_ANNOTATIONS: True
67
+ NUM_WORKERS: 4
68
+ VERSION: 2
configs/ade20k/oneformer_R50_bs16_160k.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-ADE20K-UnifiedSegmentation.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "OneFormer"
4
+ SEM_SEG_HEAD:
5
+ NAME: "OneFormerHead"
6
+ IGNORE_VALUE: 255
7
+ NUM_CLASSES: 150
8
+ LOSS_WEIGHT: 1.0
9
+ CONVS_DIM: 256
10
+ MASK_DIM: 256
11
+ NORM: "GN"
12
+ # pixel decoder
13
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
14
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
16
+ COMMON_STRIDE: 4
17
+ TRANSFORMER_ENC_LAYERS: 6
18
+ ONE_FORMER:
19
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
20
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
21
+ DEEP_SUPERVISION: True
22
+ NO_OBJECT_WEIGHT: 0.1
23
+ CLASS_WEIGHT: 2.0
24
+ MASK_WEIGHT: 5.0
25
+ DICE_WEIGHT: 5.0
26
+ CONTRASTIVE_WEIGHT: 0.5
27
+ CONTRASTIVE_TEMPERATURE: 0.07
28
+ HIDDEN_DIM: 256
29
+ NUM_OBJECT_QUERIES: 150
30
+ USE_TASK_NORM: True
31
+ NHEADS: 8
32
+ DROPOUT: 0.1
33
+ DIM_FEEDFORWARD: 2048
34
+ ENC_LAYERS: 0
35
+ PRE_NORM: False
36
+ ENFORCE_INPUT_PROJ: False
37
+ SIZE_DIVISIBILITY: 32
38
+ CLASS_DEC_LAYERS: 2
39
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
40
+ TRAIN_NUM_POINTS: 12544
41
+ OVERSAMPLE_RATIO: 3.0
42
+ IMPORTANCE_SAMPLE_RATIO: 0.75
43
+ TEXT_ENCODER:
44
+ WIDTH: 256
45
+ CONTEXT_LENGTH: 77
46
+ NUM_LAYERS: 6
47
+ VOCAB_SIZE: 49408
48
+ PROJ_NUM_LAYERS: 2
49
+ N_CTX: 16
50
+ TEST:
51
+ SEMANTIC_ON: True
52
+ INSTANCE_ON: True
53
+ PANOPTIC_ON: True
54
+ OVERLAP_THRESHOLD: 0.8
55
+ OBJECT_MASK_THRESHOLD: 0.8
56
+ TASK: "panoptic"
57
+ TEST:
58
+ DETECTIONS_PER_IMAGE: 150
configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2DiNAT"
5
+ DiNAT:
6
+ EMBED_DIM: 192
7
+ MLP_RATIO: 2.0
8
+ DEPTHS: [3, 4, 18, 5]
9
+ NUM_HEADS: [6, 12, 24, 48]
10
+ KERNEL_SIZE: 11
11
+ DROP_PATH_RATE: 0.3
12
+ DILATIONS: [[1, 20, 1], [1, 5, 1, 10], [1, 2, 1, 3, 1, 4, 1, 5, 1, 2, 1, 3, 1, 4, 1, 5, 1, 5], [1, 2, 1, 2, 1]]
13
+ WEIGHTS: "dinat_large_in22k_in1k_384_11x11.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ ONE_FORMER:
17
+ NUM_OBJECT_QUERIES: 250
18
+ SOLVER:
19
+ AMP:
20
+ ENABLED: False
21
+ INPUT:
22
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
23
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
24
+ MIN_SIZE_TEST: 640
25
+ MAX_SIZE_TRAIN: 2560
26
+ MAX_SIZE_TEST: 2560
27
+ CROP:
28
+ ENABLED: True
29
+ TYPE: "absolute"
30
+ SIZE: (640, 640)
31
+ SINGLE_CATEGORY_MAX_AREA: 1.0
32
+ COLOR_AUG_SSD: True
33
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
34
+ FORMAT: "RGB"
35
+ TEST:
36
+ DETECTIONS_PER_IMAGE: 250
37
+ EVAL_PERIOD: 5000
38
+ AUG:
39
+ ENABLED: False
40
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
41
+ MAX_SIZE: 4480
42
+ FLIP: True
configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_160k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 192
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [6, 12, 24, 48]
9
+ WINDOW_SIZE: 12
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ PRETRAIN_IMG_SIZE: 384
14
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ ONE_FORMER:
18
+ NUM_OBJECT_QUERIES: 250
19
+ INPUT:
20
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
21
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
22
+ MIN_SIZE_TEST: 640
23
+ MAX_SIZE_TRAIN: 2560
24
+ MAX_SIZE_TEST: 2560
25
+ CROP:
26
+ ENABLED: True
27
+ TYPE: "absolute"
28
+ SIZE: (640, 640)
29
+ SINGLE_CATEGORY_MAX_AREA: 1.0
30
+ COLOR_AUG_SSD: True
31
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
32
+ FORMAT: "RGB"
33
+ TEST:
34
+ DETECTIONS_PER_IMAGE: 250
35
+ EVAL_PERIOD: 5000
36
+ AUG:
37
+ ENABLED: False
38
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
39
+ MAX_SIZE: 4480
40
+ FLIP: True
configs/cityscapes/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/cityscapes/Base-Cityscapes-UnifiedSegmentation.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ BACKBONE:
3
+ FREEZE_AT: 0
4
+ NAME: "build_resnet_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
6
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
7
+ PIXEL_STD: [58.395, 57.120, 57.375]
8
+ RESNETS:
9
+ DEPTH: 50
10
+ STEM_TYPE: "basic" # not used
11
+ STEM_OUT_CHANNELS: 64
12
+ STRIDE_IN_1X1: False
13
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
14
+ NORM: "SyncBN" # use syncbn for cityscapes dataset
15
+ RES5_MULTI_GRID: [1, 1, 1] # not used
16
+ DATASETS:
17
+ TRAIN: ("cityscapes_fine_panoptic_train",)
18
+ TEST_PANOPTIC: ("cityscapes_fine_panoptic_val",)
19
+ TEST_INSTANCE: ("cityscapes_fine_instance_seg_val",)
20
+ TEST_SEMANTIC: ("cityscapes_fine_sem_seg_val",)
21
+ SOLVER:
22
+ IMS_PER_BATCH: 16
23
+ BASE_LR: 0.0001
24
+ MAX_ITER: 90000
25
+ WARMUP_FACTOR: 1.0
26
+ WARMUP_ITERS: 0
27
+ WEIGHT_DECAY: 0.05
28
+ OPTIMIZER: "ADAMW"
29
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
30
+ BACKBONE_MULTIPLIER: 0.1
31
+ CLIP_GRADIENTS:
32
+ ENABLED: True
33
+ CLIP_TYPE: "full_model"
34
+ CLIP_VALUE: 0.01
35
+ NORM_TYPE: 2.0
36
+ AMP:
37
+ ENABLED: True
38
+ INPUT:
39
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"]
40
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
41
+ MIN_SIZE_TEST: 1024
42
+ MAX_SIZE_TRAIN: 4096
43
+ MAX_SIZE_TEST: 2048
44
+ CROP:
45
+ ENABLED: True
46
+ TYPE: "absolute"
47
+ SIZE: (512, 1024)
48
+ SINGLE_CATEGORY_MAX_AREA: 1.0
49
+ COLOR_AUG_SSD: True
50
+ SIZE_DIVISIBILITY: -1
51
+ FORMAT: "RGB"
52
+ DATASET_MAPPER_NAME: "oneformer_unified"
53
+ MAX_SEQ_LEN: 77
54
+ TASK_SEQ_LEN: 77
55
+ TASK_PROB:
56
+ SEMANTIC: 0.33
57
+ INSTANCE: 0.66
58
+ TEST:
59
+ EVAL_PERIOD: 5000
60
+ AUG:
61
+ ENABLED: False
62
+ MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792]
63
+ MAX_SIZE: 4096
64
+ FLIP: True
65
+ DATALOADER:
66
+ FILTER_EMPTY_ANNOTATIONS: True
67
+ NUM_WORKERS: 4
68
+ VERSION: 2
configs/cityscapes/oneformer_R50_bs16_90k.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-Cityscapes-UnifiedSegmentation.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "OneFormer"
4
+ SEM_SEG_HEAD:
5
+ NAME: "OneFormerHead"
6
+ IGNORE_VALUE: 255
7
+ NUM_CLASSES: 19
8
+ LOSS_WEIGHT: 1.0
9
+ CONVS_DIM: 256
10
+ MASK_DIM: 256
11
+ NORM: "GN"
12
+ # pixel decoder
13
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
14
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
16
+ COMMON_STRIDE: 4
17
+ TRANSFORMER_ENC_LAYERS: 6
18
+ ONE_FORMER:
19
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
20
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
21
+ DEEP_SUPERVISION: True
22
+ NO_OBJECT_WEIGHT: 0.1
23
+ CLASS_WEIGHT: 2.0
24
+ MASK_WEIGHT: 5.0
25
+ DICE_WEIGHT: 5.0
26
+ CONTRASTIVE_WEIGHT: 0.5
27
+ CONTRASTIVE_TEMPERATURE: 0.07
28
+ HIDDEN_DIM: 256
29
+ NUM_OBJECT_QUERIES: 150
30
+ USE_TASK_NORM: True
31
+ NHEADS: 8
32
+ DROPOUT: 0.1
33
+ DIM_FEEDFORWARD: 2048
34
+ ENC_LAYERS: 0
35
+ PRE_NORM: False
36
+ ENFORCE_INPUT_PROJ: False
37
+ SIZE_DIVISIBILITY: 32
38
+ ENC_LAYERS: 0
39
+ CLASS_DEC_LAYERS: 2
40
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
41
+ TRAIN_NUM_POINTS: 12544
42
+ OVERSAMPLE_RATIO: 3.0
43
+ IMPORTANCE_SAMPLE_RATIO: 0.75
44
+ TEXT_ENCODER:
45
+ WIDTH: 256
46
+ CONTEXT_LENGTH: 77
47
+ NUM_LAYERS: 6
48
+ VOCAB_SIZE: 49408
49
+ PROJ_NUM_LAYERS: 2
50
+ N_CTX: 16
51
+ TEST:
52
+ SEMANTIC_ON: True
53
+ INSTANCE_ON: True
54
+ PANOPTIC_ON: True
55
+ OVERLAP_THRESHOLD: 0.8
56
+ OBJECT_MASK_THRESHOLD: 0.8
57
+ TASK: "panoptic"
58
+ TEST:
59
+ DETECTIONS_PER_IMAGE: 150
configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_90k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2DiNAT"
5
+ DiNAT:
6
+ EMBED_DIM: 192
7
+ MLP_RATIO: 2.0
8
+ DEPTHS: [3, 4, 18, 5]
9
+ NUM_HEADS: [6, 12, 24, 48]
10
+ KERNEL_SIZE: 7
11
+ DROP_PATH_RATE: 0.3
12
+ DILATIONS: [[1, 18, 1], [1, 5, 1, 9], [1, 2, 1, 3, 1, 4, 1, 2, 1, 3, 1, 4, 1, 2, 1, 3, 1, 4], [1, 2, 1, 2, 1]]
13
+ WEIGHTS: "dinat_large_in22k_224.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ ONE_FORMER:
17
+ NUM_OBJECT_QUERIES: 250
18
+ SOLVER:
19
+ AMP:
20
+ ENABLED: False
21
+ TEST:
22
+ DETECTIONS_PER_IMAGE: 250
configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_90k.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 192
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [6, 12, 24, 48]
9
+ WINDOW_SIZE: 12
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ PRETRAIN_IMG_SIZE: 384
14
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ ONE_FORMER:
18
+ NUM_OBJECT_QUERIES: 250
19
+ TEST:
20
+ DETECTIONS_PER_IMAGE: 250
configs/coco/Base-COCO-UnifiedSegmentation.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ BACKBONE:
3
+ FREEZE_AT: 0
4
+ NAME: "build_resnet_backbone"
5
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
6
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
7
+ PIXEL_STD: [58.395, 57.120, 57.375]
8
+ RESNETS:
9
+ DEPTH: 50
10
+ STEM_TYPE: "basic" # not used
11
+ STEM_OUT_CHANNELS: 64
12
+ STRIDE_IN_1X1: False
13
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
14
+ # NORM: "SyncBN"
15
+ RES5_MULTI_GRID: [1, 1, 1] # not used
16
+ DATASETS:
17
+ TRAIN: ("coco_2017_train_panoptic_with_sem_seg",)
18
+ TEST_PANOPTIC: ("coco_2017_val_panoptic_with_sem_seg",) # to evaluate instance and semantic performance as well
19
+ TEST_INSTANCE: ("coco_2017_val",)
20
+ TEST_SEMANTIC: ("coco_2017_val_panoptic_with_sem_seg",)
21
+ SOLVER:
22
+ IMS_PER_BATCH: 16
23
+ BASE_LR: 0.0001
24
+ STEPS: (327778, 355092)
25
+ MAX_ITER: 368750
26
+ WARMUP_FACTOR: 1.0
27
+ WARMUP_ITERS: 10
28
+ WEIGHT_DECAY: 0.05
29
+ OPTIMIZER: "ADAMW"
30
+ BACKBONE_MULTIPLIER: 0.1
31
+ CLIP_GRADIENTS:
32
+ ENABLED: True
33
+ CLIP_TYPE: "full_model"
34
+ CLIP_VALUE: 0.01
35
+ NORM_TYPE: 2.0
36
+ AMP:
37
+ ENABLED: True
38
+ INPUT:
39
+ IMAGE_SIZE: 1024
40
+ MIN_SCALE: 0.1
41
+ MAX_SCALE: 2.0
42
+ FORMAT: "RGB"
43
+ DATASET_MAPPER_NAME: "coco_unified_lsj"
44
+ MAX_SEQ_LEN: 77
45
+ TASK_SEQ_LEN: 77
46
+ TASK_PROB:
47
+ SEMANTIC: 0.33
48
+ INSTANCE: 0.66
49
+ TEST:
50
+ EVAL_PERIOD: 5000
51
+ DATALOADER:
52
+ FILTER_EMPTY_ANNOTATIONS: True
53
+ NUM_WORKERS: 4
54
+ VERSION: 2
configs/coco/oneformer_R50_bs16_50ep.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: Base-COCO-UnifiedSegmentation.yaml
2
+ MODEL:
3
+ META_ARCHITECTURE: "OneFormer"
4
+ SEM_SEG_HEAD:
5
+ NAME: "OneFormerHead"
6
+ IGNORE_VALUE: 255
7
+ NUM_CLASSES: 133
8
+ LOSS_WEIGHT: 1.0
9
+ CONVS_DIM: 256
10
+ MASK_DIM: 256
11
+ NORM: "GN"
12
+ # pixel decoder
13
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
14
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
15
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
16
+ COMMON_STRIDE: 4
17
+ TRANSFORMER_ENC_LAYERS: 6
18
+ ONE_FORMER:
19
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
20
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
21
+ DEEP_SUPERVISION: True
22
+ NO_OBJECT_WEIGHT: 0.1
23
+ CLASS_WEIGHT: 2.0
24
+ MASK_WEIGHT: 5.0
25
+ DICE_WEIGHT: 5.0
26
+ CONTRASTIVE_WEIGHT: 0.5
27
+ CONTRASTIVE_TEMPERATURE: 0.07
28
+ HIDDEN_DIM: 256
29
+ NUM_OBJECT_QUERIES: 150
30
+ USE_TASK_NORM: True
31
+ NHEADS: 8
32
+ DROPOUT: 0.1
33
+ DIM_FEEDFORWARD: 2048
34
+ ENC_LAYERS: 0
35
+ PRE_NORM: False
36
+ ENFORCE_INPUT_PROJ: False
37
+ SIZE_DIVISIBILITY: 32
38
+ CLASS_DEC_LAYERS: 2
39
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
40
+ TRAIN_NUM_POINTS: 12544
41
+ OVERSAMPLE_RATIO: 3.0
42
+ IMPORTANCE_SAMPLE_RATIO: 0.75
43
+ TEXT_ENCODER:
44
+ WIDTH: 256
45
+ CONTEXT_LENGTH: 77
46
+ NUM_LAYERS: 6
47
+ VOCAB_SIZE: 49408
48
+ PROJ_NUM_LAYERS: 2
49
+ N_CTX: 16
50
+ TEST:
51
+ SEMANTIC_ON: True
52
+ INSTANCE_ON: True
53
+ PANOPTIC_ON: True
54
+ DETECTION_ON: False
55
+ OVERLAP_THRESHOLD: 0.8
56
+ OBJECT_MASK_THRESHOLD: 0.8
57
+ TASK: "panoptic"
58
+ TEST:
59
+ DETECTIONS_PER_IMAGE: 150
configs/coco/oneformer_dinat_large_bs16_100ep.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_50ep.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2DiNAT"
5
+ DiNAT:
6
+ EMBED_DIM: 192
7
+ MLP_RATIO: 2.0
8
+ DEPTHS: [3, 4, 18, 5]
9
+ NUM_HEADS: [6, 12, 24, 48]
10
+ KERNEL_SIZE: 11
11
+ DROP_PATH_RATE: 0.3
12
+ DILATIONS: [[1, 20, 1], [1, 5, 1, 10], [1, 2, 1, 3, 1, 4, 1, 5, 1, 2, 1, 3, 1, 4, 1, 5, 1, 5], [1, 2, 1, 2, 1]]
13
+ WEIGHTS: "dinat_large_in22k_in1k_384_11x11.pkl"
14
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
15
+ PIXEL_STD: [58.395, 57.120, 57.375]
16
+ ONE_FORMER:
17
+ NUM_OBJECT_QUERIES: 150
18
+ SOLVER:
19
+ STEPS: (655556, 710184)
20
+ MAX_ITER: 737500
21
+ TEST:
22
+ DETECTIONS_PER_IMAGE: 150
configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: oneformer_R50_bs16_50ep.yaml
2
+ MODEL:
3
+ BACKBONE:
4
+ NAME: "D2SwinTransformer"
5
+ SWIN:
6
+ EMBED_DIM: 192
7
+ DEPTHS: [2, 2, 18, 2]
8
+ NUM_HEADS: [6, 12, 24, 48]
9
+ WINDOW_SIZE: 12
10
+ APE: False
11
+ DROP_PATH_RATE: 0.3
12
+ PATCH_NORM: True
13
+ PRETRAIN_IMG_SIZE: 384
14
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
15
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
16
+ PIXEL_STD: [58.395, 57.120, 57.375]
17
+ ONE_FORMER:
18
+ NUM_OBJECT_QUERIES: 150
19
+ SOLVER:
20
+ STEPS: (655556, 735184)
21
+ MAX_ITER: 737500
22
+ AMP:
23
+ ENABLED: False
24
+ TEST:
25
+ DETECTIONS_PER_IMAGE: 150
deform_setup.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # ln -s ./oneformer/modeling/pixel_decoder/ops/ ./
4
+ # ls
5
+ # cd ops/ && bash make.sh && cd ..
6
+ echo '----------------------------------------------------------------'
7
+ echo '----------------------------------------------------------------'
8
+ pip3 freeze | grep MultiScaleDeformableAttention
9
+ pip3 freeze | grep torch
10
+ pip3 freeze | grep detectron2
11
+ pip3 freeze | grep natten
12
+ echo '----------------------------------------------------------------'
13
+ echo '----------------------------------------------------------------'
14
+
15
+ # echo '----------------------------------------------------------------'
16
+ # echo '----------------------------------------------------------------'
17
+ # cd /home/user/.pyenv/versions/3.8.15/lib/python3.8/site-packages
18
+ # ls
19
+ # ls | grep MultiScale
20
+ # echo '----------------------------------------------------------------'
21
+ # echo '----------------------------------------------------------------'
demo/colormap.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ """
4
+ An awesome colormap for really neat visualizations.
5
+ Copied from Detectron, and removed gray colors.
6
+ """
7
+
8
+ import numpy as np
9
+ import random
10
+ random.seed(0)
11
+
12
+ __all__ = ["colormap", "random_color", "random_colors"]
13
+
14
+ # fmt: off
15
+ # RGB:
16
+ # _COLORS = np.array(
17
+ # [
18
+ # 0.000, 0.447, 0.741,
19
+ # 0.850, 0.325, 0.098,
20
+ # 0.929, 0.694, 0.125,
21
+ # 0.494, 0.184, 0.556,
22
+ # 0.466, 0.674, 0.188,
23
+ # 0.301, 0.745, 0.933,
24
+ # 0.635, 0.078, 0.184,
25
+ # 0.300, 0.300, 0.300,
26
+ # 0.600, 0.600, 0.600,
27
+ # 1.000, 0.000, 0.000,
28
+ # 1.000, 0.500, 0.000,
29
+ # 0.749, 0.749, 0.000,
30
+ # 0.000, 1.000, 0.000,
31
+ # 0.000, 0.000, 1.000,
32
+ # 0.667, 0.000, 1.000,
33
+ # 0.333, 0.333, 0.000,
34
+ # 0.333, 0.667, 0.000,
35
+ # 0.333, 1.000, 0.000,
36
+ # 0.667, 0.333, 0.000,
37
+ # 0.667, 0.667, 0.000,
38
+ # 0.667, 1.000, 0.000,
39
+ # 1.000, 0.333, 0.000,
40
+ # 1.000, 0.667, 0.000,
41
+ # 1.000, 1.000, 0.000,
42
+ # 0.000, 0.333, 0.500,
43
+ # 0.000, 0.667, 0.500,
44
+ # 0.000, 1.000, 0.500,
45
+ # 0.333, 0.000, 0.500,
46
+ # 0.333, 0.333, 0.500,
47
+ # 0.333, 0.667, 0.500,
48
+ # 0.333, 1.000, 0.500,
49
+ # 0.667, 0.000, 0.500,
50
+ # 0.667, 0.333, 0.500,
51
+ # 0.667, 0.667, 0.500,
52
+ # 0.667, 1.000, 0.500,
53
+ # 1.000, 0.000, 0.500,
54
+ # 1.000, 0.333, 0.500,
55
+ # 1.000, 0.667, 0.500,
56
+ # 1.000, 1.000, 0.500,
57
+ # 0.000, 0.333, 1.000,
58
+ # 0.000, 0.667, 1.000,
59
+ # 0.000, 1.000, 1.000,
60
+ # 0.333, 0.000, 1.000,
61
+ # 0.333, 0.333, 1.000,
62
+ # 0.333, 0.667, 1.000,
63
+ # 0.333, 1.000, 1.000,
64
+ # 0.667, 0.000, 1.000,
65
+ # 0.667, 0.333, 1.000,
66
+ # 0.667, 0.667, 1.000,
67
+ # 0.667, 1.000, 1.000,
68
+ # 1.000, 0.000, 1.000,
69
+ # 1.000, 0.333, 1.000,
70
+ # 1.000, 0.667, 1.000,
71
+ # 0.333, 0.000, 0.000,
72
+ # 0.500, 0.000, 0.000,
73
+ # 0.667, 0.000, 0.000,
74
+ # 0.833, 0.000, 0.000,
75
+ # 1.000, 0.000, 0.000,
76
+ # 0.000, 0.167, 0.000,
77
+ # 0.000, 0.333, 0.000,
78
+ # 0.000, 0.500, 0.000,
79
+ # 0.000, 0.667, 0.000,
80
+ # 0.000, 0.833, 0.000,
81
+ # 0.000, 1.000, 0.000,
82
+ # 0.000, 0.000, 0.167,
83
+ # 0.000, 0.000, 0.333,
84
+ # 0.000, 0.000, 0.500,
85
+ # 0.000, 0.000, 0.667,
86
+ # 0.000, 0.000, 0.833,
87
+ # 0.000, 0.000, 1.000,
88
+ # 0.000, 0.000, 0.000,
89
+ # 0.143, 0.143, 0.143,
90
+ # 0.857, 0.857, 0.857,
91
+ # 1.000, 1.000, 1.000
92
+ # ]
93
+ # ).astype(np.float32).reshape(-1, 3)
94
+ # fmt: on
95
+
96
+ _COLORS = []
97
+
98
+
99
+ def gen_color():
100
+ color = tuple(np.round(np.random.choice(range(256), size=3)/255, 3))
101
+ if color not in _COLORS and np.mean(color) != 0.0:
102
+ _COLORS.append(color)
103
+ else:
104
+ gen_color()
105
+
106
+
107
+ for _ in range(300):
108
+ gen_color()
109
+
110
+
111
+ def colormap(rgb=False, maximum=255):
112
+ """
113
+ Args:
114
+ rgb (bool): whether to return RGB colors or BGR colors.
115
+ maximum (int): either 255 or 1
116
+ Returns:
117
+ ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
118
+ """
119
+ assert maximum in [255, 1], maximum
120
+ c = _COLORS * maximum
121
+ if not rgb:
122
+ c = c[:, ::-1]
123
+ return c
124
+
125
+
126
+ def random_color(rgb=False, maximum=255):
127
+ """
128
+ Args:
129
+ rgb (bool): whether to return RGB colors or BGR colors.
130
+ maximum (int): either 255 or 1
131
+ Returns:
132
+ ndarray: a vector of 3 numbers
133
+ """
134
+ idx = np.random.randint(0, len(_COLORS))
135
+ ret = _COLORS[idx] * maximum
136
+ if not rgb:
137
+ ret = ret[::-1]
138
+ return ret
139
+
140
+
141
+ def random_colors(N, rgb=False, maximum=255):
142
+ """
143
+ Args:
144
+ N (int): number of unique colors needed
145
+ rgb (bool): whether to return RGB colors or BGR colors.
146
+ maximum (int): either 255 or 1
147
+ Returns:
148
+ ndarray: a list of random_color
149
+ """
150
+ indices = random.sample(range(len(_COLORS)), N)
151
+ ret = [_COLORS[i] * maximum for i in indices]
152
+ if not rgb:
153
+ ret = [x[::-1] for x in ret]
154
+ return ret
155
+
156
+
157
+ if __name__ == "__main__":
158
+ import cv2
159
+
160
+ size = 100
161
+ H, W = 10, 10
162
+ canvas = np.random.rand(H * size, W * size, 3).astype("float32")
163
+ for h in range(H):
164
+ for w in range(W):
165
+ idx = h * W + w
166
+ if idx >= len(_COLORS):
167
+ break
168
+ canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
169
+ cv2.imshow("a", canvas)
170
+ cv2.waitKey(0)
demo/defaults.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import detectron2.data.transforms as T
3
+ from detectron2.checkpoint import DetectionCheckpointer
4
+ from detectron2.data import (
5
+ MetadataCatalog,
6
+ )
7
+ from detectron2.modeling import build_model
8
+
9
+
10
+ __all__ = [
11
+ "DefaultPredictor",
12
+ ]
13
+
14
+
15
+ class DefaultPredictor:
16
+ """
17
+ Create a simple end-to-end predictor with the given config that runs on
18
+ single device for a single input image.
19
+ Compared to using the model directly, this class does the following additions:
20
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
21
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
22
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
23
+ 4. Take one input image and produce a single output, instead of a batch.
24
+ This is meant for simple demo purposes, so it does the above steps automatically.
25
+ This is not meant for benchmarks or running complicated inference logic.
26
+ If you'd like to do anything more complicated, please refer to its source code as
27
+ examples to build and use the model manually.
28
+ Attributes:
29
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
30
+ cfg.DATASETS.TEST.
31
+ Examples:
32
+ ::
33
+ pred = DefaultPredictor(cfg)
34
+ inputs = cv2.imread("input.jpg")
35
+ outputs = pred(inputs)
36
+ """
37
+
38
+ def __init__(self, cfg):
39
+ self.cfg = cfg.clone() # cfg can be modified by model
40
+ self.model = build_model(self.cfg)
41
+ self.model.eval()
42
+ if len(cfg.DATASETS.TEST):
43
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
44
+
45
+ checkpointer = DetectionCheckpointer(self.model)
46
+ checkpointer.load(cfg.MODEL.WEIGHTS)
47
+
48
+ self.aug = T.ResizeShortestEdge(
49
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
50
+ )
51
+
52
+ self.input_format = cfg.INPUT.FORMAT
53
+ assert self.input_format in ["RGB", "BGR"], self.input_format
54
+
55
+ def __call__(self, original_image, task):
56
+ """
57
+ Args:
58
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
59
+ Returns:
60
+ predictions (dict):
61
+ the output of the model for one image only.
62
+ See :doc:`/tutorials/models` for details about the format.
63
+ """
64
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
65
+ # Apply pre-processing to image.
66
+ if self.input_format == "RGB":
67
+ # whether the model expects BGR inputs or RGB
68
+ original_image = original_image[:, :, ::-1]
69
+ height, width = original_image.shape[:2]
70
+ image = self.aug.get_transform(original_image).apply_image(original_image)
71
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
72
+
73
+ task = f"The task is {task}"
74
+
75
+ inputs = {"image": image, "height": height, "width": width, "task": task}
76
+ predictions = self.model([inputs])[0]
77
+ return predictions
demo/predictor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
3
+ import atexit
4
+ import bisect
5
+ import multiprocessing as mp
6
+ from collections import deque
7
+
8
+ import cv2
9
+ import torch
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from defaults import DefaultPredictor
13
+ from detectron2.utils.video_visualizer import VideoVisualizer
14
+ from visualizer import ColorMode, Visualizer
15
+
16
+
17
+ class VisualizationDemo(object):
18
+ def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
19
+ """
20
+ Args:
21
+ cfg (CfgNode):
22
+ instance_mode (ColorMode):
23
+ parallel (bool): whether to run the model in different processes from visualization.
24
+ Useful since the visualization logic can be slow.
25
+ """
26
+ self.metadata = MetadataCatalog.get(
27
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
28
+ )
29
+ if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST[0]:
30
+ from cityscapesscripts.helpers.labels import labels
31
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
32
+ self.metadata = self.metadata.set(stuff_colors=stuff_colors)
33
+ self.cpu_device = torch.device("cpu")
34
+ self.instance_mode = instance_mode
35
+
36
+ self.parallel = parallel
37
+ if parallel:
38
+ num_gpu = torch.cuda.device_count()
39
+ self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
40
+ else:
41
+ self.predictor = DefaultPredictor(cfg)
42
+
43
+ def run_on_image(self, image, task, sem_gt, pan_gt, ins_gt, box_gt):
44
+ """
45
+ Args:
46
+ image (np.ndarray): an image of shape (H, W, C) (in BGR order).
47
+ This is the format used by OpenCV.
48
+ Returns:
49
+ predictions (dict): the output of the model.
50
+ vis_output (VisImage): the visualized image output.
51
+ """
52
+ vis_output = None
53
+ # Convert image from OpenCV BGR format to Matplotlib RGB format.
54
+ image = image[:, :, ::-1]
55
+ vis_output = {}
56
+
57
+ if task == 'panoptic':
58
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
59
+ predictions = self.predictor(image, "panoptic")
60
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
61
+ vis_output['panoptic'] = visualizer.draw_panoptic_seg_predictions(
62
+ panoptic_seg.to(self.cpu_device), segments_info, alpha=1
63
+ )
64
+
65
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
66
+ # vis_output['pan_gt'] = visualizer.draw_panoptic_seg(
67
+ # pan_gt[0].to(self.cpu_device), pan_gt[1], alpha=1
68
+ # )
69
+
70
+ if task == 'panoptic' or task == 'semantic':
71
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
72
+ predictions = self.predictor(image, "semantic")
73
+ vis_output['semantic'] = visualizer.draw_sem_seg(
74
+ predictions["sem_seg"].argmax(dim=0).to(self.cpu_device), alpha=1
75
+ )
76
+
77
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
78
+ # vis_output['gt_sem'] = visualizer.draw_sem_seg(
79
+ # sem_gt.to(self.cpu_device), alpha=1
80
+ # )
81
+
82
+ if task == 'panoptic' or task == 'instance':
83
+ visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
84
+ predictions = self.predictor(image, "instance")
85
+ instances = predictions["instances"].to(self.cpu_device)
86
+ vis_output['instance'] = visualizer.draw_instance_predictions(predictions=instances, alpha=1)
87
+
88
+ if 'boxes' in predictions:
89
+ boxes, labels, scores = predictions["boxes"]
90
+ visualizer = Visualizer(image, False, metadata=self.metadata, instance_mode=0)
91
+ vis_output['boxes'] = visualizer.draw_box_predictions(
92
+ boxes.to(self.cpu_device), labels.to(self.cpu_device), scores.to(self.cpu_device))
93
+
94
+
95
+ # visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
96
+ # vis_output['ins_gt'] = visualizer.draw_instance_predictions(predictions=ins_gt.to(self.cpu_device), alpha=1)
97
+ # vis_output['input'] = visualizer.get_image(image)
98
+
99
+ return predictions, vis_output
100
+
101
+
102
+ class AsyncPredictor:
103
+ """
104
+ A predictor that runs the model asynchronously, possibly on >1 GPUs.
105
+ Because rendering the visualization takes considerably amount of time,
106
+ this helps improve throughput a little bit when rendering videos.
107
+ """
108
+
109
+ class _StopToken:
110
+ pass
111
+
112
+ class _PredictWorker(mp.Process):
113
+ def __init__(self, cfg, task_queue, result_queue):
114
+ self.cfg = cfg
115
+ self.task_queue = task_queue
116
+ self.result_queue = result_queue
117
+ super().__init__()
118
+
119
+ def run(self):
120
+ predictor = DefaultPredictor(self.cfg)
121
+
122
+ while True:
123
+ task = self.task_queue.get()
124
+ if isinstance(task, AsyncPredictor._StopToken):
125
+ break
126
+ idx, data = task
127
+ result = predictor(data)
128
+ self.result_queue.put((idx, result))
129
+
130
+ def __init__(self, cfg, num_gpus: int = 1):
131
+ """
132
+ Args:
133
+ cfg (CfgNode):
134
+ num_gpus (int): if 0, will run on CPU
135
+ """
136
+ num_workers = max(num_gpus, 1)
137
+ self.task_queue = mp.Queue(maxsize=num_workers * 3)
138
+ self.result_queue = mp.Queue(maxsize=num_workers * 3)
139
+ self.procs = []
140
+ for gpuid in range(max(num_gpus, 1)):
141
+ cfg = cfg.clone()
142
+ cfg.defrost()
143
+ cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
144
+ self.procs.append(
145
+ AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
146
+ )
147
+
148
+ self.put_idx = 0
149
+ self.get_idx = 0
150
+ self.result_rank = []
151
+ self.result_data = []
152
+
153
+ for p in self.procs:
154
+ p.start()
155
+ atexit.register(self.shutdown)
156
+
157
+ def put(self, image):
158
+ self.put_idx += 1
159
+ self.task_queue.put((self.put_idx, image))
160
+
161
+ def get(self):
162
+ self.get_idx += 1 # the index needed for this request
163
+ if len(self.result_rank) and self.result_rank[0] == self.get_idx:
164
+ res = self.result_data[0]
165
+ del self.result_data[0], self.result_rank[0]
166
+ return res
167
+
168
+ while True:
169
+ # make sure the results are returned in the correct order
170
+ idx, res = self.result_queue.get()
171
+ if idx == self.get_idx:
172
+ return res
173
+ insert = bisect.bisect(self.result_rank, idx)
174
+ self.result_rank.insert(insert, idx)
175
+ self.result_data.insert(insert, res)
176
+
177
+ def __len__(self):
178
+ return self.put_idx - self.get_idx
179
+
180
+ def __call__(self, image):
181
+ self.put(image)
182
+ return self.get()
183
+
184
+ def shutdown(self):
185
+ for _ in self.procs:
186
+ self.task_queue.put(AsyncPredictor._StopToken())
187
+
188
+ @property
189
+ def default_buffer_size(self):
190
+ return len(self.procs) * 5
demo/visualizer.py ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import colorsys
3
+ import logging
4
+ import math
5
+ import numpy as np
6
+ from enum import Enum, unique
7
+ import cv2
8
+ import matplotlib as mpl
9
+ import matplotlib.colors as mplc
10
+ import matplotlib.figure as mplfigure
11
+ import pycocotools.mask as mask_util
12
+ import torch
13
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
14
+ from PIL import Image
15
+
16
+ from detectron2.data import MetadataCatalog
17
+ from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
18
+ from detectron2.utils.file_io import PathManager
19
+ import random
20
+ random.seed(0)
21
+ from .colormap import random_color, _COLORS
22
+ logger = logging.getLogger(__name__)
23
+
24
+ __all__ = ["ColorMode", "VisImage", "Visualizer"]
25
+
26
+
27
+ _SMALL_OBJECT_AREA_THRESH = 1000
28
+ _LARGE_MASK_AREA_THRESH = 120000
29
+ _OFF_WHITE = (1.0, 1.0, 240.0 / 255)
30
+ _BLACK = (0, 0, 0)
31
+ _RED = (1.0, 0, 0)
32
+
33
+ _KEYPOINT_THRESHOLD = 0.05
34
+
35
+
36
+ def instance_color(rgb=False, idx=1, maximum=255):
37
+ """
38
+ Args:
39
+ rgb (bool): whether to return RGB colors or BGR colors.
40
+ maximum (int): either 255 or 1
41
+ Returns:
42
+ ndarray: a vector of 3 numbers
43
+ """
44
+ ret = _COLORS[idx] * maximum
45
+ if not rgb:
46
+ ret = ret[::-1]
47
+ return ret
48
+
49
+ @unique
50
+ class ColorMode(Enum):
51
+ """
52
+ Enum of different color modes to use for instance visualizations.
53
+ """
54
+
55
+ IMAGE = 0
56
+ """
57
+ Picks a random color for every instance and overlay segmentations with low opacity.
58
+ """
59
+ SEGMENTATION = 1
60
+ """
61
+ Let instances of the same category have similar colors
62
+ (from metadata.thing_colors), and overlay them with
63
+ high opacity. This provides more attention on the quality of segmentation.
64
+ """
65
+ IMAGE_BW = 2
66
+ """
67
+ Same as IMAGE, but convert all areas without masks to gray-scale.
68
+ Only available for drawing per-instance mask predictions.
69
+ """
70
+
71
+
72
+ class GenericMask:
73
+ """
74
+ Attribute:
75
+ polygons (list[ndarray]): list[ndarray]: polygons for this mask.
76
+ Each ndarray has format [x, y, x, y, ...]
77
+ mask (ndarray): a binary mask
78
+ """
79
+
80
+ def __init__(self, mask_or_polygons, height, width):
81
+ self._mask = self._polygons = self._has_holes = None
82
+ self.height = height
83
+ self.width = width
84
+
85
+ m = mask_or_polygons
86
+ if isinstance(m, dict):
87
+ # RLEs
88
+ assert "counts" in m and "size" in m
89
+ if isinstance(m["counts"], list): # uncompressed RLEs
90
+ h, w = m["size"]
91
+ assert h == height and w == width
92
+ m = mask_util.frPyObjects(m, h, w)
93
+ self._mask = mask_util.decode(m)[:, :]
94
+ return
95
+
96
+ if isinstance(m, list): # list[ndarray]
97
+ self._polygons = [np.asarray(x).reshape(-1) for x in m]
98
+ return
99
+
100
+ if isinstance(m, np.ndarray): # assumed to be a binary mask
101
+ assert m.shape[1] != 2, m.shape
102
+ assert m.shape == (
103
+ height,
104
+ width,
105
+ ), f"mask shape: {m.shape}, target dims: {height}, {width}"
106
+ self._mask = m.astype("uint8")
107
+ return
108
+
109
+ raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
110
+
111
+ @property
112
+ def mask(self):
113
+ if self._mask is None:
114
+ self._mask = self.polygons_to_mask(self._polygons)
115
+ return self._mask
116
+
117
+ @property
118
+ def polygons(self):
119
+ if self._polygons is None:
120
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
121
+ return self._polygons
122
+
123
+ @property
124
+ def has_holes(self):
125
+ if self._has_holes is None:
126
+ if self._mask is not None:
127
+ self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
128
+ else:
129
+ self._has_holes = False # if original format is polygon, does not have holes
130
+ return self._has_holes
131
+
132
+ def mask_to_polygons(self, mask):
133
+ # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
134
+ # hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
135
+ # Internal contours (holes) are placed in hierarchy-2.
136
+ # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
137
+ mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
138
+ res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
139
+ hierarchy = res[-1]
140
+ if hierarchy is None: # empty mask
141
+ return [], False
142
+ has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
143
+ res = res[-2]
144
+ res = [x.flatten() for x in res]
145
+ # These coordinates from OpenCV are integers in range [0, W-1 or H-1].
146
+ # We add 0.5 to turn them into real-value coordinate space. A better solution
147
+ # would be to first +0.5 and then dilate the returned polygon by 0.5.
148
+ res = [x + 0.5 for x in res if len(x) >= 6]
149
+ return res, has_holes
150
+
151
+ def polygons_to_mask(self, polygons):
152
+ rle = mask_util.frPyObjects(polygons, self.height, self.width)
153
+ rle = mask_util.merge(rle)
154
+ return mask_util.decode(rle)[:, :]
155
+
156
+ def area(self):
157
+ return self.mask.sum()
158
+
159
+ def bbox(self):
160
+ p = mask_util.frPyObjects(self.polygons, self.height, self.width)
161
+ p = mask_util.merge(p)
162
+ bbox = mask_util.toBbox(p)
163
+ bbox[2] += bbox[0]
164
+ bbox[3] += bbox[1]
165
+ return bbox
166
+
167
+
168
+ class _PanopticPrediction:
169
+ """
170
+ Unify different panoptic annotation/prediction formats
171
+ """
172
+
173
+ def __init__(self, panoptic_seg, segments_info, metadata=None):
174
+ if segments_info is None:
175
+ assert metadata is not None
176
+ # If "segments_info" is None, we assume "panoptic_img" is a
177
+ # H*W int32 image storing the panoptic_id in the format of
178
+ # category_id * label_divisor + instance_id. We reserve -1 for
179
+ # VOID label.
180
+ label_divisor = metadata.label_divisor
181
+ segments_info = []
182
+ for panoptic_label in np.unique(panoptic_seg.numpy()):
183
+ if panoptic_label == -1:
184
+ # VOID region.
185
+ continue
186
+ pred_class = panoptic_label // label_divisor
187
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
188
+ segments_info.append(
189
+ {
190
+ "id": int(panoptic_label),
191
+ "category_id": int(pred_class),
192
+ "isthing": bool(isthing),
193
+ }
194
+ )
195
+ del metadata
196
+
197
+ self._seg = panoptic_seg
198
+
199
+ self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
200
+ segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
201
+ areas = areas.numpy()
202
+ sorted_idxs = np.argsort(-areas)
203
+ self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
204
+ self._seg_ids = self._seg_ids.tolist()
205
+ for sid, area in zip(self._seg_ids, self._seg_areas):
206
+ if sid in self._sinfo:
207
+ self._sinfo[sid]["area"] = float(area)
208
+
209
+ def non_empty_mask(self):
210
+ """
211
+ Returns:
212
+ (H, W) array, a mask for all pixels that have a prediction
213
+ """
214
+ empty_ids = []
215
+ for id in self._seg_ids:
216
+ if id not in self._sinfo:
217
+ empty_ids.append(id)
218
+ if len(empty_ids) == 0:
219
+ return np.zeros(self._seg.shape, dtype=np.uint8)
220
+ assert (
221
+ len(empty_ids) == 1
222
+ ), ">1 ids corresponds to no labels. This is currently not supported"
223
+ return (self._seg != empty_ids[0]).numpy().astype(np.bool)
224
+
225
+ def semantic_masks(self):
226
+ for sid in self._seg_ids:
227
+ sinfo = self._sinfo.get(sid)
228
+ if sinfo is None or sinfo["isthing"]:
229
+ # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
230
+ continue
231
+ yield (self._seg == sid).numpy().astype(np.bool), sinfo
232
+
233
+ def instance_masks(self):
234
+ for sid in self._seg_ids:
235
+ sinfo = self._sinfo.get(sid)
236
+ if sinfo is None or not sinfo["isthing"]:
237
+ continue
238
+ mask = (self._seg == sid).numpy().astype(np.bool)
239
+ if mask.sum() > 0:
240
+ yield mask, sinfo
241
+
242
+
243
+ def _create_text_labels(classes, scores, class_names, is_crowd=None):
244
+ """
245
+ Args:
246
+ classes (list[int] or None):
247
+ scores (list[float] or None):
248
+ class_names (list[str] or None):
249
+ is_crowd (list[bool] or None):
250
+ Returns:
251
+ list[str] or None
252
+ """
253
+ labels = None
254
+ if classes is not None:
255
+ if class_names is not None and len(class_names) > 0:
256
+ labels = [class_names[i] for i in classes]
257
+ else:
258
+ labels = [str(i) for i in classes]
259
+ if scores is not None:
260
+ if labels is None:
261
+ labels = ["{:.0f}%".format(s * 100) for s in scores]
262
+ else:
263
+ labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
264
+ if labels is not None and is_crowd is not None:
265
+ labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
266
+ return labels
267
+
268
+
269
+ class VisImage:
270
+ def __init__(self, img, scale=1.0):
271
+ """
272
+ Args:
273
+ img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
274
+ scale (float): scale the input image
275
+ """
276
+ self.img = img
277
+ self.scale = scale
278
+ self.width, self.height = img.shape[1], img.shape[0]
279
+ self._setup_figure(img)
280
+
281
+ def _setup_figure(self, img):
282
+ """
283
+ Args:
284
+ Same as in :meth:`__init__()`.
285
+ Returns:
286
+ fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
287
+ ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
288
+ """
289
+ fig = mplfigure.Figure(frameon=False)
290
+ self.dpi = fig.get_dpi()
291
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
292
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
293
+ fig.set_size_inches(
294
+ (self.width * self.scale + 1e-2) / self.dpi,
295
+ (self.height * self.scale + 1e-2) / self.dpi,
296
+ )
297
+ self.canvas = FigureCanvasAgg(fig)
298
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
299
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
300
+ ax.axis("off")
301
+ self.fig = fig
302
+ self.ax = ax
303
+ self.reset_image(img)
304
+
305
+ def reset_image(self, img):
306
+ """
307
+ Args:
308
+ img: same as in __init__
309
+ """
310
+ img = img.astype("uint8")
311
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
312
+
313
+ def save(self, filepath):
314
+ """
315
+ Args:
316
+ filepath (str): a string that contains the absolute path, including the file name, where
317
+ the visualized image will be saved.
318
+ """
319
+ self.fig.savefig(filepath)
320
+
321
+ def get_image(self):
322
+ """
323
+ Returns:
324
+ ndarray:
325
+ the visualized image of shape (H, W, 3) (RGB) in uint8 type.
326
+ The shape is scaled w.r.t the input image using the given `scale` argument.
327
+ """
328
+ canvas = self.canvas
329
+ s, (width, height) = canvas.print_to_buffer()
330
+ # buf = io.BytesIO() # works for cairo backend
331
+ # canvas.print_rgba(buf)
332
+ # width, height = self.width, self.height
333
+ # s = buf.getvalue()
334
+
335
+ buffer = np.frombuffer(s, dtype="uint8")
336
+
337
+ img_rgba = buffer.reshape(height, width, 4)
338
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
339
+ return rgb.astype("uint8")
340
+
341
+
342
+ class Visualizer:
343
+ """
344
+ Visualizer that draws data about detection/segmentation on images.
345
+ It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
346
+ that draw primitive objects to images, as well as high-level wrappers like
347
+ `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
348
+ that draw composite data in some pre-defined style.
349
+ Note that the exact visualization style for the high-level wrappers are subject to change.
350
+ Style such as color, opacity, label contents, visibility of labels, or even the visibility
351
+ of objects themselves (e.g. when the object is too small) may change according
352
+ to different heuristics, as long as the results still look visually reasonable.
353
+ To obtain a consistent style, you can implement custom drawing functions with the
354
+ abovementioned primitive methods instead. If you need more customized visualization
355
+ styles, you can process the data yourself following their format documented in
356
+ tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
357
+ intend to satisfy everyone's preference on drawing styles.
358
+ This visualizer focuses on high rendering quality rather than performance. It is not
359
+ designed to be used for real-time applications.
360
+ """
361
+
362
+ # TODO implement a fast, rasterized version using OpenCV
363
+
364
+ def __init__(self, img_rgb, is_img=True, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
365
+ """
366
+ Args:
367
+ img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
368
+ the height and width of the image respectively. C is the number of
369
+ color channels. The image is required to be in RGB format since that
370
+ is a requirement of the Matplotlib library. The image is also expected
371
+ to be in the range [0, 255].
372
+ metadata (Metadata): dataset metadata (e.g. class names and colors)
373
+ instance_mode (ColorMode): defines one of the pre-defined style for drawing
374
+ instances on an image.
375
+ """
376
+ if is_img:
377
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
378
+ else:
379
+ self.img = np.zeros_like(img_rgb).clip(0, 255).astype(np.uint8)
380
+ if metadata is None:
381
+ metadata = MetadataCatalog.get("__nonexist__")
382
+ self.metadata = metadata
383
+ self.output = VisImage(self.img, scale=scale)
384
+ self.cpu_device = torch.device("cpu")
385
+
386
+ # too small texts are useless, therefore clamp to 9
387
+ self._default_font_size = max(
388
+ np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
389
+ )
390
+ self._instance_mode = instance_mode
391
+ self.keypoint_threshold = _KEYPOINT_THRESHOLD
392
+
393
+ def get_image(self, img):
394
+ img = np.asarray(img).clip(0, 255).astype(np.uint8)
395
+ return VisImage(img, scale=1.0)
396
+
397
+ def draw_box_predictions(
398
+ self,
399
+ boxes=None,
400
+ labels=None,
401
+ scores=None,
402
+ assigned_colors=None
403
+ ):
404
+ """
405
+ Args:
406
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
407
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
408
+ or a :class:`RotatedBoxes`,
409
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
410
+ for the N objects in a single image,
411
+ labels (list[str]): the text to be displayed for each instance.
412
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
413
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
414
+ for full list of formats that the colors are accepted in.
415
+ Returns:
416
+ output (VisImage): image object with visualizations.
417
+ """
418
+ num_instances = 0
419
+ boxes = self._convert_boxes(boxes)
420
+ classes = labels.tolist()
421
+ scores = scores.tolist()
422
+ labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
423
+ num_instances = len(boxes)
424
+ assert len(labels) == num_instances
425
+ if assigned_colors is None:
426
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
427
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
428
+ if num_instances == 0:
429
+ return self.output
430
+
431
+ # Display in largest to smallest order to reduce occlusion.
432
+ areas = None
433
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
434
+
435
+ if areas is not None:
436
+ sorted_idxs = np.argsort(-areas).tolist()
437
+ # Re-order overlapped instances in descending order.
438
+ boxes = boxes[sorted_idxs] if boxes is not None else None
439
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
440
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
441
+
442
+ for i in range(num_instances):
443
+ color = assigned_colors[i]
444
+ if boxes is not None:
445
+ self.draw_box(boxes[i], edge_color=color)
446
+
447
+ if labels is not None:
448
+ # first get a box
449
+ if boxes is not None:
450
+ x0, y0, x1, y1 = boxes[i]
451
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
452
+ horiz_align = "left"
453
+ else:
454
+ continue # drawing the box confidence for keypoints isn't very useful.
455
+ # for small objects, draw text at the side to avoid occlusion
456
+ instance_area = (y1 - y0) * (x1 - x0)
457
+ if (
458
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
459
+ or y1 - y0 < 40 * self.output.scale
460
+ ):
461
+ if y1 >= self.output.height - 5:
462
+ text_pos = (x1, y0)
463
+ else:
464
+ text_pos = (x0, y1)
465
+
466
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
467
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
468
+ font_size = (
469
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
470
+ * 0.5
471
+ * self._default_font_size
472
+ )
473
+ self.draw_text(
474
+ labels[i],
475
+ text_pos,
476
+ color=lighter_color,
477
+ horizontal_alignment=horiz_align,
478
+ font_size=font_size,
479
+ )
480
+
481
+ return self.output
482
+
483
+
484
+ def draw_instance_predictions(self, predictions, alpha=0.8, is_text=True):
485
+ """
486
+ Draw instance-level prediction results on an image.
487
+ Args:
488
+ predictions (Instances): the output of an instance detection/segmentation
489
+ model. Following fields will be used to draw:
490
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
491
+ Returns:
492
+ output (VisImage): image object with visualizations.
493
+ """
494
+ boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
495
+ scores = predictions.scores if predictions.has("scores") else None
496
+ classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
497
+ labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
498
+ keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
499
+
500
+ if predictions.has("pred_masks"):
501
+ masks = np.asarray(predictions.pred_masks)
502
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
503
+ else:
504
+ masks = None
505
+
506
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
507
+ # colors = [
508
+ # self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
509
+ # ]
510
+ colors = [
511
+ instance_color(rgb=True, idx=c, maximum=1) for c in classes
512
+ ]
513
+ else:
514
+ colors = None
515
+
516
+ if self._instance_mode == ColorMode.IMAGE_BW:
517
+ self.output.reset_image(
518
+ self._create_grayscale_image(
519
+ (predictions.pred_masks.any(dim=0) > 0).numpy()
520
+ if predictions.has("pred_masks")
521
+ else None
522
+ )
523
+ )
524
+
525
+ self.overlay_instances(
526
+ masks=masks,
527
+ boxes=boxes,
528
+ labels=labels,
529
+ keypoints=keypoints,
530
+ assigned_colors=colors,
531
+ alpha=alpha,
532
+ is_text=is_text,
533
+ )
534
+ return self.output
535
+
536
+ def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True):
537
+ """
538
+ Draw semantic segmentation predictions/labels.
539
+ Args:
540
+ sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
541
+ Each value is the integer label of the pixel.
542
+ area_threshold (int): segments with less than `area_threshold` are not drawn.
543
+ alpha (float): the larger it is, the more opaque the segmentations are.
544
+ Returns:
545
+ output (VisImage): image object with visualizations.
546
+ """
547
+ if isinstance(sem_seg, torch.Tensor):
548
+ sem_seg = sem_seg.numpy()
549
+ labels, areas = np.unique(sem_seg, return_counts=True)
550
+ sorted_idxs = np.argsort(-areas).tolist()
551
+ labels = labels[sorted_idxs]
552
+ for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
553
+ try:
554
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
555
+ except (AttributeError, IndexError):
556
+ mask_color = None
557
+
558
+ binary_mask = (sem_seg == label).astype(np.uint8)
559
+ text = self.metadata.stuff_classes[label]
560
+ self.draw_binary_mask(
561
+ binary_mask,
562
+ color=mask_color,
563
+ edge_color=_OFF_WHITE,
564
+ text=text,
565
+ alpha=alpha,
566
+ area_threshold=area_threshold,
567
+ is_text=is_text,
568
+ )
569
+ return self.output
570
+
571
+ def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7, is_text=True,):
572
+ """
573
+ Draw panoptic prediction annotations or results.
574
+ Args:
575
+ panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
576
+ segment.
577
+ segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
578
+ If it is a ``list[dict]``, each dict contains keys "id", "category_id".
579
+ If None, category id of each pixel is computed by
580
+ ``pixel // metadata.label_divisor``.
581
+ area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
582
+ Returns:
583
+ output (VisImage): image object with visualizations.
584
+ """
585
+ pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
586
+
587
+ if self._instance_mode == ColorMode.IMAGE_BW:
588
+ self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
589
+
590
+ # draw mask for all semantic segments first i.e. "stuff"
591
+ for mask, sinfo in pred.semantic_masks():
592
+ category_idx = sinfo["category_id"]
593
+ try:
594
+ mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
595
+ except AttributeError:
596
+ mask_color = None
597
+
598
+ text = self.metadata.stuff_classes[category_idx]
599
+ self.draw_binary_mask(
600
+ mask,
601
+ color=mask_color,
602
+ edge_color=_OFF_WHITE,
603
+ text=text,
604
+ alpha=alpha,
605
+ area_threshold=area_threshold,
606
+ is_text=is_text,
607
+ )
608
+
609
+ # draw mask for all instances second
610
+ all_instances = list(pred.instance_masks())
611
+ if len(all_instances) == 0:
612
+ return self.output
613
+ masks, sinfo = list(zip(*all_instances))
614
+ category_ids = [x["category_id"] for x in sinfo]
615
+
616
+ try:
617
+ scores = [x["score"] for x in sinfo]
618
+ except KeyError:
619
+ scores = None
620
+ labels = _create_text_labels(
621
+ category_ids, scores, self.metadata.stuff_classes, [x.get("iscrowd", 0) for x in sinfo]
622
+ )
623
+
624
+ try:
625
+ colors = [
626
+ self._jitter([x / 255 for x in self.metadata.stuff_colors[c]]) for c in category_ids
627
+ ]
628
+ except AttributeError:
629
+ colors = None
630
+ self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha, is_text=is_text)
631
+
632
+ return self.output
633
+
634
+ draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
635
+
636
+ def draw_dataset_dict(self, dic):
637
+ """
638
+ Draw annotations/segmentaions in Detectron2 Dataset format.
639
+ Args:
640
+ dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
641
+ Returns:
642
+ output (VisImage): image object with visualizations.
643
+ """
644
+ annos = dic.get("annotations", None)
645
+ if annos:
646
+ if "segmentation" in annos[0]:
647
+ masks = [x["segmentation"] for x in annos]
648
+ else:
649
+ masks = None
650
+ if "keypoints" in annos[0]:
651
+ keypts = [x["keypoints"] for x in annos]
652
+ keypts = np.array(keypts).reshape(len(annos), -1, 3)
653
+ else:
654
+ keypts = None
655
+
656
+ boxes = [
657
+ BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
658
+ if len(x["bbox"]) == 4
659
+ else x["bbox"]
660
+ for x in annos
661
+ ]
662
+
663
+ colors = None
664
+ category_ids = [x["category_id"] for x in annos]
665
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
666
+ colors = [
667
+ self._jitter([x / 255 for x in self.metadata.stuff_colors[c]])
668
+ for c in category_ids
669
+ ]
670
+ names = self.metadata.get("stuff_classes", None)
671
+ labels = _create_text_labels(
672
+ category_ids,
673
+ scores=None,
674
+ class_names=names,
675
+ is_crowd=[x.get("iscrowd", 0) for x in annos],
676
+ )
677
+ self.overlay_instances(
678
+ labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
679
+ )
680
+
681
+ sem_seg = dic.get("sem_seg", None)
682
+ if sem_seg is None and "sem_seg_file_name" in dic:
683
+ with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
684
+ sem_seg = Image.open(f)
685
+ sem_seg = np.asarray(sem_seg, dtype="uint8")
686
+ if sem_seg is not None:
687
+ self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
688
+
689
+ pan_seg = dic.get("pan_seg", None)
690
+ if pan_seg is None and "pan_seg_file_name" in dic:
691
+ with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
692
+ pan_seg = Image.open(f)
693
+ pan_seg = np.asarray(pan_seg)
694
+ from panopticapi.utils import rgb2id
695
+
696
+ pan_seg = rgb2id(pan_seg)
697
+ if pan_seg is not None:
698
+ segments_info = dic["segments_info"]
699
+ pan_seg = torch.tensor(pan_seg)
700
+ self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
701
+ return self.output
702
+
703
+ def overlay_instances(
704
+ self,
705
+ *,
706
+ boxes=None,
707
+ labels=None,
708
+ masks=None,
709
+ keypoints=None,
710
+ assigned_colors=None,
711
+ alpha=0.5,
712
+ is_text=True,
713
+ ):
714
+ """
715
+ Args:
716
+ boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
717
+ or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
718
+ or a :class:`RotatedBoxes`,
719
+ or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
720
+ for the N objects in a single image,
721
+ labels (list[str]): the text to be displayed for each instance.
722
+ masks (masks-like object): Supported types are:
723
+ * :class:`detectron2.structures.PolygonMasks`,
724
+ :class:`detectron2.structures.BitMasks`.
725
+ * list[list[ndarray]]: contains the segmentation masks for all objects in one image.
726
+ The first level of the list corresponds to individual instances. The second
727
+ level to all the polygon that compose the instance, and the third level
728
+ to the polygon coordinates. The third level should have the format of
729
+ [x0, y0, x1, y1, ..., xn, yn] (n >= 3).
730
+ * list[ndarray]: each ndarray is a binary mask of shape (H, W).
731
+ * list[dict]: each dict is a COCO-style RLE.
732
+ keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
733
+ where the N is the number of instances and K is the number of keypoints.
734
+ The last dimension corresponds to (x, y, visibility or score).
735
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
736
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
737
+ for full list of formats that the colors are accepted in.
738
+ Returns:
739
+ output (VisImage): image object with visualizations.
740
+ """
741
+ num_instances = 0
742
+ if boxes is not None:
743
+ boxes = self._convert_boxes(boxes)
744
+ num_instances = len(boxes)
745
+ if masks is not None:
746
+ masks = self._convert_masks(masks)
747
+ if num_instances:
748
+ assert len(masks) == num_instances
749
+ else:
750
+ num_instances = len(masks)
751
+ if keypoints is not None:
752
+ if num_instances:
753
+ assert len(keypoints) == num_instances
754
+ else:
755
+ num_instances = len(keypoints)
756
+ keypoints = self._convert_keypoints(keypoints)
757
+ if labels is not None:
758
+ assert len(labels) == num_instances
759
+ if assigned_colors is None:
760
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
761
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
762
+ if num_instances == 0:
763
+ return self.output
764
+ if boxes is not None and boxes.shape[1] == 5:
765
+ return self.overlay_rotated_instances(
766
+ boxes=boxes, labels=labels, assigned_colors=assigned_colors
767
+ )
768
+
769
+ # Display in largest to smallest order to reduce occlusion.
770
+ areas = None
771
+ if boxes is not None:
772
+ areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
773
+ elif masks is not None:
774
+ areas = np.asarray([x.area() for x in masks])
775
+
776
+ if areas is not None:
777
+ sorted_idxs = np.argsort(-areas).tolist()
778
+ # Re-order overlapped instances in descending order.
779
+ boxes = boxes[sorted_idxs] if boxes is not None else None
780
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
781
+ masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
782
+ assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
783
+ keypoints = keypoints[sorted_idxs] if keypoints is not None else None
784
+
785
+ for i in range(num_instances):
786
+ color = assigned_colors[i]
787
+ if boxes is not None:
788
+ self.draw_box(boxes[i], edge_color=color)
789
+
790
+ if masks is not None:
791
+ for segment in masks[i].polygons:
792
+ self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
793
+
794
+ if labels is not None:
795
+ # first get a box
796
+ if boxes is not None:
797
+ x0, y0, x1, y1 = boxes[i]
798
+ text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
799
+ horiz_align = "left"
800
+ elif masks is not None:
801
+ # skip small mask without polygon
802
+ if len(masks[i].polygons) == 0:
803
+ continue
804
+
805
+ x0, y0, x1, y1 = masks[i].bbox()
806
+
807
+ # draw text in the center (defined by median) when box is not drawn
808
+ # median is less sensitive to outliers.
809
+ text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
810
+ horiz_align = "center"
811
+ else:
812
+ continue # drawing the box confidence for keypoints isn't very useful.
813
+ # for small objects, draw text at the side to avoid occlusion
814
+ instance_area = (y1 - y0) * (x1 - x0)
815
+ if (
816
+ instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
817
+ or y1 - y0 < 40 * self.output.scale
818
+ ):
819
+ if y1 >= self.output.height - 5:
820
+ text_pos = (x1, y0)
821
+ else:
822
+ text_pos = (x0, y1)
823
+
824
+ height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
825
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
826
+ font_size = (
827
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
828
+ * 0.5
829
+ * self._default_font_size
830
+ )
831
+ if is_text:
832
+ self.draw_text(
833
+ labels[i],
834
+ text_pos,
835
+ color=lighter_color,
836
+ horizontal_alignment=horiz_align,
837
+ font_size=font_size,
838
+ )
839
+
840
+ # draw keypoints
841
+ if keypoints is not None:
842
+ for keypoints_per_instance in keypoints:
843
+ self.draw_and_connect_keypoints(keypoints_per_instance)
844
+
845
+ return self.output
846
+
847
+ def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
848
+ """
849
+ Args:
850
+ boxes (ndarray): an Nx5 numpy array of
851
+ (x_center, y_center, width, height, angle_degrees) format
852
+ for the N objects in a single image.
853
+ labels (list[str]): the text to be displayed for each instance.
854
+ assigned_colors (list[matplotlib.colors]): a list of colors, where each color
855
+ corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
856
+ for full list of formats that the colors are accepted in.
857
+ Returns:
858
+ output (VisImage): image object with visualizations.
859
+ """
860
+ num_instances = len(boxes)
861
+
862
+ if assigned_colors is None:
863
+ # assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
864
+ assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
865
+ if num_instances == 0:
866
+ return self.output
867
+
868
+ # Display in largest to smallest order to reduce occlusion.
869
+ if boxes is not None:
870
+ areas = boxes[:, 2] * boxes[:, 3]
871
+
872
+ sorted_idxs = np.argsort(-areas).tolist()
873
+ # Re-order overlapped instances in descending order.
874
+ boxes = boxes[sorted_idxs]
875
+ labels = [labels[k] for k in sorted_idxs] if labels is not None else None
876
+ colors = [assigned_colors[idx] for idx in sorted_idxs]
877
+
878
+ for i in range(num_instances):
879
+ self.draw_rotated_box_with_label(
880
+ boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
881
+ )
882
+
883
+ return self.output
884
+
885
+ def draw_and_connect_keypoints(self, keypoints):
886
+ """
887
+ Draws keypoints of an instance and follows the rules for keypoint connections
888
+ to draw lines between appropriate keypoints. This follows color heuristics for
889
+ line color.
890
+ Args:
891
+ keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
892
+ and the last dimension corresponds to (x, y, probability).
893
+ Returns:
894
+ output (VisImage): image object with visualizations.
895
+ """
896
+ visible = {}
897
+ keypoint_names = self.metadata.get("keypoint_names")
898
+ for idx, keypoint in enumerate(keypoints):
899
+
900
+ # draw keypoint
901
+ x, y, prob = keypoint
902
+ if prob > self.keypoint_threshold:
903
+ self.draw_circle((x, y), color=_RED)
904
+ if keypoint_names:
905
+ keypoint_name = keypoint_names[idx]
906
+ visible[keypoint_name] = (x, y)
907
+
908
+ if self.metadata.get("keypoint_connection_rules"):
909
+ for kp0, kp1, color in self.metadata.keypoint_connection_rules:
910
+ if kp0 in visible and kp1 in visible:
911
+ x0, y0 = visible[kp0]
912
+ x1, y1 = visible[kp1]
913
+ color = tuple(x / 255.0 for x in color)
914
+ self.draw_line([x0, x1], [y0, y1], color=color)
915
+
916
+ # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
917
+ # Note that this strategy is specific to person keypoints.
918
+ # For other keypoints, it should just do nothing
919
+ try:
920
+ ls_x, ls_y = visible["left_shoulder"]
921
+ rs_x, rs_y = visible["right_shoulder"]
922
+ mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
923
+ except KeyError:
924
+ pass
925
+ else:
926
+ # draw line from nose to mid-shoulder
927
+ nose_x, nose_y = visible.get("nose", (None, None))
928
+ if nose_x is not None:
929
+ self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
930
+
931
+ try:
932
+ # draw line from mid-shoulder to mid-hip
933
+ lh_x, lh_y = visible["left_hip"]
934
+ rh_x, rh_y = visible["right_hip"]
935
+ except KeyError:
936
+ pass
937
+ else:
938
+ mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
939
+ self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
940
+ return self.output
941
+
942
+ """
943
+ Primitive drawing functions:
944
+ """
945
+
946
+ def draw_text(
947
+ self,
948
+ text,
949
+ position,
950
+ *,
951
+ font_size=None,
952
+ color="g",
953
+ horizontal_alignment="center",
954
+ rotation=0,
955
+ ):
956
+ """
957
+ Args:
958
+ text (str): class label
959
+ position (tuple): a tuple of the x and y coordinates to place text on image.
960
+ font_size (int, optional): font of the text. If not provided, a font size
961
+ proportional to the image width is calculated and used.
962
+ color: color of the text. Refer to `matplotlib.colors` for full list
963
+ of formats that are accepted.
964
+ horizontal_alignment (str): see `matplotlib.text.Text`
965
+ rotation: rotation angle in degrees CCW
966
+ Returns:
967
+ output (VisImage): image object with text drawn.
968
+ """
969
+ if not font_size:
970
+ font_size = self._default_font_size
971
+
972
+ # since the text background is dark, we don't want the text to be dark
973
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
974
+ color[np.argmax(color)] = max(0.8, np.max(color))
975
+
976
+ x, y = position
977
+ self.output.ax.text(
978
+ x,
979
+ y,
980
+ text,
981
+ size=font_size * self.output.scale,
982
+ family="sans-serif",
983
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
984
+ verticalalignment="top",
985
+ horizontalalignment=horizontal_alignment,
986
+ color=color,
987
+ zorder=10,
988
+ rotation=rotation,
989
+ )
990
+ return self.output
991
+
992
+ def draw_box(self, box_coord, alpha=1.0, edge_color="g", line_style="-"):
993
+ """
994
+ Args:
995
+ box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
996
+ are the coordinates of the image's top left corner. x1 and y1 are the
997
+ coordinates of the image's bottom right corner.
998
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
999
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1000
+ for full list of formats that are accepted.
1001
+ line_style (string): the string to use to create the outline of the boxes.
1002
+ Returns:
1003
+ output (VisImage): image object with box drawn.
1004
+ """
1005
+ x0, y0, x1, y1 = box_coord
1006
+ width = x1 - x0
1007
+ height = y1 - y0
1008
+
1009
+ linewidth = 2
1010
+
1011
+ self.output.ax.add_patch(
1012
+ mpl.patches.Rectangle(
1013
+ (x0, y0),
1014
+ width,
1015
+ height,
1016
+ fill=False,
1017
+ edgecolor=edge_color,
1018
+ linewidth=linewidth * self.output.scale,
1019
+ alpha=alpha,
1020
+ linestyle=line_style,
1021
+ )
1022
+ )
1023
+ return self.output
1024
+
1025
+ def draw_rotated_box_with_label(
1026
+ self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
1027
+ ):
1028
+ """
1029
+ Draw a rotated box with label on its top-left corner.
1030
+ Args:
1031
+ rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
1032
+ where cnt_x and cnt_y are the center coordinates of the box.
1033
+ w and h are the width and height of the box. angle represents how
1034
+ many degrees the box is rotated CCW with regard to the 0-degree box.
1035
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1036
+ edge_color: color of the outline of the box. Refer to `matplotlib.colors`
1037
+ for full list of formats that are accepted.
1038
+ line_style (string): the string to use to create the outline of the boxes.
1039
+ label (string): label for rotated box. It will not be rendered when set to None.
1040
+ Returns:
1041
+ output (VisImage): image object with box drawn.
1042
+ """
1043
+ cnt_x, cnt_y, w, h, angle = rotated_box
1044
+ area = w * h
1045
+ # use thinner lines when the box is small
1046
+ linewidth = self._default_font_size / (
1047
+ 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
1048
+ )
1049
+
1050
+ theta = angle * math.pi / 180.0
1051
+ c = math.cos(theta)
1052
+ s = math.sin(theta)
1053
+ rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
1054
+ # x: left->right ; y: top->down
1055
+ rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
1056
+ for k in range(4):
1057
+ j = (k + 1) % 4
1058
+ self.draw_line(
1059
+ [rotated_rect[k][0], rotated_rect[j][0]],
1060
+ [rotated_rect[k][1], rotated_rect[j][1]],
1061
+ color=edge_color,
1062
+ linestyle="--" if k == 1 else line_style,
1063
+ linewidth=linewidth,
1064
+ )
1065
+
1066
+ if label is not None:
1067
+ text_pos = rotated_rect[1] # topleft corner
1068
+
1069
+ height_ratio = h / np.sqrt(self.output.height * self.output.width)
1070
+ label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
1071
+ font_size = (
1072
+ np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
1073
+ )
1074
+ self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
1075
+
1076
+ return self.output
1077
+
1078
+ def draw_circle(self, circle_coord, color, radius=3):
1079
+ """
1080
+ Args:
1081
+ circle_coord (list(int) or tuple(int)): contains the x and y coordinates
1082
+ of the center of the circle.
1083
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1084
+ formats that are accepted.
1085
+ radius (int): radius of the circle.
1086
+ Returns:
1087
+ output (VisImage): image object with box drawn.
1088
+ """
1089
+ x, y = circle_coord
1090
+ self.output.ax.add_patch(
1091
+ mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
1092
+ )
1093
+ return self.output
1094
+
1095
+ def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
1096
+ """
1097
+ Args:
1098
+ x_data (list[int]): a list containing x values of all the points being drawn.
1099
+ Length of list should match the length of y_data.
1100
+ y_data (list[int]): a list containing y values of all the points being drawn.
1101
+ Length of list should match the length of x_data.
1102
+ color: color of the line. Refer to `matplotlib.colors` for a full list of
1103
+ formats that are accepted.
1104
+ linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
1105
+ for a full list of formats that are accepted.
1106
+ linewidth (float or None): width of the line. When it's None,
1107
+ a default value will be computed and used.
1108
+ Returns:
1109
+ output (VisImage): image object with line drawn.
1110
+ """
1111
+ if linewidth is None:
1112
+ linewidth = self._default_font_size / 3
1113
+ linewidth = max(linewidth, 1)
1114
+ self.output.ax.add_line(
1115
+ mpl.lines.Line2D(
1116
+ x_data,
1117
+ y_data,
1118
+ linewidth=linewidth * self.output.scale,
1119
+ color=color,
1120
+ linestyle=linestyle,
1121
+ )
1122
+ )
1123
+ return self.output
1124
+
1125
+ def draw_binary_mask(
1126
+ self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10, is_text=True,
1127
+ ):
1128
+ """
1129
+ Args:
1130
+ binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
1131
+ W is the image width. Each value in the array is either a 0 or 1 value of uint8
1132
+ type.
1133
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1134
+ formats that are accepted. If None, will pick a random color.
1135
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1136
+ full list of formats that are accepted.
1137
+ text (str): if None, will be drawn on the object
1138
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1139
+ area_threshold (float): a connected component smaller than this area will not be shown.
1140
+ Returns:
1141
+ output (VisImage): image object with mask drawn.
1142
+ """
1143
+ if color is None:
1144
+ color = random_color(rgb=True, maximum=1)
1145
+ color = mplc.to_rgb(color)
1146
+
1147
+ has_valid_segment = False
1148
+ binary_mask = binary_mask.astype("uint8") # opencv needs uint8
1149
+ mask = GenericMask(binary_mask, self.output.height, self.output.width)
1150
+ shape2d = (binary_mask.shape[0], binary_mask.shape[1])
1151
+
1152
+ if not mask.has_holes:
1153
+ # draw polygons for regular masks
1154
+ for segment in mask.polygons:
1155
+ area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
1156
+ if area < (area_threshold or 0):
1157
+ continue
1158
+ has_valid_segment = True
1159
+ segment = segment.reshape(-1, 2)
1160
+ self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
1161
+ else:
1162
+ # TODO: Use Path/PathPatch to draw vector graphics:
1163
+ # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
1164
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1165
+ rgba[:, :, :3] = color
1166
+ rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
1167
+ has_valid_segment = True
1168
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1169
+
1170
+ if is_text:
1171
+ if text is not None and has_valid_segment:
1172
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1173
+ self._draw_text_in_mask(binary_mask, text, lighter_color)
1174
+ return self.output
1175
+
1176
+ def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
1177
+ """
1178
+ Args:
1179
+ soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
1180
+ color: color of the mask. Refer to `matplotlib.colors` for a full list of
1181
+ formats that are accepted. If None, will pick a random color.
1182
+ text (str): if None, will be drawn on the object
1183
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1184
+ Returns:
1185
+ output (VisImage): image object with mask drawn.
1186
+ """
1187
+ if color is None:
1188
+ color = random_color(rgb=True, maximum=1)
1189
+ color = mplc.to_rgb(color)
1190
+
1191
+ shape2d = (soft_mask.shape[0], soft_mask.shape[1])
1192
+ rgba = np.zeros(shape2d + (4,), dtype="float32")
1193
+ rgba[:, :, :3] = color
1194
+ rgba[:, :, 3] = soft_mask * alpha
1195
+ self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
1196
+
1197
+ if text is not None:
1198
+ lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
1199
+ binary_mask = (soft_mask > 0.5).astype("uint8")
1200
+ # self._draw_text_in_mask(binary_mask, text, lighter_color)
1201
+ return self.output
1202
+
1203
+ def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
1204
+ """
1205
+ Args:
1206
+ segment: numpy array of shape Nx2, containing all the points in the polygon.
1207
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1208
+ formats that are accepted.
1209
+ edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
1210
+ full list of formats that are accepted. If not provided, a darker shade
1211
+ of the polygon color will be used instead.
1212
+ alpha (float): blending efficient. Smaller values lead to more transparent masks.
1213
+ Returns:
1214
+ output (VisImage): image object with polygon drawn.
1215
+ """
1216
+ if edge_color is None:
1217
+ # make edge color darker than the polygon color
1218
+ if alpha > 0.8:
1219
+ edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
1220
+ else:
1221
+ edge_color = color
1222
+ edge_color = mplc.to_rgb(edge_color) + (1,)
1223
+
1224
+ polygon = mpl.patches.Polygon(
1225
+ segment,
1226
+ fill=True,
1227
+ facecolor=mplc.to_rgb(color) + (alpha,),
1228
+ edgecolor=edge_color,
1229
+ linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
1230
+ )
1231
+ self.output.ax.add_patch(polygon)
1232
+ return self.output
1233
+
1234
+ """
1235
+ Internal methods:
1236
+ """
1237
+
1238
+ def _jitter(self, color):
1239
+ """
1240
+ Randomly modifies given color to produce a slightly different color than the color given.
1241
+ Args:
1242
+ color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
1243
+ picked. The values in the list are in the [0.0, 1.0] range.
1244
+ Returns:
1245
+ jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
1246
+ color after being jittered. The values in the list are in the [0.0, 1.0] range.
1247
+ """
1248
+ color = mplc.to_rgb(color)
1249
+ vec = np.random.rand(3)
1250
+ # better to do it in another color space
1251
+ vec = vec / np.linalg.norm(vec) * 0.5
1252
+ res = np.clip(vec + color, 0, 1)
1253
+ return tuple(res)
1254
+
1255
+ def _create_grayscale_image(self, mask=None):
1256
+ """
1257
+ Create a grayscale version of the original image.
1258
+ The colors in masked area, if given, will be kept.
1259
+ """
1260
+ img_bw = self.img.astype("f4").mean(axis=2)
1261
+ img_bw = np.stack([img_bw] * 3, axis=2)
1262
+ if mask is not None:
1263
+ img_bw[mask] = self.img[mask]
1264
+ return img_bw
1265
+
1266
+ def _change_color_brightness(self, color, brightness_factor):
1267
+ """
1268
+ Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
1269
+ less or more saturation than the original color.
1270
+ Args:
1271
+ color: color of the polygon. Refer to `matplotlib.colors` for a full list of
1272
+ formats that are accepted.
1273
+ brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
1274
+ 0 will correspond to no change, a factor in [-1.0, 0) range will result in
1275
+ a darker color and a factor in (0, 1.0] range will result in a lighter color.
1276
+ Returns:
1277
+ modified_color (tuple[double]): a tuple containing the RGB values of the
1278
+ modified color. Each value in the tuple is in the [0.0, 1.0] range.
1279
+ """
1280
+ assert brightness_factor >= -1.0 and brightness_factor <= 1.0
1281
+ color = mplc.to_rgb(color)
1282
+ polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
1283
+ modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
1284
+ modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
1285
+ modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
1286
+ modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
1287
+ return modified_color
1288
+
1289
+ def _convert_boxes(self, boxes):
1290
+ """
1291
+ Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
1292
+ """
1293
+ if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
1294
+ return boxes.tensor.detach().numpy()
1295
+ else:
1296
+ return np.asarray(boxes)
1297
+
1298
+ def _convert_masks(self, masks_or_polygons):
1299
+ """
1300
+ Convert different format of masks or polygons to a tuple of masks and polygons.
1301
+ Returns:
1302
+ list[GenericMask]:
1303
+ """
1304
+
1305
+ m = masks_or_polygons
1306
+ if isinstance(m, PolygonMasks):
1307
+ m = m.polygons
1308
+ if isinstance(m, BitMasks):
1309
+ m = m.tensor.numpy()
1310
+ if isinstance(m, torch.Tensor):
1311
+ m = m.numpy()
1312
+ ret = []
1313
+ for x in m:
1314
+ if isinstance(x, GenericMask):
1315
+ ret.append(x)
1316
+ else:
1317
+ ret.append(GenericMask(x, self.output.height, self.output.width))
1318
+ return ret
1319
+
1320
+ def _draw_text_in_mask(self, binary_mask, text, color):
1321
+ """
1322
+ Find proper places to draw text given a binary mask.
1323
+ """
1324
+ # TODO sometimes drawn on wrong objects. the heuristics here can improve.
1325
+ _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
1326
+ if stats[1:, -1].size == 0:
1327
+ return
1328
+ largest_component_id = np.argmax(stats[1:, -1]) + 1
1329
+
1330
+ # draw text on the largest component, as well as other very large components.
1331
+ for cid in range(1, _num_cc):
1332
+ if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
1333
+ # median is more stable than centroid
1334
+ # center = centroids[largest_component_id]
1335
+ center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
1336
+ self.draw_text(text, center, color=color)
1337
+
1338
+ def _convert_keypoints(self, keypoints):
1339
+ if isinstance(keypoints, Keypoints):
1340
+ keypoints = keypoints.tensor
1341
+ keypoints = np.asarray(keypoints)
1342
+ return keypoints
1343
+
1344
+ def get_output(self):
1345
+ """
1346
+ Returns:
1347
+ output (VisImage): the image output containing the visualizations added
1348
+ to the image.
1349
+ """
1350
+ return self.output
gradio_app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print("Installed the dependencies!")
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ import imutils
9
+ import os
10
+ import sys
11
+ import time
12
+ from detectron2.config import get_cfg
13
+ from detectron2.projects.deeplab import add_deeplab_config
14
+ from detectron2.data import MetadataCatalog
15
+ torch.set_num_threads(16) # Use 16 CPU threads
16
+ torch.set_num_interop_threads(16) # Inter-op parallelism
17
+ from oneformer import (
18
+ add_oneformer_config,
19
+ add_common_config,
20
+ add_swin_config,
21
+ add_dinat_config,
22
+ )
23
+
24
+ from demo.defaults import DefaultPredictor
25
+ from demo.visualizer import Visualizer, ColorMode
26
+
27
+ import gradio as gr
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ KEY_DICT = {"Cityscapes (19 classes)": "cityscapes",
31
+ "COCO (133 classes)": "coco",
32
+ "ADE20K (150 classes)": "ade20k",}
33
+
34
+ SWIN_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml",
35
+ "coco": "configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml",
36
+ "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",}
37
+
38
+ SWIN_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_swin_large",
39
+ filename="250_16_swin_l_oneformer_cityscapes_90k.pth"),
40
+ "coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_swin_large",
41
+ filename="150_16_swin_l_oneformer_coco_100ep.pth"),
42
+ "ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_swin_large",
43
+ filename="250_16_swin_l_oneformer_ade20k_160k.pth")
44
+ }
45
+
46
+ DINAT_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml",
47
+ "coco": "configs/coco/oneformer_dinat_large_bs16_100ep.yaml",
48
+ "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",}
49
+
50
+ DINAT_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_dinat_large",
51
+ filename="250_16_dinat_l_oneformer_cityscapes_90k.pth"),
52
+ "coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_dinat_large",
53
+ filename="150_16_dinat_l_oneformer_coco_100ep.pth"),
54
+ "ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_dinat_large",
55
+ filename="250_16_dinat_l_oneformer_ade20k_160k.pth")
56
+ }
57
+
58
+ MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT,
59
+ "Swin-L": SWIN_MODEL_DICT }
60
+
61
+ CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT,
62
+ "Swin-L": SWIN_CFG_DICT }
63
+
64
+ WIDTH_DICT = {"cityscapes": 512,
65
+ "coco": 512,
66
+ "ade20k": 640}
67
+
68
+ cpu_device = torch.device("cpu")
69
+
70
+ PREDICTORS = {
71
+ "DiNAT-L": {
72
+ "Cityscapes (19 classes)": None,
73
+ "COCO (133 classes)": None,
74
+ "ADE20K (150 classes)": None
75
+ },
76
+ "Swin-L": {
77
+ "Cityscapes (19 classes)": None,
78
+ "COCO (133 classes)": None,
79
+ "ADE20K (150 classes)": None
80
+ }
81
+ }
82
+
83
+ METADATA = {
84
+ "DiNAT-L": {
85
+ "Cityscapes (19 classes)": None,
86
+ "COCO (133 classes)": None,
87
+ "ADE20K (150 classes)": None
88
+ },
89
+ "Swin-L": {
90
+ "Cityscapes (19 classes)": None,
91
+ "COCO (133 classes)": None,
92
+ "ADE20K (150 classes)": None
93
+ }
94
+ }
95
+
96
+ def setup_modules():
97
+ for dataset in ["Cityscapes (19 classes)", "COCO (133 classes)", "ADE20K (150 classes)"]:
98
+ for backbone in ["DiNAT-L", "Swin-L"]:
99
+ cfg = setup_cfg(dataset, backbone)
100
+ metadata = MetadataCatalog.get(
101
+ cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
102
+ )
103
+ if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
104
+ from cityscapesscripts.helpers.labels import labels
105
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
106
+ metadata = metadata.set(stuff_colors=stuff_colors)
107
+ PREDICTORS[backbone][dataset] = DefaultPredictor(cfg)
108
+ METADATA[backbone][dataset] = metadata
109
+
110
+ def setup_cfg(dataset, backbone):
111
+ # load config from file and command-line arguments
112
+ cfg = get_cfg()
113
+ add_deeplab_config(cfg)
114
+ add_common_config(cfg)
115
+ add_swin_config(cfg)
116
+ add_oneformer_config(cfg)
117
+ add_dinat_config(cfg)
118
+ dataset = KEY_DICT[dataset]
119
+ cfg_path = CFG_DICT[backbone][dataset]
120
+ cfg.merge_from_file(cfg_path)
121
+ if torch.cuda.is_available():
122
+ cfg.MODEL.DEVICE = 'cuda'
123
+ else:
124
+ cfg.MODEL.DEVICE = 'cpu'
125
+ cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
126
+ cfg.freeze()
127
+ return cfg
128
+
129
+ # def setup_modules(dataset, backbone):
130
+ # cfg = setup_cfg(dataset, backbone)
131
+ # predictor = DefaultPredictor(cfg)
132
+ # # predictor = PREDICTORS[backbone][dataset]
133
+ # metadata = MetadataCatalog.get(
134
+ # cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
135
+ # )
136
+ # if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
137
+ # from cityscapesscripts.helpers.labels import labels
138
+ # stuff_colors = [k.color for k in labels if k.trainId != 255]
139
+ # metadata = metadata.set(stuff_colors=stuff_colors)
140
+
141
+ # return predictor, metadata
142
+
143
+ def panoptic_run(img, predictor, metadata):
144
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
145
+ predictions = predictor(img, "panoptic")
146
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
147
+ out = visualizer.draw_panoptic_seg_predictions(
148
+ panoptic_seg.to(cpu_device), segments_info, alpha=0.5
149
+ )
150
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
151
+ out_map = visualizer_map.draw_panoptic_seg_predictions(
152
+ panoptic_seg.to(cpu_device), segments_info, alpha=1, is_text=False
153
+ )
154
+ return out, out_map
155
+
156
+ def instance_run(img, predictor, metadata):
157
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
158
+ predictions = predictor(img, "instance")
159
+ instances = predictions["instances"].to(cpu_device)
160
+ out = visualizer.draw_instance_predictions(predictions=instances, alpha=0.5)
161
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
162
+ out_map = visualizer_map.draw_instance_predictions(predictions=instances, alpha=1, is_text=False)
163
+ return out, out_map
164
+
165
+ def semantic_run(img, predictor, metadata):
166
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
167
+ predictions = predictor(img, "semantic")
168
+ out = visualizer.draw_sem_seg(
169
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
170
+ )
171
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
172
+ out_map = visualizer_map.draw_sem_seg(
173
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=1, is_text=False
174
+ )
175
+ return out, out_map
176
+
177
+ TASK_INFER = {"the task is panoptic": panoptic_run, "the task is instance": instance_run, "the task is semantic": semantic_run}
178
+
179
+ def segment(path, task, dataset, backbone):
180
+ # predictor, metadata = setup_modules(dataset, backbone)
181
+ predictor = PREDICTORS[backbone][dataset]
182
+ metadata = METADATA[backbone][dataset]
183
+ img = cv2.imread(path)
184
+ width = WIDTH_DICT[KEY_DICT[dataset]]
185
+ img = imutils.resize(img, width=width)
186
+ out, out_map = TASK_INFER[task](img, predictor, metadata)
187
+ out = Image.fromarray(out.get_image())
188
+ out_map = Image.fromarray(out_map.get_image())
189
+ return out, out_map
190
+
191
+ title = "<h1 style='text-align: center'>OneFormer:DIEGO MENTORIA MILIONÁRIA - APP 1</h1>"
192
+ # style='margin-bottom: -10px;
193
+ description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain, </a> <a href='https://chrisjuniorli.github.io/' style='text-decoration:none' target='_blank'>Jiachen Li<sup>*</sup>, </a> <a href='https://www.linkedin.com/in/mtchiu/' style='text-decoration:none' target='_blank'>MangTik Chiu<sup>*</sup>, </a> <a href='https://alihassanijr.com/' style='text-decoration:none' target='_blank'>Ali Hassani, </a> <a href='https://www.linkedin.com/in/nukich74/' style='text-decoration:none' target='_blank'>Nikita Orlov, </a> <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi</a></p>" \
194
+ + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/oneformer/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2211.06220' target='_blank'>ArXiv Paper</a> | <a href='https://github.com/SHI-Labs/OneFormer' target='_blank'>Github Repo</a></p>" \
195
+ + "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
196
+ OneFormer is the first multi-task universal image segmentation framework based on transformers. Our single OneFormer model achieves state-of-the-art performance across all three segmentation tasks with a single task-conditioned joint training process. OneFormer uses a task token to condition the model on the task in focus, making our architecture task-guided for training, and task-dynamic for inference, all with a single model. We believe OneFormer is a significant step towards making image segmentation more universal and accessible.\
197
+ </p>" \
198
+ + "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a single RTX A6000 GPU, OneFormer is able to inference at more than 15 FPS.]</p>"
199
+
200
+ setup_modules()
201
+
202
+ gradio_inputs = [gr.Image(label="Input Image",type="filepath"),
203
+ gr.Radio(choices=["the task is panoptic" ,"the task is instance", "the task is semantic"], type="value", value="the task is panoptic", label="Task Token Input"),
204
+ gr.Radio(choices=["COCO (133 classes)" ,"Cityscapes (19 classes)", "ADE20K (150 classes)"], type="value", value="COCO (133 classes)", label="Model"),
205
+ gr.Radio(choices=["DiNAT-L" ,"Swin-L"], type="value", value="DiNAT-L", label="Backbone"),
206
+ ]
207
+ gradio_outputs = [gr.Image(type="pil", label="Segmentation Overlay"), gr.Image(type="pil", label="Segmentation Map")]
208
+
209
+
210
+ examples = [["examples/coco.jpeg", "the task is panoptic", "COCO (133 classes)", "DiNAT-L"],
211
+ ["examples/cityscapes.png", "the task is panoptic", "Cityscapes (19 classes)", "DiNAT-L"],
212
+ ["examples/ade20k.jpeg", "the task is panoptic", "ADE20K (150 classes)", "DiNAT-L"]]
213
+
214
+
215
+ iface = gr.Interface(fn=segment, inputs=gradio_inputs,
216
+ outputs=gradio_outputs,
217
+ examples_per_page=5,
218
+ allow_flagging="never",
219
+ examples=examples, title=title,
220
+ description=description)
221
+
222
+ iface.launch(server_name="0.0.0.0",share=True)
gradio_combine.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ import imutils
6
+ import os
7
+ import sys
8
+ import time
9
+ import colorsys
10
+ from scipy import ndimage
11
+ import gradio as gr
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # Detectron2 imports
15
+ from detectron2.config import get_cfg
16
+ from detectron2.projects.deeplab import add_deeplab_config
17
+ from detectron2.data import MetadataCatalog
18
+ from detectron2.engine import DefaultPredictor as DetectronPredictor
19
+ from detectron2 import model_zoo
20
+
21
+ # OneFormer imports
22
+ from oneformer import (
23
+ add_oneformer_config,
24
+ add_common_config,
25
+ add_swin_config,
26
+ add_dinat_config,
27
+ )
28
+ from demo.defaults import DefaultPredictor as OneFormerPredictor
29
+ from demo.visualizer import Visualizer, ColorMode
30
+
31
+ # NeuroNest contrast detection imports
32
+ from utils.contrast_detector import ContrastDetector
33
+ from utils.luminance_contrast import LuminanceContrastDetector
34
+ from utils.hue_contrast import HueContrastDetector
35
+ from utils.saturation_contrast import SaturationContrastDetector
36
+ from utils.combined_contrast import CombinedContrastDetector
37
+
38
+ # Set threads for CPU optimization
39
+ torch.set_num_threads(4)
40
+
41
+ ########################################
42
+ # GLOBAL CONFIGURATIONS
43
+ ########################################
44
+
45
+ # OneFormer configurations
46
+ KEY_DICT = {"ADE20K (150 classes)": "ade20k"}
47
+
48
+ SWIN_CFG_DICT = {
49
+ "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
50
+ }
51
+
52
+ SWIN_MODEL_DICT = {
53
+ "ade20k": hf_hub_download(
54
+ repo_id="shi-labs/oneformer_ade20k_swin_large",
55
+ filename="250_16_swin_l_oneformer_ade20k_160k.pth"
56
+ )
57
+ }
58
+
59
+ DINAT_CFG_DICT = {
60
+ "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",
61
+ }
62
+
63
+ DINAT_MODEL_DICT = {
64
+ "ade20k": hf_hub_download(
65
+ repo_id="shi-labs/oneformer_ade20k_dinat_large",
66
+ filename="250_16_dinat_l_oneformer_ade20k_160k.pth"
67
+ )
68
+ }
69
+
70
+ MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT, "Swin-L": SWIN_MODEL_DICT}
71
+ CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT, "Swin-L": SWIN_CFG_DICT}
72
+ WIDTH_DICT = {"ade20k": 640}
73
+
74
+ # Contrast detector mapping
75
+ CONTRAST_DETECTORS = {
76
+ "Luminance (WCAG)": LuminanceContrastDetector(),
77
+ "Hue": HueContrastDetector(),
78
+ "Saturation": SaturationContrastDetector(),
79
+ "Combined": CombinedContrastDetector()
80
+ }
81
+
82
+ # Device configuration
83
+ cpu_device = torch.device("cpu")
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ print(f"Using device: {device}")
86
+
87
+ # Model storage
88
+ ONEFORMER_PREDICTORS = {
89
+ "DiNAT-L": {"ADE20K (150 classes)": None},
90
+ "Swin-L": {"ADE20K (150 classes)": None}
91
+ }
92
+
93
+ ONEFORMER_METADATA = {
94
+ "DiNAT-L": {"ADE20K (150 classes)": None},
95
+ "Swin-L": {"ADE20K (150 classes)": None}
96
+ }
97
+
98
+ ########################################
99
+ # MASK R-CNN SETUP AND FUNCTIONS
100
+ ########################################
101
+
102
+ def load_maskrcnn_model(weights_path, device="cuda", threshold=0.5):
103
+ """Load Mask R-CNN model for blackspot detection"""
104
+ cfg = get_cfg()
105
+ cfg.merge_from_file(
106
+ model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
107
+ )
108
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # [Floors, blackspot]
109
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
110
+ cfg.MODEL.WEIGHTS = weights_path
111
+ # Fix: Convert torch.device to string
112
+ cfg.MODEL.DEVICE = str(device) if isinstance(device, torch.device) else device
113
+ return DetectronPredictor(cfg)
114
+ def postprocess_blackspot_masks(im, instances, show_floor=True, show_blackspot=True):
115
+ """Extract floor and blackspot masks from Mask R-CNN predictions"""
116
+ height, width = im.shape[:2]
117
+ pred_classes = instances.pred_classes.cpu().numpy()
118
+ pred_masks = instances.pred_masks.cpu().numpy()
119
+
120
+ combined_floor_mask = np.zeros((height, width), dtype=bool)
121
+ final_blackspot = np.zeros((height, width), dtype=bool)
122
+
123
+ for cls_id, mask in zip(pred_classes, pred_masks):
124
+ if cls_id == 0 and show_floor: # Floor class
125
+ combined_floor_mask |= mask
126
+ elif cls_id == 1 and show_blackspot: # Blackspot class
127
+ final_blackspot |= mask
128
+
129
+ return combined_floor_mask.astype(np.uint8), final_blackspot.astype(np.uint8)
130
+
131
+ ########################################
132
+ # ONEFORMER SETUP AND FUNCTIONS
133
+ ########################################
134
+
135
+ def setup_oneformer_modules():
136
+ """Initialize OneFormer models"""
137
+ for dataset in ["ADE20K (150 classes)"]:
138
+ for backbone in ["DiNAT-L", "Swin-L"]:
139
+ cfg = setup_oneformer_cfg(dataset, backbone)
140
+ metadata = MetadataCatalog.get(
141
+ cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
142
+ )
143
+ ONEFORMER_PREDICTORS[backbone][dataset] = OneFormerPredictor(cfg)
144
+ ONEFORMER_METADATA[backbone][dataset] = metadata
145
+
146
+ def setup_oneformer_cfg(dataset, backbone):
147
+ """Setup OneFormer configuration"""
148
+ cfg = get_cfg()
149
+ add_deeplab_config(cfg)
150
+ add_common_config(cfg)
151
+ add_swin_config(cfg)
152
+ add_oneformer_config(cfg)
153
+ add_dinat_config(cfg)
154
+ dataset = KEY_DICT[dataset]
155
+ cfg_path = CFG_DICT[backbone][dataset]
156
+ cfg.merge_from_file(cfg_path)
157
+ cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+ cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
159
+ cfg.freeze()
160
+ return cfg
161
+
162
+ def semantic_run(img, predictor, metadata):
163
+ """Run OneFormer semantic segmentation"""
164
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
165
+ predictions = predictor(img, "semantic")
166
+ out = visualizer.draw_sem_seg(
167
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
168
+ )
169
+ return out, predictions["sem_seg"].argmax(dim=0).to(cpu_device).numpy()
170
+
171
+ ########################################
172
+ # INTEGRATED ANALYSIS FUNCTION
173
+ ########################################
174
+
175
+ def integrated_analysis(image_path,
176
+ # Blackspot detection parameters
177
+ blackspot_threshold, show_floor, show_blackspot,
178
+ # Contrast detection parameters
179
+ enable_contrast, backbone, contrast_method, contrast_threshold):
180
+ """
181
+ Perform integrated analysis with both blackspot detection and contrast analysis
182
+ """
183
+ # Read the image
184
+ im = cv2.imread(image_path)
185
+ if im is None:
186
+ return "Error: could not read image!", None, None, None
187
+
188
+ # Resize for OneFormer if contrast analysis is enabled
189
+ if enable_contrast:
190
+ width = WIDTH_DICT["ade20k"]
191
+ im_resized = imutils.resize(im, width=width)
192
+ else:
193
+ im_resized = im
194
+
195
+ # Part 1: Blackspot Detection with Mask R-CNN
196
+ blackspot_text = []
197
+ blackspot_viz = None
198
+
199
+ if show_floor or show_blackspot:
200
+ weights_path = "./output_floor_blackspot/model_0004999.pth"
201
+ maskrcnn_predictor = load_maskrcnn_model(weights_path, device, blackspot_threshold)
202
+
203
+ # Run blackspot detection
204
+ outputs = maskrcnn_predictor(im)
205
+ instances = outputs["instances"]
206
+
207
+ # Post-process masks
208
+ floor_mask, blackspot_mask = postprocess_blackspot_masks(im, instances, show_floor, show_blackspot)
209
+
210
+ # Create visualization
211
+ blackspot_overlay = im.copy()
212
+ overlay = np.zeros_like(im)
213
+
214
+ if show_floor:
215
+ overlay[floor_mask > 0] = (0, 255, 0) # Green for floor
216
+ if show_blackspot:
217
+ overlay[blackspot_mask > 0] = (0, 0, 255) # Red for blackspot
218
+
219
+ blackspot_overlay = cv2.addWeighted(im, 1.0, overlay, 0.5, 0)
220
+ blackspot_viz = Image.fromarray(cv2.cvtColor(blackspot_overlay, cv2.COLOR_BGR2RGB))
221
+
222
+ # Calculate statistics
223
+ blackspot_area = int(blackspot_mask.sum())
224
+ floor_area = int(floor_mask.sum())
225
+
226
+ blackspot_text.append(f"### Blackspot Detection Results")
227
+ blackspot_text.append(f"**Threshold:** {blackspot_threshold:.2f}")
228
+
229
+ if show_floor:
230
+ blackspot_text.append(f"**Floor area:** {floor_area} pixels")
231
+ if show_blackspot:
232
+ blackspot_text.append(f"**Blackspot area:** {blackspot_area} pixels")
233
+ if floor_area > 0 and show_floor:
234
+ percentage = (blackspot_area / floor_area) * 100
235
+ blackspot_text.append(f"**Blackspot coverage:** {percentage:.2f}% of floor area")
236
+
237
+ # Part 2: Contrast Analysis with OneFormer
238
+ segmentation_viz = None
239
+ contrast_viz = None
240
+ contrast_text = []
241
+
242
+ if enable_contrast:
243
+ dataset = "ADE20K (150 classes)"
244
+ predictor = ONEFORMER_PREDICTORS[backbone][dataset]
245
+ metadata = ONEFORMER_METADATA[backbone][dataset]
246
+
247
+ # Get segmentation
248
+ out, seg_mask = semantic_run(im_resized, predictor, metadata)
249
+ segmentation_viz = Image.fromarray(out.get_image())
250
+
251
+ # Analyze contrast
252
+ img_rgb = cv2.cvtColor(im_resized, cv2.COLOR_BGR2RGB)
253
+ detector = CONTRAST_DETECTORS[contrast_method]
254
+ contrast_image, problem_areas, stats = detector.analyze(
255
+ img_rgb, seg_mask, contrast_threshold
256
+ )
257
+
258
+ contrast_viz = Image.fromarray(contrast_image)
259
+
260
+ # Create stats text
261
+ contrast_text.append(f"### Contrast Analysis Results")
262
+ contrast_text.append(f"**Method:** {contrast_method}")
263
+ contrast_text.append(f"**Threshold:** {contrast_threshold:.2f}")
264
+ contrast_text.append(f"**Problem Areas:** {stats['problem_count']}")
265
+
266
+ if 'min_contrast' in stats:
267
+ contrast_text.append(f"**Min Contrast:** {stats['min_contrast']:.2f}")
268
+ if 'max_contrast' in stats:
269
+ contrast_text.append(f"**Max Contrast:** {stats['max_contrast']:.2f}")
270
+ if 'average_contrast' in stats:
271
+ contrast_text.append(f"**Average Contrast:** {stats['average_contrast']:.2f}")
272
+
273
+ # Combine results
274
+ combined_text = []
275
+ if blackspot_text:
276
+ combined_text.extend(blackspot_text)
277
+ if contrast_text:
278
+ if blackspot_text:
279
+ combined_text.append("\n")
280
+ combined_text.extend(contrast_text)
281
+
282
+ return "\n".join(combined_text), blackspot_viz, segmentation_viz, contrast_viz
283
+
284
+ ########################################
285
+ # GRADIO INTERFACE
286
+ ########################################
287
+
288
+ # Initialize models
289
+ print("Initializing OneFormer models...")
290
+ setup_oneformer_modules()
291
+
292
+ title = "NeuroNest: Integrated Blackspot & Contrast Detection"
293
+ description = """
294
+ This integrated system combines:
295
+ 1. **Blackspot Detection**: Uses Mask R-CNN to detect blackspots on floors
296
+ 2. **Contrast Analysis**: Uses OneFormer segmentation to analyze contrast between objects
297
+
298
+ Both analyses help identify potential accessibility issues for individuals with Alzheimer's disease.
299
+ """
300
+
301
+ # Create the Gradio interface
302
+ with gr.Blocks(title=title) as demo:
303
+ gr.Markdown(f"# {title}")
304
+ gr.Markdown(description)
305
+
306
+ with gr.Row():
307
+ with gr.Column(scale=1):
308
+ # Input image
309
+ image_input = gr.Image(label="Input Image", type="filepath")
310
+
311
+ # Blackspot detection controls
312
+ with gr.Accordion("Blackspot Detection Settings", open=True):
313
+ blackspot_threshold = gr.Slider(
314
+ minimum=0.1, maximum=0.9, value=0.5, step=0.05,
315
+ label="Blackspot Detection Threshold"
316
+ )
317
+ with gr.Row():
318
+ show_floor = gr.Checkbox(value=True, label="Show Floor")
319
+ show_blackspot = gr.Checkbox(value=True, label="Show Blackspots")
320
+
321
+ # Contrast analysis controls
322
+ with gr.Accordion("Contrast Analysis Settings", open=True):
323
+ enable_contrast = gr.Checkbox(value=True, label="Enable Contrast Analysis")
324
+ backbone = gr.Radio(
325
+ choices=["Swin-L", "DiNAT-L"],
326
+ value="Swin-L",
327
+ label="OneFormer Backbone"
328
+ )
329
+ contrast_method = gr.Radio(
330
+ choices=["Luminance (WCAG)", "Hue", "Saturation", "Combined"],
331
+ value="Luminance (WCAG)",
332
+ label="Contrast Detection Method"
333
+ )
334
+ contrast_threshold = gr.Slider(
335
+ minimum=1.0, maximum=10.0, value=4.5, step=0.1,
336
+ label="Contrast Threshold"
337
+ )
338
+
339
+ analyze_btn = gr.Button("Analyze", variant="primary")
340
+
341
+ with gr.Column(scale=2):
342
+ # Output displays
343
+ with gr.Tabs():
344
+ with gr.Tab("Analysis Report"):
345
+ analysis_text = gr.Textbox(label="Analysis Results", lines=10)
346
+
347
+ with gr.Tab("Blackspot Detection"):
348
+ blackspot_output = gr.Image(label="Blackspot Visualization")
349
+
350
+ with gr.Tab("Segmentation"):
351
+ segmentation_output = gr.Image(label="OneFormer Segmentation")
352
+
353
+ with gr.Tab("Contrast Analysis"):
354
+ contrast_output = gr.Image(label="Contrast Visualization")
355
+
356
+ # Connect the interface
357
+ analyze_btn.click(
358
+ fn=integrated_analysis,
359
+ inputs=[
360
+ image_input,
361
+ blackspot_threshold, show_floor, show_blackspot,
362
+ enable_contrast, backbone, contrast_method, contrast_threshold
363
+ ],
364
+ outputs=[
365
+ analysis_text, blackspot_output, segmentation_output, contrast_output
366
+ ]
367
+ )
368
+
369
+ # Examples
370
+ gr.Examples(
371
+ examples=[
372
+ ["examples/indoor_room.jpg", 0.5, True, True, True, "Swin-L", "Luminance (WCAG)", 4.5],
373
+ ["examples/living_room.jpg", 0.7, True, True, True, "DiNAT-L", "Combined", 3.0],
374
+ ],
375
+ inputs=[
376
+ image_input,
377
+ blackspot_threshold, show_floor, show_blackspot,
378
+ enable_contrast, backbone, contrast_method, contrast_threshold
379
+ ]
380
+ )
381
+
382
+ if __name__ == "__main__":
383
+ print(f"Launching integrated NeuroNest app on device: {device}")
384
+ demo.queue().launch(server_name="0.0.0.0", share=True)
gradio_contrast.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import time
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ import imutils
9
+ import colorsys
10
+ from scipy import ndimage
11
+
12
+ # Set CUDA device explicitly at the start
13
+ if torch.cuda.is_available():
14
+ torch.cuda.set_device(0) # Use first GPU
15
+ print(f"Using GPU: {torch.cuda.get_device_name(0)}")
16
+ else:
17
+ print("WARNING: No GPU available, using CPU")
18
+
19
+ print("Installed the dependencies!")
20
+
21
+ from detectron2.config import get_cfg
22
+ from detectron2.projects.deeplab import add_deeplab_config
23
+ from detectron2.data import MetadataCatalog
24
+
25
+ from oneformer import (
26
+ add_oneformer_config,
27
+ add_common_config,
28
+ add_swin_config,
29
+ add_dinat_config,
30
+ )
31
+
32
+ from demo.defaults import DefaultPredictor
33
+ from demo.visualizer import Visualizer, ColorMode
34
+
35
+ import gradio as gr
36
+ from huggingface_hub import hf_hub_download
37
+
38
+ # Force unbuffered output for SLURM logs
39
+ sys.stdout = sys.__stdout__
40
+ sys.stderr = sys.__stderr__
41
+
42
+ # Set environment variables for better GPU performance
43
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
44
+ os.environ['TORCH_USE_CUDA_DSA'] = '1'
45
+
46
+ # Contrast Detection Classes
47
+ class ContrastDetector:
48
+ """Base class for contrast detection between segments"""
49
+
50
+ @staticmethod
51
+ def calculate_luminance_contrast(color1, color2):
52
+ """Calculate WCAG luminance contrast ratio"""
53
+ def get_relative_luminance(rgb):
54
+ r, g, b = [val/255.0 for val in rgb]
55
+ r = r/12.92 if r <= 0.03928 else ((r + 0.055)/1.055) ** 2.4
56
+ g = g/12.92 if g <= 0.03928 else ((g + 0.055)/1.055) ** 2.4
57
+ b = b/12.92 if b <= 0.03928 else ((b + 0.055)/1.055) ** 2.4
58
+ return 0.2126 * r + 0.7152 * g + 0.0722 * b
59
+
60
+ lum1 = get_relative_luminance(color1)
61
+ lum2 = get_relative_luminance(color2)
62
+
63
+ lighter = max(lum1, lum2)
64
+ darker = min(lum1, lum2)
65
+
66
+ return (lighter + 0.05) / (darker + 0.05)
67
+
68
+ @staticmethod
69
+ def calculate_hue_contrast(color1, color2):
70
+ """Calculate hue difference between two colors"""
71
+ hsv1 = colorsys.rgb_to_hsv(color1[0]/255.0, color1[1]/255.0, color1[2]/255.0)
72
+ hsv2 = colorsys.rgb_to_hsv(color2[0]/255.0, color2[1]/255.0, color2[2]/255.0)
73
+
74
+ hue_diff = abs(hsv1[0] - hsv2[0])
75
+ if hue_diff > 0.5:
76
+ hue_diff = 1 - hue_diff
77
+
78
+ return hue_diff * 2
79
+
80
+ @staticmethod
81
+ def calculate_saturation_contrast(color1, color2):
82
+ """Calculate saturation difference between two colors"""
83
+ hsv1 = colorsys.rgb_to_hsv(color1[0]/255.0, color1[1]/255.0, color1[2]/255.0)
84
+ hsv2 = colorsys.rgb_to_hsv(color2[0]/255.0, color2[1]/255.0, color2[2]/255.0)
85
+
86
+ return abs(hsv1[1] - hsv2[1])
87
+
88
+ @staticmethod
89
+ def analyze_contrast(image, segmentation, method="luminance", threshold=4.5):
90
+ """Analyze contrast between adjacent segments"""
91
+ unique_segments = np.unique(segmentation)
92
+ h, w = segmentation.shape
93
+ contrast_mask = np.zeros((h, w), dtype=bool)
94
+ problem_areas = []
95
+
96
+ # Calculate average colors for each segment
97
+ segment_colors = {}
98
+ for seg_id in unique_segments:
99
+ mask = segmentation == seg_id
100
+ if np.any(mask):
101
+ segment_colors[seg_id] = np.mean(image[mask], axis=0).astype(int)
102
+
103
+ # Check contrast between adjacent segments
104
+ for i in range(h):
105
+ for j in range(w):
106
+ current_seg = segmentation[i, j]
107
+
108
+ # Check 4-connected neighbors
109
+ for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
110
+ ni, nj = i + di, j + dj
111
+ if 0 <= ni < h and 0 <= nj < w:
112
+ neighbor_seg = segmentation[ni, nj]
113
+
114
+ if current_seg != neighbor_seg:
115
+ color1 = segment_colors[current_seg]
116
+ color2 = segment_colors[neighbor_seg]
117
+
118
+ if method == "luminance":
119
+ contrast = ContrastDetector.calculate_luminance_contrast(color1, color2)
120
+ elif method == "hue":
121
+ contrast = ContrastDetector.calculate_hue_contrast(color1, color2)
122
+ threshold = 0.3 # Adjust threshold for hue
123
+ elif method == "saturation":
124
+ contrast = ContrastDetector.calculate_saturation_contrast(color1, color2)
125
+ threshold = 0.3 # Adjust threshold for saturation
126
+
127
+ if contrast < threshold:
128
+ contrast_mask[i, j] = True
129
+ problem_areas.append((current_seg, neighbor_seg, contrast))
130
+
131
+ return contrast_mask, problem_areas, segment_colors
132
+
133
+ # Rest of your code remains the same until setup_cfg function
134
+ KEY_DICT = {"Cityscapes (19 classes)": "cityscapes",
135
+ "COCO (133 classes)": "coco",
136
+ "ADE20K (150 classes)": "ade20k",}
137
+
138
+ SWIN_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml",
139
+ "coco": "configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml",
140
+ "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",}
141
+
142
+ SWIN_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_swin_large",
143
+ filename="250_16_swin_l_oneformer_cityscapes_90k.pth"),
144
+ "coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_swin_large",
145
+ filename="150_16_swin_l_oneformer_coco_100ep.pth"),
146
+ "ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_swin_large",
147
+ filename="250_16_swin_l_oneformer_ade20k_160k.pth")
148
+ }
149
+
150
+ DINAT_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml",
151
+ "coco": "configs/coco/oneformer_dinat_large_bs16_100ep.yaml",
152
+ "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",}
153
+
154
+ DINAT_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_dinat_large",
155
+ filename="250_16_dinat_l_oneformer_cityscapes_90k.pth"),
156
+ "coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_dinat_large",
157
+ filename="150_16_dinat_l_oneformer_coco_100ep.pth"),
158
+ "ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_dinat_large",
159
+ filename="250_16_dinat_l_oneformer_ade20k_160k.pth")
160
+ }
161
+
162
+ MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT,
163
+ "Swin-L": SWIN_MODEL_DICT }
164
+
165
+ CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT,
166
+ "Swin-L": SWIN_CFG_DICT }
167
+
168
+ WIDTH_DICT = {"cityscapes": 512,
169
+ "coco": 512,
170
+ "ade20k": 640}
171
+
172
+ # Modified to ensure CUDA device
173
+ if torch.cuda.is_available():
174
+ device = torch.device("cuda:0")
175
+ print(f"Using device: {device}")
176
+ else:
177
+ device = torch.device("cpu")
178
+ print(f"WARNING: Using CPU device")
179
+
180
+ cpu_device = torch.device("cpu")
181
+
182
+ PREDICTORS = {
183
+ "DiNAT-L": {
184
+ "Cityscapes (19 classes)": None,
185
+ "COCO (133 classes)": None,
186
+ "ADE20K (150 classes)": None
187
+ },
188
+ "Swin-L": {
189
+ "Cityscapes (19 classes)": None,
190
+ "COCO (133 classes)": None,
191
+ "ADE20K (150 classes)": None
192
+ }
193
+ }
194
+
195
+ METADATA = {
196
+ "DiNAT-L": {
197
+ "Cityscapes (19 classes)": None,
198
+ "COCO (133 classes)": None,
199
+ "ADE20K (150 classes)": None
200
+ },
201
+ "Swin-L": {
202
+ "Cityscapes (19 classes)": None,
203
+ "COCO (133 classes)": None,
204
+ "ADE20K (150 classes)": None
205
+ }
206
+ }
207
+
208
+ def setup_modules():
209
+ print("Setting up modules...")
210
+ for dataset in ["Cityscapes (19 classes)", "COCO (133 classes)", "ADE20K (150 classes)"]:
211
+ for backbone in ["DiNAT-L", "Swin-L"]:
212
+ print(f"Loading {backbone} for {dataset}...")
213
+ cfg = setup_cfg(dataset, backbone)
214
+ metadata = MetadataCatalog.get(
215
+ cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
216
+ )
217
+ if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
218
+ from cityscapesscripts.helpers.labels import labels
219
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
220
+ metadata = metadata.set(stuff_colors=stuff_colors)
221
+
222
+ # Create predictor with explicit device
223
+ predictor = DefaultPredictor(cfg)
224
+ predictor.model.to(device)
225
+
226
+ PREDICTORS[backbone][dataset] = predictor
227
+ METADATA[backbone][dataset] = metadata
228
+ print(f"✓ Loaded {backbone} for {dataset}")
229
+ print("All modules setup complete!")
230
+
231
+ def setup_cfg(dataset, backbone):
232
+ # load config from file and command-line arguments
233
+ cfg = get_cfg()
234
+ add_deeplab_config(cfg)
235
+ add_common_config(cfg)
236
+ add_swin_config(cfg)
237
+ add_oneformer_config(cfg)
238
+ add_dinat_config(cfg)
239
+ dataset = KEY_DICT[dataset]
240
+ cfg_path = CFG_DICT[backbone][dataset]
241
+ cfg.merge_from_file(cfg_path)
242
+
243
+ # Explicitly set device to CUDA if available
244
+ if torch.cuda.is_available():
245
+ cfg.MODEL.DEVICE = 'cuda:0'
246
+ print(f"Config set to use CUDA device")
247
+ else:
248
+ cfg.MODEL.DEVICE = 'cpu'
249
+ print(f"Config set to use CPU device")
250
+
251
+ cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
252
+ cfg.freeze()
253
+ return cfg
254
+
255
+ # Rest of your functions remain the same
256
+ def panoptic_run(img, predictor, metadata):
257
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
258
+ predictions = predictor(img, "panoptic")
259
+ panoptic_seg, segments_info = predictions["panoptic_seg"]
260
+ out = visualizer.draw_panoptic_seg_predictions(
261
+ panoptic_seg.to(cpu_device), segments_info, alpha=0.5
262
+ )
263
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
264
+ out_map = visualizer_map.draw_panoptic_seg_predictions(
265
+ panoptic_seg.to(cpu_device), segments_info, alpha=1, is_text=False
266
+ )
267
+ return out, out_map, predictions
268
+
269
+ def instance_run(img, predictor, metadata):
270
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
271
+ predictions = predictor(img, "instance")
272
+ instances = predictions["instances"].to(cpu_device)
273
+ out = visualizer.draw_instance_predictions(predictions=instances, alpha=0.5)
274
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
275
+ out_map = visualizer_map.draw_instance_predictions(predictions=instances, alpha=1, is_text=False)
276
+ return out, out_map, predictions
277
+
278
+ def semantic_run(img, predictor, metadata):
279
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
280
+ predictions = predictor(img, "semantic")
281
+ out = visualizer.draw_sem_seg(
282
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
283
+ )
284
+ visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
285
+ out_map = visualizer_map.draw_sem_seg(
286
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=1, is_text=False
287
+ )
288
+ return out, out_map, predictions
289
+
290
+ TASK_INFER = {"the task is panoptic": panoptic_run, "the task is instance": instance_run, "the task is semantic": semantic_run}
291
+
292
+ def create_contrast_visualization(img, contrast_mask, problem_areas, segment_colors):
293
+ """Create visualization of contrast issues"""
294
+ # Copy original image
295
+ contrast_viz = img.copy()
296
+
297
+ # Highlight low contrast boundaries
298
+ boundary_color = (255, 0, 0) # Red for problem areas
299
+ contrast_viz[contrast_mask] = boundary_color
300
+
301
+ # Add information overlay
302
+ info_text = f"Low contrast areas detected: {len(problem_areas)}"
303
+ cv2.putText(contrast_viz, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
304
+
305
+ return contrast_viz
306
+
307
+ def segment_and_analyze(path, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold):
308
+ # Get predictions and segmentation visualization
309
+ predictor = PREDICTORS[backbone][dataset]
310
+ metadata = METADATA[backbone][dataset]
311
+ img = cv2.imread(path)
312
+ width = WIDTH_DICT[KEY_DICT[dataset]]
313
+ img = imutils.resize(img, width=width)
314
+
315
+ out, out_map, predictions = TASK_INFER[task](img, predictor, metadata)
316
+ out_img = Image.fromarray(out.get_image())
317
+ out_map_img = Image.fromarray(out_map.get_image())
318
+
319
+ if not enable_contrast:
320
+ return out_img, out_map_img, None, None
321
+
322
+ # Extract segmentation mask from predictions
323
+ if task == "the task is semantic":
324
+ seg_mask = predictions["sem_seg"].argmax(dim=0).cpu().numpy()
325
+ elif task == "the task is panoptic":
326
+ seg_mask, _ = predictions["panoptic_seg"]
327
+ seg_mask = seg_mask.cpu().numpy()
328
+ elif task == "the task is instance":
329
+ # For instance segmentation, create a mask from instances
330
+ instances = predictions["instances"].to("cpu")
331
+ seg_mask = np.zeros(img.shape[:2], dtype=np.int32)
332
+ for i, mask in enumerate(instances.pred_masks):
333
+ seg_mask[mask] = i + 1
334
+
335
+ # Analyze contrast
336
+ contrast_mask, problem_areas, segment_colors = ContrastDetector.analyze_contrast(
337
+ img, seg_mask, method=contrast_method, threshold=contrast_threshold
338
+ )
339
+
340
+ # Create contrast visualization
341
+ contrast_viz = create_contrast_visualization(img, contrast_mask, problem_areas, segment_colors)
342
+ contrast_viz_img = Image.fromarray(contrast_viz[:, :, ::-1]) # Convert BGR to RGB
343
+
344
+ # Generate analysis report
345
+ report = f"### Contrast Analysis Report\n\n"
346
+ report += f"**Method:** {contrast_method.capitalize()}\n"
347
+ report += f"**Threshold:** {contrast_threshold}\n"
348
+ report += f"**Total segments:** {len(segment_colors)}\n"
349
+ report += f"**Low contrast boundaries found:** {len(problem_areas)}\n\n"
350
+
351
+ if problem_areas:
352
+ report += "**Problem Areas:**\n"
353
+ for i, (seg1, seg2, contrast_value) in enumerate(problem_areas[:10]): # Show first 10
354
+ report += f"- Segments {seg1} and {seg2}: Contrast ratio = {contrast_value:.2f}\n"
355
+ if len(problem_areas) > 10:
356
+ report += f"... and {len(problem_areas) - 10} more\n"
357
+
358
+ return out_img, out_map_img, contrast_viz_img, report
359
+
360
+ title = "<h1 style='text-align: center'>OneFormer:DIEGO MENTORIA MILIONÁRIA - APP 1</h1>"
361
+ description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/lolout1/sam2Contrast' style='text-decoration:none' target='_blank'>NeuroNest Contrast Model</a></p>" \
362
+ + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/oneformer/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2211.06220' target='_blank'>ArXiv Paper</a> | <a href='https://github.com/SHI-Labs/OneFormer' target='_blank'>Github Repo</a></p>" \
363
+ + "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
364
+ This model leverages the OneFormer architecture to perform comprehensive image segmentation and labeling across multiple tasks. The system can identify and segment various objects, structures, and regions within images with high accuracy. It supports semantic, instance, and panoptic segmentation modes, enabling detailed analysis of indoor and outdoor environments. The model excels at distinguishing between different classes of objects, from common everyday items to complex urban structures, making it particularly useful for environmental analysis and scene understanding applications.\
365
+ </p>" \
366
+ + "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a single RTX A6000 GPU, OneFormer is able to inference at more than 15 FPS.]</p>"
367
+
368
+ # Main execution with error handling
369
+ if __name__ == "__main__":
370
+ try:
371
+ print("Starting setup...")
372
+ setup_modules()
373
+
374
+ print("Creating Gradio interface...")
375
+ with gr.Blocks(title="OneFormer with Contrast Detection") as iface:
376
+ gr.Markdown(title)
377
+ gr.Markdown(description)
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=1):
381
+ input_image = gr.Image(label="Input Image", type="filepath")
382
+ task = gr.Radio(
383
+ choices=["the task is panoptic", "the task is instance", "the task is semantic"],
384
+ value="the task is panoptic",
385
+ label="Task Token Input"
386
+ )
387
+ dataset = gr.Radio(
388
+ choices=["COCO (133 classes)", "Cityscapes (19 classes)", "ADE20K (150 classes)"],
389
+ value="COCO (133 classes)",
390
+ label="Model"
391
+ )
392
+ backbone = gr.Radio(
393
+ choices=["DiNAT-L", "Swin-L"],
394
+ value="DiNAT-L",
395
+ label="Backbone"
396
+ )
397
+
398
+ with gr.Accordion("Contrast Detection Options", open=False):
399
+ enable_contrast = gr.Checkbox(
400
+ label="Enable Contrast Detection",
401
+ value=False
402
+ )
403
+ contrast_method = gr.Radio(
404
+ choices=["luminance", "hue", "saturation"],
405
+ value="luminance",
406
+ label="Contrast Method"
407
+ )
408
+ contrast_threshold = gr.Slider(
409
+ minimum=1.0,
410
+ maximum=10.0,
411
+ value=4.5,
412
+ step=0.1,
413
+ label="Contrast Threshold (WCAG AA is 4.5)"
414
+ )
415
+
416
+ submit_btn = gr.Button("Analyze", variant="primary")
417
+
418
+ with gr.Column(scale=2):
419
+ with gr.Tabs():
420
+ with gr.TabItem("Segmentation Results"):
421
+ seg_output = gr.Image(type="pil", label="Segmentation Overlay")
422
+ seg_map = gr.Image(type="pil", label="Segmentation Map")
423
+
424
+ with gr.TabItem("Contrast Analysis"):
425
+ contrast_viz = gr.Image(type="pil", label="Contrast Visualization")
426
+ contrast_report = gr.Markdown(label="Contrast Analysis Report")
427
+
428
+ examples = [
429
+ ["examples/coco.jpeg", "the task is panoptic", "COCO (133 classes)", "DiNAT-L", False, "luminance", 4.5],
430
+ ["examples/cityscapes.png", "the task is panoptic", "Cityscapes (19 classes)", "DiNAT-L", False, "luminance", 4.5],
431
+ ["examples/ade20k.jpeg", "the task is panoptic", "ADE20K (150 classes)", "DiNAT-L", False, "luminance", 4.5]
432
+ ]
433
+
434
+ gr.Examples(
435
+ examples=examples,
436
+ inputs=[input_image, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold],
437
+ outputs=[seg_output, seg_map, contrast_viz, contrast_report],
438
+ fn=segment_and_analyze,
439
+ cache_examples=False
440
+ )
441
+
442
+ submit_btn.click(
443
+ fn=segment_and_analyze,
444
+ inputs=[input_image, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold],
445
+ outputs=[seg_output, seg_map, contrast_viz, contrast_report]
446
+ )
447
+
448
+ print("Launching Gradio app...")
449
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
450
+
451
+ except Exception as e:
452
+ print(f"Error occurred: {str(e)}")
453
+ import traceback
454
+ traceback.print_exc()
455
+ sys.exit(1)
gradio_gpu.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ import imutils
6
+ import os
7
+ import sys
8
+ import time
9
+ from detectron2.config import get_cfg
10
+ from detectron2.projects.deeplab import add_deeplab_config
11
+ from detectron2.data import MetadataCatalog
12
+ from scipy import ndimage
13
+ import colorsys
14
+ import math
15
+
16
+ torch.set_num_threads(16)
17
+ torch.set_num_interop_threads(16)
18
+
19
+ from oneformer import (
20
+ add_oneformer_config,
21
+ add_common_config,
22
+ add_swin_config,
23
+ add_dinat_config,
24
+ )
25
+
26
+ from demo.defaults import DefaultPredictor
27
+ from demo.visualizer import Visualizer, ColorMode
28
+
29
+ import gradio as gr
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ # NeuroNest specific imports
33
+ from utils.contrast_detector import ContrastDetector
34
+ from utils.luminance_contrast import LuminanceContrastDetector
35
+ from utils.hue_contrast import HueContrastDetector
36
+ from utils.saturation_contrast import SaturationContrastDetector
37
+ from utils.combined_contrast import CombinedContrastDetector
38
+
39
+ KEY_DICT = {
40
+ "ADE20K (150 classes)": "ade20k",
41
+ }
42
+
43
+ SWIN_CFG_DICT = {
44
+ "ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
45
+ }
46
+
47
+ SWIN_MODEL_DICT = {
48
+ "ade20k": hf_hub_download(
49
+ repo_id="shi-labs/oneformer_ade20k_swin_large",
50
+ filename="250_16_swin_l_oneformer_ade20k_160k.pth"
51
+ )
52
+ }
53
+
54
+ DINAT_CFG_DICT = {
55
+ "ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",
56
+ }
57
+
58
+ DINAT_MODEL_DICT = {
59
+ "ade20k": hf_hub_download(
60
+ repo_id="shi-labs/oneformer_ade20k_dinat_large",
61
+ filename="250_16_dinat_l_oneformer_ade20k_160k.pth"
62
+ )
63
+ }
64
+
65
+ MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT, "Swin-L": SWIN_MODEL_DICT}
66
+ CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT, "Swin-L": SWIN_CFG_DICT}
67
+ WIDTH_DICT = {"ade20k": 640}
68
+
69
+ cpu_device = torch.device("cpu")
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ print(f"Using device: {device}")
72
+
73
+ PREDICTORS = {
74
+ "DiNAT-L": {"ADE20K (150 classes)": None},
75
+ "Swin-L": {"ADE20K (150 classes)": None}
76
+ }
77
+
78
+ METADATA = {
79
+ "DiNAT-L": {"ADE20K (150 classes)": None},
80
+ "Swin-L": {"ADE20K (150 classes)": None}
81
+ }
82
+
83
+ # Contrast detector mapping
84
+ CONTRAST_DETECTORS = {
85
+ "Luminance (WCAG)": LuminanceContrastDetector(),
86
+ "Hue": HueContrastDetector(),
87
+ "Saturation": SaturationContrastDetector(),
88
+ "Combined": CombinedContrastDetector()
89
+ }
90
+
91
+ def setup_modules():
92
+ for dataset in ["ADE20K (150 classes)"]:
93
+ for backbone in ["DiNAT-L", "Swin-L"]:
94
+ cfg = setup_cfg(dataset, backbone)
95
+ metadata = MetadataCatalog.get(
96
+ cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
97
+ )
98
+ PREDICTORS[backbone][dataset] = DefaultPredictor(cfg)
99
+ METADATA[backbone][dataset] = metadata
100
+
101
+ def setup_cfg(dataset, backbone):
102
+ cfg = get_cfg()
103
+ add_deeplab_config(cfg)
104
+ add_common_config(cfg)
105
+ add_swin_config(cfg)
106
+ add_oneformer_config(cfg)
107
+ add_dinat_config(cfg)
108
+ dataset = KEY_DICT[dataset]
109
+ cfg_path = CFG_DICT[backbone][dataset]
110
+ cfg.merge_from_file(cfg_path)
111
+ cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
112
+ cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
113
+ cfg.freeze()
114
+ return cfg
115
+
116
+ def semantic_run(img, predictor, metadata):
117
+ visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
118
+ predictions = predictor(img, "semantic")
119
+ out = visualizer.draw_sem_seg(
120
+ predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
121
+ )
122
+ return out, predictions["sem_seg"].argmax(dim=0).to(cpu_device).numpy()
123
+
124
+ def analyze_contrast(image, segmentation, contrast_method, threshold):
125
+ """Analyze contrast between segments using selected method"""
126
+ detector = CONTRAST_DETECTORS[contrast_method]
127
+
128
+ # Perform contrast analysis
129
+ contrast_image, problem_areas, stats = detector.analyze(
130
+ image, segmentation, threshold
131
+ )
132
+
133
+ return contrast_image, problem_areas, stats
134
+
135
+ def segment_and_analyze_contrast(path, backbone, contrast_method, threshold):
136
+ """Main function to segment and analyze contrast"""
137
+ dataset = "ADE20K (150 classes)"
138
+ predictor = PREDICTORS[backbone][dataset]
139
+ metadata = METADATA[backbone][dataset]
140
+
141
+ # Read and resize image
142
+ img = cv2.imread(path)
143
+ if img is None:
144
+ return None, None, "Error: Could not load image"
145
+
146
+ width = WIDTH_DICT[KEY_DICT[dataset]]
147
+ img = imutils.resize(img, width=width)
148
+
149
+ # Get segmentation
150
+ out, seg_mask = semantic_run(img, predictor, metadata)
151
+ out_img = Image.fromarray(out.get_image())
152
+
153
+ # Analyze contrast
154
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
155
+ contrast_img, problem_areas, stats = analyze_contrast(
156
+ img_rgb, seg_mask, contrast_method, threshold
157
+ )
158
+
159
+ # Create stats text
160
+ stats_text = f"### Contrast Analysis Results\n\n"
161
+ stats_text += f"**Method:** {contrast_method}\n"
162
+ stats_text += f"**Threshold:** {threshold:.2f}\n"
163
+ stats_text += f"**Problem Areas:** {stats['problem_count']}\n"
164
+
165
+ if 'min_contrast' in stats:
166
+ stats_text += f"**Min Contrast:** {stats['min_contrast']:.2f}\n"
167
+ if 'max_contrast' in stats:
168
+ stats_text += f"**Max Contrast:** {stats['max_contrast']:.2f}\n"
169
+ if 'average_contrast' in stats:
170
+ stats_text += f"**Average Contrast:** {stats['average_contrast']:.2f}\n"
171
+
172
+ # Convert contrast image to PIL
173
+ contrast_pil = Image.fromarray(contrast_img)
174
+
175
+ return out_img, contrast_pil, stats_text
176
+
177
+ # Initialize models
178
+ setup_modules()
179
+
180
+ # Gradio Interface
181
+ title = "<h1 style='text-align: center'>NeuroNest: Abheek Pradhan - Contrast Model</h1>"
182
+ description = "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> "\
183
+ "<a href='https://github.com/lolout1/sam2Contrast' target='_blank'>Github Repo</a></p>" \
184
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'>" \
185
+ "I am developing NeuroNest, a contrast detection system designed to identify areas with insufficient contrast " \
186
+ "for individuals with Alzheimer's disease. This tool leverages OneFormer's state-of-the-art segmentation " \
187
+ "capabilities trained on ADE20K dataset to detect indoor objects like floors, furniture, walls, and ceilings. " \
188
+ "By analyzing contrast ratios between adjacent segments, NeuroNest flags potential visual accessibility issues " \
189
+ "that may trigger confusion or disorientation in elderly individuals with cognitive impairments.</p>" \
190
+ "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'>" \
191
+ "[Note: When running on my Linux cluster, please request a GPU node for optimal performance. " \
192
+ "On login nodes, CUDA may not be available.]</p>"
193
+
194
+ gradio_inputs = [
195
+ gr.Image(label="Input Image", type="filepath"),
196
+ gr.Radio(choices=["Swin-L", "DiNAT-L"], value="Swin-L", label="Backbone"),
197
+ gr.Radio(
198
+ choices=["Luminance (WCAG)", "Hue", "Saturation", "Combined"],
199
+ value="Luminance (WCAG)",
200
+ label="Contrast Detection Method"
201
+ ),
202
+ gr.Slider(
203
+ minimum=1.0,
204
+ maximum=10.0,
205
+ value=4.5,
206
+ step=0.1,
207
+ label="Contrast Threshold (Lower = More Strict)"
208
+ )
209
+ ]
210
+
211
+ gradio_outputs = [
212
+ gr.Image(type="pil", label="Segmentation Result"),
213
+ gr.Image(type="pil", label="Contrast Analysis"),
214
+ gr.Markdown(label="Analysis Results")
215
+ ]
216
+
217
+ examples = [
218
+ ["examples/indoor_room.jpg", "Swin-L", "Luminance (WCAG)", 4.5],
219
+ ["examples/living_room.jpg", "DiNAT-L", "Combined", 3.0],
220
+ ]
221
+
222
+ iface = gr.Interface(
223
+ fn=segment_and_analyze_contrast,
224
+ inputs=gradio_inputs,
225
+ outputs=gradio_outputs,
226
+ examples_per_page=5,
227
+ allow_flagging="never",
228
+ examples=examples,
229
+ title=title,
230
+ description=description
231
+ )
232
+
233
+ if __name__ == "__main__":
234
+ iface.launch(server_name="0.0.0.0", share=True)
gradio_test.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ import os
6
+ import sys
7
+ import time
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Tuple, Dict, List, Optional, Union
11
+ import gradio as gr
12
+ from huggingface_hub import hf_hub_download
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ # Detectron2 imports
17
+ from detectron2.config import get_cfg
18
+ from detectron2.projects.deeplab import add_deeplab_config
19
+ from detectron2.data import MetadataCatalog
20
+ from detectron2.engine import DefaultPredictor as DetectronPredictor
21
+ from detectron2 import model_zoo
22
+ from detectron2.utils.visualizer import Visualizer, ColorMode
23
+
24
+ # OneFormer imports
25
+ try:
26
+ from oneformer import (
27
+ add_oneformer_config,
28
+ add_common_config,
29
+ add_swin_config,
30
+ add_dinat_config,
31
+ )
32
+ from demo.defaults import DefaultPredictor as OneFormerPredictor
33
+ ONEFORMER_AVAILABLE = True
34
+ except ImportError as e:
35
+ print(f"OneFormer not available: {e}")
36
+ ONEFORMER_AVAILABLE = False
37
+
38
+ # Setup logging
39
+ logging.basicConfig(level=logging.INFO)
40
+ logger = logging.getLogger(__name__)
41
+
42
+ ########################################
43
+ # GLOBAL CONFIGURATIONS
44
+ ########################################
45
+
46
+ # Device configuration
47
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
48
+ CPU_DEVICE = torch.device("cpu")
49
+ torch.set_num_threads(4)
50
+
51
+ # ADE20K class mappings for floor detection
52
+ FLOOR_CLASSES = {
53
+ 'floor': [3, 4, 13], # floor, wood floor, rug
54
+ 'carpet': [28], # carpet
55
+ 'mat': [78], # mat
56
+ }
57
+
58
+ # OneFormer configurations
59
+ ONEFORMER_CONFIG = {
60
+ "ADE20K": {
61
+ "key": "ade20k",
62
+ "swin_cfg": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
63
+ "swin_model": "shi-labs/oneformer_ade20k_swin_large",
64
+ "swin_file": "250_16_swin_l_oneformer_ade20k_160k.pth",
65
+ "width": 640
66
+ }
67
+ }
68
+
69
+ ########################################
70
+ # IMPORT UNIVERSAL CONTRAST ANALYZER
71
+ ########################################
72
+
73
+ from utils.universal_contrast_analyzer import UniversalContrastAnalyzer
74
+
75
+ # Keep old class for compatibility but deprecated
76
+ class RobustContrastAnalyzer:
77
+ """Advanced contrast analyzer for Alzheimer's-friendly environments"""
78
+
79
+ def __init__(self, wcag_threshold: float = 4.5):
80
+ self.wcag_threshold = wcag_threshold
81
+
82
+ # ADE20K class mappings for important objects
83
+ self.semantic_classes = {
84
+ 'floor': [3, 4, 13, 28, 78], # floor, wood floor, rug, carpet, mat
85
+ 'wall': [0, 1, 9], # wall, building, brick
86
+ 'ceiling': [5], # ceiling
87
+ 'furniture': [10, 19, 15, 7, 18, 23], # sofa, chair, table, bed, armchair, cabinet
88
+ 'door': [25], # door
89
+ 'window': [8], # window
90
+ 'stairs': [53], # stairs
91
+ }
92
+
93
+ # Priority relationships for safety
94
+ self.priority_relationships = {
95
+ ('floor', 'furniture'): ('critical', 'Furniture must be clearly visible against floor'),
96
+ ('floor', 'stairs'): ('critical', 'Stairs must have clear contrast with floor'),
97
+ ('floor', 'door'): ('high', 'Door should be easily distinguishable from floor'),
98
+ ('wall', 'furniture'): ('high', 'Furniture should stand out from walls'),
99
+ ('wall', 'door'): ('high', 'Doors should be clearly visible on walls'),
100
+ ('wall', 'window'): ('medium', 'Windows should have adequate contrast'),
101
+ ('ceiling', 'wall'): ('low', 'Ceiling-wall contrast is less critical'),
102
+ }
103
+
104
+ def get_object_category(self, class_id: int) -> str:
105
+ """Map segmentation class to object category"""
106
+ for category, class_ids in self.semantic_classes.items():
107
+ if class_id in class_ids:
108
+ return category
109
+ return 'other'
110
+
111
+ def calculate_wcag_contrast(self, color1: np.ndarray, color2: np.ndarray) -> float:
112
+ """Calculate WCAG contrast ratio"""
113
+ def relative_luminance(rgb):
114
+ rgb_norm = rgb / 255.0
115
+ rgb_linear = np.where(rgb_norm <= 0.03928,
116
+ rgb_norm / 12.92,
117
+ ((rgb_norm + 0.055) / 1.055) ** 2.4)
118
+ return np.dot(rgb_linear, [0.2126, 0.7152, 0.0722])
119
+
120
+ lum1 = relative_luminance(color1)
121
+ lum2 = relative_luminance(color2)
122
+
123
+ lighter = max(lum1, lum2)
124
+ darker = min(lum1, lum2)
125
+
126
+ return (lighter + 0.05) / (darker + 0.05)
127
+
128
+ def extract_dominant_color(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
129
+ """Extract dominant color from masked region"""
130
+ if not np.any(mask):
131
+ return np.array([128, 128, 128])
132
+
133
+ masked_pixels = image[mask]
134
+ if len(masked_pixels) == 0:
135
+ return np.array([128, 128, 128])
136
+
137
+ # Use median for robustness against outliers
138
+ return np.median(masked_pixels, axis=0).astype(int)
139
+
140
+ def find_adjacent_segments(self, seg1_mask: np.ndarray, seg2_mask: np.ndarray,
141
+ min_boundary_length: int = 30) -> np.ndarray:
142
+ """Find clean boundaries between segments"""
143
+ kernel = np.ones((3, 3), np.uint8)
144
+ dilated1 = cv2.dilate(seg1_mask.astype(np.uint8), kernel, iterations=1)
145
+ dilated2 = cv2.dilate(seg2_mask.astype(np.uint8), kernel, iterations=1)
146
+
147
+ boundary = dilated1 & dilated2
148
+
149
+ # Remove small disconnected components
150
+ contours, _ = cv2.findContours(boundary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
151
+ clean_boundary = np.zeros_like(boundary)
152
+
153
+ for contour in contours:
154
+ if cv2.contourArea(contour) >= min_boundary_length:
155
+ cv2.fillPoly(clean_boundary, [contour], 1)
156
+
157
+ return clean_boundary.astype(bool)
158
+
159
+ def analyze_contrast(self, image: np.ndarray, segmentation: np.ndarray) -> Dict:
160
+ """Perform comprehensive contrast analysis"""
161
+ h, w = segmentation.shape
162
+ results = {
163
+ 'critical_issues': [],
164
+ 'high_issues': [],
165
+ 'medium_issues': [],
166
+ 'visualization': image.copy(),
167
+ 'statistics': {}
168
+ }
169
+
170
+ # Build segment information
171
+ unique_segments = np.unique(segmentation)
172
+ segment_info = {}
173
+
174
+ for seg_id in unique_segments:
175
+ if seg_id == 0: # Skip background
176
+ continue
177
+
178
+ mask = segmentation == seg_id
179
+ if np.sum(mask) < 100: # Skip very small segments
180
+ continue
181
+
182
+ category = self.get_object_category(seg_id)
183
+ if category == 'other':
184
+ continue
185
+
186
+ segment_info[seg_id] = {
187
+ 'category': category,
188
+ 'mask': mask,
189
+ 'color': self.extract_dominant_color(image, mask),
190
+ 'area': np.sum(mask)
191
+ }
192
+
193
+ # Analyze priority relationships
194
+ issue_counts = {'critical': 0, 'high': 0, 'medium': 0}
195
+
196
+ for seg_id1, info1 in segment_info.items():
197
+ for seg_id2, info2 in segment_info.items():
198
+ if seg_id1 >= seg_id2:
199
+ continue
200
+
201
+ # Check if this is a priority relationship
202
+ relationship = tuple(sorted([info1['category'], info2['category']]))
203
+ if relationship not in self.priority_relationships:
204
+ continue
205
+
206
+ priority, description = self.priority_relationships[relationship]
207
+
208
+ # Check if segments are adjacent
209
+ boundary = self.find_adjacent_segments(info1['mask'], info2['mask'])
210
+ if not np.any(boundary):
211
+ continue
212
+
213
+ # Calculate contrast
214
+ wcag_contrast = self.calculate_wcag_contrast(info1['color'], info2['color'])
215
+
216
+ # Determine if there's an issue
217
+ if wcag_contrast < self.wcag_threshold:
218
+ issue = {
219
+ 'categories': (info1['category'], info2['category']),
220
+ 'contrast_ratio': wcag_contrast,
221
+ 'boundary_area': np.sum(boundary),
222
+ 'description': description,
223
+ 'priority': priority
224
+ }
225
+
226
+ # Color-code boundaries and store issues
227
+ if priority == 'critical':
228
+ results['critical_issues'].append(issue)
229
+ results['visualization'][boundary] = [255, 0, 0] # Red
230
+ issue_counts['critical'] += 1
231
+ elif priority == 'high':
232
+ results['high_issues'].append(issue)
233
+ results['visualization'][boundary] = [255, 165, 0] # Orange
234
+ issue_counts['high'] += 1
235
+ elif priority == 'medium':
236
+ results['medium_issues'].append(issue)
237
+ results['visualization'][boundary] = [255, 255, 0] # Yellow
238
+ issue_counts['medium'] += 1
239
+
240
+ # Calculate statistics
241
+ results['statistics'] = {
242
+ 'total_segments': len(segment_info),
243
+ 'total_issues': sum(issue_counts.values()),
244
+ 'critical_count': issue_counts['critical'],
245
+ 'high_count': issue_counts['high'],
246
+ 'medium_count': issue_counts['medium'],
247
+ 'wcag_threshold': self.wcag_threshold
248
+ }
249
+
250
+ return results
251
+
252
+ ########################################
253
+ # ONEFORMER INTEGRATION
254
+ ########################################
255
+
256
+ class OneFormerManager:
257
+ """Manages OneFormer model loading and inference"""
258
+
259
+ def __init__(self):
260
+ self.predictor = None
261
+ self.metadata = None
262
+ self.initialized = False
263
+
264
+ def initialize(self, backbone: str = "swin"):
265
+ """Initialize OneFormer model"""
266
+ if not ONEFORMER_AVAILABLE:
267
+ logger.error("OneFormer not available")
268
+ return False
269
+
270
+ try:
271
+ cfg = get_cfg()
272
+ add_deeplab_config(cfg)
273
+ add_common_config(cfg)
274
+ add_swin_config(cfg)
275
+ add_oneformer_config(cfg)
276
+ add_dinat_config(cfg)
277
+
278
+ config = ONEFORMER_CONFIG["ADE20K"]
279
+ cfg.merge_from_file(config["swin_cfg"])
280
+ cfg.MODEL.DEVICE = DEVICE
281
+
282
+ # Download model if not exists
283
+ model_path = hf_hub_download(
284
+ repo_id=config["swin_model"],
285
+ filename=config["swin_file"]
286
+ )
287
+ cfg.MODEL.WEIGHTS = model_path
288
+ cfg.freeze()
289
+
290
+ self.predictor = OneFormerPredictor(cfg)
291
+ self.metadata = MetadataCatalog.get(
292
+ cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
293
+ )
294
+ self.initialized = True
295
+ logger.info("OneFormer initialized successfully")
296
+ return True
297
+
298
+ except Exception as e:
299
+ logger.error(f"Failed to initialize OneFormer: {e}")
300
+ return False
301
+
302
+ def semantic_segmentation(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
303
+ """Perform semantic segmentation"""
304
+ if not self.initialized:
305
+ raise RuntimeError("OneFormer not initialized")
306
+
307
+ # Resize image to expected width
308
+ width = ONEFORMER_CONFIG["ADE20K"]["width"]
309
+ h, w = image.shape[:2]
310
+ if w != width:
311
+ scale = width / w
312
+ new_h = int(h * scale)
313
+ image_resized = cv2.resize(image, (width, new_h))
314
+ else:
315
+ image_resized = image
316
+
317
+ # Run prediction
318
+ predictions = self.predictor(image_resized, "semantic")
319
+ seg_mask = predictions["sem_seg"].argmax(dim=0).cpu().numpy()
320
+
321
+ # Create visualization
322
+ visualizer = Visualizer(
323
+ image_resized[:, :, ::-1],
324
+ metadata=self.metadata,
325
+ instance_mode=ColorMode.IMAGE
326
+ )
327
+ vis_output = visualizer.draw_sem_seg(seg_mask, alpha=0.5)
328
+ vis_image = vis_output.get_image()[:, :, ::-1] # BGR to RGB
329
+
330
+ return seg_mask, vis_image
331
+
332
+ def extract_floor_areas(self, segmentation: np.ndarray) -> np.ndarray:
333
+ """Extract floor areas from segmentation"""
334
+ floor_mask = np.zeros_like(segmentation, dtype=bool)
335
+
336
+ for class_ids in FLOOR_CLASSES.values():
337
+ for class_id in class_ids:
338
+ floor_mask |= (segmentation == class_id)
339
+
340
+ return floor_mask
341
+
342
+
343
+ ########################################
344
+ # ENHANCED BLACKSPOT DETECTION WITH CLEAR VISUALIZATION
345
+ ########################################
346
+
347
+ class BlackspotDetector:
348
+ """Manages blackspot detection with MaskRCNN - Enhanced Version"""
349
+
350
+ def __init__(self, model_path: str):
351
+ self.model_path = model_path
352
+ self.predictor = None
353
+
354
+ def initialize(self, threshold: float = 0.5) -> bool:
355
+ """Initialize MaskRCNN model"""
356
+ try:
357
+ cfg = get_cfg()
358
+ cfg.merge_from_file(
359
+ model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
360
+ )
361
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # [floors, blackspot]
362
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
363
+ cfg.MODEL.WEIGHTS = self.model_path
364
+ cfg.MODEL.DEVICE = DEVICE
365
+
366
+ self.predictor = DetectronPredictor(cfg)
367
+ logger.info("MaskRCNN blackspot detector initialized")
368
+ return True
369
+
370
+ except Exception as e:
371
+ logger.error(f"Failed to initialize blackspot detector: {e}")
372
+ return False
373
+
374
+ def create_enhanced_visualizations(self, image: np.ndarray, floor_mask: np.ndarray,
375
+ blackspot_mask: np.ndarray) -> Dict:
376
+ """Create multiple enhanced visualizations of blackspot detection"""
377
+
378
+ # 1. Pure Segmentation View (like semantic segmentation output)
379
+ segmentation_view = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
380
+ segmentation_view[floor_mask] = [34, 139, 34] # Forest green for floor
381
+ segmentation_view[blackspot_mask] = [255, 0, 0] # Bright red for blackspots
382
+ segmentation_view[~(floor_mask | blackspot_mask)] = [128, 128, 128] # Gray for other areas
383
+
384
+ # 2. High Contrast Overlay
385
+ high_contrast_overlay = image.copy()
386
+ # Make background slightly darker to emphasize blackspots
387
+ high_contrast_overlay = cv2.convertScaleAbs(high_contrast_overlay, alpha=0.6, beta=0)
388
+ # Add bright overlays
389
+ high_contrast_overlay[floor_mask] = cv2.addWeighted(
390
+ high_contrast_overlay[floor_mask], 0.7,
391
+ np.full_like(high_contrast_overlay[floor_mask], [0, 255, 0]), 0.3, 0
392
+ )
393
+ high_contrast_overlay[blackspot_mask] = [255, 0, 255] # Magenta for maximum visibility
394
+
395
+ # 3. Blackspot-only View (white blackspots on black background)
396
+ blackspot_only = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
397
+ blackspot_only[blackspot_mask] = [255, 255, 255] # White blackspots
398
+ blackspot_only[floor_mask & ~blackspot_mask] = [64, 64, 64] # Dark gray for floor areas
399
+
400
+ # 4. Side-by-side comparison
401
+ h, w = image.shape[:2]
402
+ side_by_side = np.zeros((h, w * 2, 3), dtype=np.uint8)
403
+ side_by_side[:, :w] = image
404
+ side_by_side[:, w:] = segmentation_view
405
+
406
+ # Add text labels
407
+ cv2.putText(side_by_side, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
408
+ cv2.putText(side_by_side, "Blackspot Detection", (w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
409
+
410
+ # 5. Annotated view with bounding boxes and labels
411
+ annotated_view = image.copy()
412
+
413
+ # Find blackspot contours for bounding boxes
414
+ blackspot_contours, _ = cv2.findContours(
415
+ blackspot_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
416
+ )
417
+
418
+ for i, contour in enumerate(blackspot_contours):
419
+ if cv2.contourArea(contour) > 50: # Filter small artifacts
420
+ # Draw bounding box
421
+ x, y, w, h = cv2.boundingRect(contour)
422
+ cv2.rectangle(annotated_view, (x, y), (x + w, y + h), (255, 0, 0), 2)
423
+
424
+ # Draw contour
425
+ cv2.drawContours(annotated_view, [contour], -1, (255, 0, 255), 2)
426
+
427
+ # Add label
428
+ area = cv2.contourArea(contour)
429
+ label = f"Blackspot {i+1}: {area:.0f}px"
430
+ cv2.putText(annotated_view, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
431
+
432
+ return {
433
+ 'segmentation_view': segmentation_view,
434
+ 'high_contrast_overlay': high_contrast_overlay,
435
+ 'blackspot_only': blackspot_only,
436
+ 'side_by_side': side_by_side,
437
+ 'annotated_view': annotated_view
438
+ }
439
+
440
+ def detect_blackspots(self, image: np.ndarray, floor_prior: Optional[np.ndarray] = None) -> Dict:
441
+ """Detect blackspots with enhanced visualizations"""
442
+ if self.predictor is None:
443
+ raise RuntimeError("Blackspot detector not initialized")
444
+
445
+ # Get original image dimensions
446
+ original_h, original_w = image.shape[:2]
447
+
448
+ # Handle floor prior shape mismatch
449
+ processed_image = image.copy()
450
+ if floor_prior is not None:
451
+ prior_h, prior_w = floor_prior.shape
452
+
453
+ # Resize floor_prior to match original image if needed
454
+ if (prior_h, prior_w) != (original_h, original_w):
455
+ logger.info(f"Resizing floor prior from {(prior_h, prior_w)} to {(original_h, original_w)}")
456
+ floor_prior_resized = cv2.resize(
457
+ floor_prior.astype(np.uint8),
458
+ (original_w, original_h),
459
+ interpolation=cv2.INTER_NEAREST
460
+ ).astype(bool)
461
+ else:
462
+ floor_prior_resized = floor_prior
463
+ else:
464
+ floor_prior_resized = None
465
+
466
+ # Run detection on the processed image
467
+ try:
468
+ outputs = self.predictor(processed_image)
469
+ instances = outputs["instances"].to("cpu")
470
+ except Exception as e:
471
+ logger.error(f"Error in MaskRCNN prediction: {e}")
472
+ # Return empty results
473
+ empty_mask = np.zeros(image.shape[:2], dtype=bool)
474
+ return {
475
+ 'visualization': image,
476
+ 'floor_mask': empty_mask,
477
+ 'blackspot_mask': empty_mask,
478
+ 'floor_area': 0,
479
+ 'blackspot_area': 0,
480
+ 'coverage_percentage': 0,
481
+ 'num_detections': 0,
482
+ 'avg_confidence': 0.0,
483
+ 'enhanced_views': self.create_enhanced_visualizations(image, empty_mask, empty_mask)
484
+ }
485
+
486
+ # Process results
487
+ if len(instances) == 0:
488
+ # No detections
489
+ combined_floor = floor_prior_resized if floor_prior_resized is not None else np.zeros(image.shape[:2], dtype=bool)
490
+ combined_blackspot = np.zeros(image.shape[:2], dtype=bool)
491
+ blackspot_scores = []
492
+ else:
493
+ pred_classes = instances.pred_classes.numpy()
494
+ pred_masks = instances.pred_masks.numpy()
495
+ scores = instances.scores.numpy()
496
+
497
+ # Separate floor and blackspot masks
498
+ floor_indices = pred_classes == 0
499
+ blackspot_indices = pred_classes == 1
500
+
501
+ floor_masks = pred_masks[floor_indices] if np.any(floor_indices) else []
502
+ blackspot_masks = pred_masks[blackspot_indices] if np.any(blackspot_indices) else []
503
+ blackspot_scores = scores[blackspot_indices] if np.any(blackspot_indices) else []
504
+
505
+ # Combine masks
506
+ combined_floor = np.zeros(image.shape[:2], dtype=bool)
507
+ combined_blackspot = np.zeros(image.shape[:2], dtype=bool)
508
+
509
+ for mask in floor_masks:
510
+ combined_floor |= mask
511
+
512
+ for mask in blackspot_masks:
513
+ combined_blackspot |= mask
514
+
515
+ # Apply floor prior if available
516
+ if floor_prior_resized is not None:
517
+ # Combine OneFormer floor detection with MaskRCNN floor detection
518
+ combined_floor |= floor_prior_resized
519
+ # Keep only blackspots that are on floors
520
+ combined_blackspot &= combined_floor
521
+
522
+ # Create all enhanced visualizations
523
+ enhanced_views = self.create_enhanced_visualizations(image, combined_floor, combined_blackspot)
524
+
525
+ # Calculate statistics
526
+ floor_area = int(np.sum(combined_floor))
527
+ blackspot_area = int(np.sum(combined_blackspot))
528
+ coverage_percentage = (blackspot_area / floor_area * 100) if floor_area > 0 else 0
529
+
530
+ # Count individual blackspot instances
531
+ blackspot_contours, _ = cv2.findContours(
532
+ combined_blackspot.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
533
+ )
534
+ actual_detections = len([c for c in blackspot_contours if cv2.contourArea(c) > 50])
535
+
536
+ return {
537
+ 'visualization': enhanced_views['high_contrast_overlay'], # Main view
538
+ 'floor_mask': combined_floor,
539
+ 'blackspot_mask': combined_blackspot,
540
+ 'floor_area': floor_area,
541
+ 'blackspot_area': blackspot_area,
542
+ 'coverage_percentage': coverage_percentage,
543
+ 'num_detections': actual_detections,
544
+ 'avg_confidence': float(np.mean(blackspot_scores)) if len(blackspot_scores) > 0 else 0.0,
545
+ 'enhanced_views': enhanced_views # All visualization options
546
+ }
547
+ ########################################
548
+ # FIXED MAIN APPLICATION CLASS
549
+ ########################################
550
+
551
+ class NeuroNestApp:
552
+ """Main application class integrating all components - FIXED VERSION"""
553
+
554
+ def __init__(self):
555
+ self.oneformer = OneFormerManager()
556
+ self.blackspot_detector = None
557
+ self.contrast_analyzer = UniversalContrastAnalyzer()
558
+ self.initialized = False
559
+
560
+ def initialize(self, blackspot_model_path: str = "./output_floor_blackspot/model_0004999.pth"):
561
+ """Initialize all components"""
562
+ logger.info("Initializing NeuroNest application...")
563
+
564
+ # Initialize OneFormer
565
+ oneformer_success = self.oneformer.initialize()
566
+
567
+ # Initialize blackspot detector if model exists
568
+ blackspot_success = False
569
+ if os.path.exists(blackspot_model_path):
570
+ self.blackspot_detector = BlackspotDetector(blackspot_model_path)
571
+ blackspot_success = True
572
+ else:
573
+ logger.warning(f"Blackspot model not found at {blackspot_model_path}")
574
+
575
+ self.initialized = oneformer_success
576
+ return oneformer_success, blackspot_success
577
+
578
+ def analyze_image(self,
579
+ image_path: str,
580
+ blackspot_threshold: float = 0.5,
581
+ contrast_threshold: float = 4.5,
582
+ enable_blackspot: bool = True,
583
+ enable_contrast: bool = True) -> Dict:
584
+ """Perform complete image analysis - FIXED VERSION"""
585
+
586
+ if not self.initialized:
587
+ return {"error": "Application not properly initialized"}
588
+
589
+ try:
590
+ # Load and preprocess image
591
+ image = cv2.imread(image_path)
592
+ if image is None:
593
+ return {"error": "Could not load image"}
594
+
595
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
596
+ logger.info(f"Loaded image with shape: {image_rgb.shape}")
597
+
598
+ results = {
599
+ 'original_image': image_rgb,
600
+ 'segmentation': None,
601
+ 'blackspot': None,
602
+ 'contrast': None,
603
+ 'statistics': {}
604
+ }
605
+
606
+ # 1. Semantic Segmentation (always performed)
607
+ logger.info("Running semantic segmentation...")
608
+ seg_mask, seg_visualization = self.oneformer.semantic_segmentation(image_rgb)
609
+ logger.info(f"Segmentation mask shape: {seg_mask.shape}")
610
+
611
+ results['segmentation'] = {
612
+ 'visualization': seg_visualization,
613
+ 'mask': seg_mask
614
+ }
615
+
616
+ # Extract floor areas for blackspot detection
617
+ floor_prior = self.oneformer.extract_floor_areas(seg_mask)
618
+ logger.info(f"Floor prior shape: {floor_prior.shape}, total floor pixels: {np.sum(floor_prior)}")
619
+
620
+ # 2. Blackspot Detection (if enabled and model available)
621
+ if enable_blackspot and self.blackspot_detector is not None:
622
+ logger.info("Running blackspot detection...")
623
+ try:
624
+ self.blackspot_detector.initialize(threshold=blackspot_threshold)
625
+ blackspot_results = self.blackspot_detector.detect_blackspots(image_rgb, floor_prior)
626
+ results['blackspot'] = blackspot_results
627
+ logger.info("Blackspot detection completed successfully")
628
+ except Exception as e:
629
+ logger.error(f"Error in blackspot detection: {e}")
630
+ # Continue without blackspot results
631
+ results['blackspot'] = None
632
+
633
+ # 3. Contrast Analysis (if enabled)
634
+ if enable_contrast:
635
+ logger.info("Running contrast analysis...")
636
+ try:
637
+ # Use the resized image for contrast analysis to match segmentation
638
+ width = ONEFORMER_CONFIG["ADE20K"]["width"]
639
+ h, w = image_rgb.shape[:2]
640
+ if w != width:
641
+ scale = width / w
642
+ new_h = int(h * scale)
643
+ image_for_contrast = cv2.resize(image_rgb, (width, new_h))
644
+ else:
645
+ image_for_contrast = image_rgb
646
+
647
+ contrast_results = self.contrast_analyzer.analyze_contrast(image_for_contrast, seg_mask)
648
+ results['contrast'] = contrast_results
649
+ logger.info("Contrast analysis completed successfully")
650
+ except Exception as e:
651
+ logger.error(f"Error in contrast analysis: {e}")
652
+ # Continue without contrast results
653
+ results['contrast'] = None
654
+
655
+ # 4. Generate combined statistics
656
+ stats = self._generate_statistics(results)
657
+ results['statistics'] = stats
658
+
659
+ logger.info("Image analysis completed successfully")
660
+ return results
661
+
662
+ except Exception as e:
663
+ logger.error(f"Error in image analysis: {e}")
664
+ import traceback
665
+ traceback.print_exc()
666
+ return {"error": f"Analysis failed: {str(e)}"}
667
+
668
+ def _generate_statistics(self, results: Dict) -> Dict:
669
+ """Generate comprehensive statistics"""
670
+ stats = {}
671
+
672
+ # Segmentation stats
673
+ if results['segmentation']:
674
+ unique_classes = np.unique(results['segmentation']['mask'])
675
+ stats['segmentation'] = {
676
+ 'num_classes': len(unique_classes),
677
+ 'image_size': results['segmentation']['mask'].shape
678
+ }
679
+
680
+ # Blackspot stats
681
+ if results['blackspot']:
682
+ bs = results['blackspot']
683
+ stats['blackspot'] = {
684
+ 'floor_area_pixels': bs['floor_area'],
685
+ 'blackspot_area_pixels': bs['blackspot_area'],
686
+ 'coverage_percentage': bs['coverage_percentage'],
687
+ 'num_detections': bs['num_detections'],
688
+ 'avg_confidence': bs['avg_confidence']
689
+ }
690
+
691
+ # Contrast stats
692
+ if results['contrast']:
693
+ cs = results['contrast']['statistics']
694
+ # Count issues by severity
695
+ critical_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'critical')
696
+ high_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'high')
697
+ medium_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'medium')
698
+
699
+ stats['contrast'] = {
700
+ 'total_issues': cs.get('low_contrast_pairs', 0),
701
+ 'critical_issues': critical_count,
702
+ 'high_priority_issues': high_count,
703
+ 'medium_priority_issues': medium_count,
704
+ 'segments_analyzed': cs.get('total_segments', 0),
705
+ 'floor_object_issues': cs.get('floor_object_issues', 0)
706
+ }
707
+
708
+ return stats
709
+ ########################################
710
+ # GRADIO INTERFACE
711
+ ########################################
712
+
713
+ ########################################
714
+ # ENHANCED GRADIO INTERFACE WITH MULTIPLE BLACKSPOT VIEWS
715
+ ########################################
716
+
717
+ def create_gradio_interface():
718
+ """Create the enhanced Gradio interface with better blackspot visualization"""
719
+
720
+ # Initialize the application
721
+ app = NeuroNestApp()
722
+ oneformer_ok, blackspot_ok = app.initialize()
723
+
724
+ if not oneformer_ok:
725
+ raise RuntimeError("Failed to initialize OneFormer")
726
+
727
+ def analyze_wrapper(image_path, blackspot_threshold, contrast_threshold,
728
+ enable_blackspot, enable_contrast, blackspot_view_type):
729
+ """Enhanced wrapper function for Gradio interface"""
730
+ if image_path is None:
731
+ return None, None, None, None, None, "Please upload an image"
732
+
733
+ results = app.analyze_image(
734
+ image_path=image_path,
735
+ blackspot_threshold=blackspot_threshold,
736
+ contrast_threshold=contrast_threshold,
737
+ enable_blackspot=enable_blackspot,
738
+ enable_contrast=enable_contrast
739
+ )
740
+
741
+ if "error" in results:
742
+ return None, None, None, None, None, f"Error: {results['error']}"
743
+
744
+ # Extract outputs
745
+ seg_output = results['segmentation']['visualization'] if results['segmentation'] else None
746
+
747
+ # Enhanced blackspot output selection
748
+ blackspot_output = None
749
+ blackspot_segmentation = None
750
+ if results['blackspot'] and 'enhanced_views' in results['blackspot']:
751
+ views = results['blackspot']['enhanced_views']
752
+
753
+ # Select view based on user choice
754
+ if blackspot_view_type == "High Contrast":
755
+ blackspot_output = views['high_contrast_overlay']
756
+ elif blackspot_view_type == "Segmentation Only":
757
+ blackspot_output = views['segmentation_view']
758
+ elif blackspot_view_type == "Blackspots Only":
759
+ blackspot_output = views['blackspot_only']
760
+ elif blackspot_view_type == "Side by Side":
761
+ blackspot_output = views['side_by_side']
762
+ elif blackspot_view_type == "Annotated":
763
+ blackspot_output = views['annotated_view']
764
+ else:
765
+ blackspot_output = views['high_contrast_overlay']
766
+
767
+ # Always provide segmentation view for the dedicated tab
768
+ blackspot_segmentation = views['segmentation_view']
769
+
770
+ contrast_output = results['contrast']['visualization'] if results['contrast'] else None
771
+
772
+ # Generate report
773
+ report = generate_analysis_report(results)
774
+
775
+ return seg_output, blackspot_output, blackspot_segmentation, contrast_output, report
776
+
777
+ # Update the generate_analysis_report function
778
+ def generate_analysis_report(results: Dict) -> str:
779
+ """Generate enhanced analysis report text"""
780
+ report = ["# NeuroNest Analysis Report\n"]
781
+
782
+ # Segmentation results
783
+ if results['segmentation']:
784
+ stats = results['statistics'].get('segmentation', {})
785
+ report.append(f"## 🎯 Semantic Segmentation")
786
+ report.append(f"- **Objects detected:** {stats.get('num_classes', 'N/A')}")
787
+ report.append(f"- **Image size:** {stats.get('image_size', 'N/A')}")
788
+ report.append("")
789
+
790
+ # Enhanced blackspot results
791
+ if results['blackspot']:
792
+ bs_stats = results['statistics'].get('blackspot', {})
793
+ report.append(f"## ⚫ Blackspot Detection")
794
+ report.append(f"- **Floor area:** {bs_stats.get('floor_area_pixels', 0):,} pixels")
795
+ report.append(f"- **Blackspot area:** {bs_stats.get('blackspot_area_pixels', 0):,} pixels")
796
+ report.append(f"- **Coverage:** {bs_stats.get('coverage_percentage', 0):.2f}% of floor")
797
+ report.append(f"- **Individual blackspots:** {bs_stats.get('num_detections', 0)}")
798
+ report.append(f"- **Average confidence:** {bs_stats.get('avg_confidence', 0):.2f}")
799
+
800
+ # Risk assessment
801
+ coverage = bs_stats.get('coverage_percentage', 0)
802
+ if coverage > 5:
803
+ report.append(f"- **⚠️ Risk Level:** HIGH - Significant blackspot coverage detected")
804
+ elif coverage > 1:
805
+ report.append(f"- **⚠️ Risk Level:** MEDIUM - Moderate blackspot coverage")
806
+ elif coverage > 0:
807
+ report.append(f"- **✓ Risk Level:** LOW - Minimal blackspot coverage")
808
+ else:
809
+ report.append(f"- **✓ Risk Level:** NONE - No blackspots detected")
810
+ report.append("")
811
+
812
+ # Contrast analysis results (updated for universal analyzer)
813
+ if results['contrast']:
814
+ contrast_stats = results['statistics'].get('contrast', {})
815
+ report.append(f"## 🎨 Universal Contrast Analysis")
816
+ report.append(f"- **Adjacent pairs analyzed:** {results['contrast']['statistics'].get('analyzed_pairs', 0)}")
817
+ report.append(f"- **Total contrast issues:** {contrast_stats.get('total_issues', 0)}")
818
+ report.append(f"- **🔴 Critical:** {contrast_stats.get('critical_issues', 0)}")
819
+ report.append(f"- **🟠 High priority:** {contrast_stats.get('high_priority_issues', 0)}")
820
+ report.append(f"- **🟡 Medium priority:** {contrast_stats.get('medium_priority_issues', 0)}")
821
+ report.append(f"- **⚠️ Floor-object issues:** {contrast_stats.get('floor_object_issues', 0)}")
822
+ report.append("")
823
+
824
+ # Add detailed issues
825
+ issues = results['contrast'].get('issues', [])
826
+ if issues:
827
+ # Group by severity
828
+ critical_issues = [i for i in issues if i['severity'] == 'critical']
829
+ high_issues = [i for i in issues if i['severity'] == 'high']
830
+
831
+ if critical_issues:
832
+ report.append("### 🔴 Critical Issues (Immediate Attention Required)")
833
+ for issue in critical_issues[:5]: # Show top 5
834
+ cats = f"{issue['categories'][0]} ↔ {issue['categories'][1]}"
835
+ ratio = issue['wcag_ratio']
836
+ report.append(f"- **{cats}**: {ratio:.1f}:1 contrast ratio")
837
+ if issue['is_floor_object']:
838
+ report.append(f" _⚠️ Object on floor - high visibility required!_")
839
+ report.append("")
840
+
841
+ if high_issues:
842
+ report.append("### 🟠 High Priority Issues")
843
+ for issue in high_issues[:3]: # Show top 3
844
+ cats = f"{issue['categories'][0]} ↔ {issue['categories'][1]}"
845
+ ratio = issue['wcag_ratio']
846
+ report.append(f"- **{cats}**: {ratio:.1f}:1 contrast ratio")
847
+ report.append("")
848
+
849
+ # Enhanced recommendations
850
+ report.append("## 📋 Recommendations")
851
+
852
+ # Blackspot-specific recommendations
853
+ if results['blackspot']:
854
+ coverage = results['statistics'].get('blackspot', {}).get('coverage_percentage', 0)
855
+ if coverage > 0:
856
+ report.append("### Blackspot Mitigation")
857
+ report.append("- Remove or replace dark-colored floor materials in detected areas")
858
+ report.append("- Improve lighting in blackspot areas")
859
+ report.append("- Consider using light-colored rugs or mats to cover blackspots")
860
+ report.append("- Add visual cues like contrasting tape around problem areas")
861
+ report.append("")
862
+
863
+ # Contrast-specific recommendations
864
+ contrast_issues = results['statistics'].get('contrast', {}).get('total_issues', 0)
865
+ if contrast_issues > 0:
866
+ report.append("### Contrast Improvements")
867
+ report.append("- Increase lighting in low-contrast areas")
868
+ report.append("- Use contrasting colors for furniture and floors")
869
+ report.append("- Add visual markers for important boundaries")
870
+ report.append("- Consider color therapy guidelines for dementia")
871
+ report.append("")
872
+
873
+ if coverage == 0 and contrast_issues == 0:
874
+ report.append("✅ **Environment Assessment: EXCELLENT**")
875
+ report.append("No significant safety issues detected. This environment appears well-suited for individuals with Alzheimer's.")
876
+
877
+ return "\n".join(report)
878
+
879
+ # Create the interface with enhanced controls
880
+ title = "🧠 NeuroNest: Advanced Environment Analysis for Alzheimer's Care"
881
+ description = """
882
+ **Comprehensive analysis system for creating Alzheimer's-friendly environments**
883
+
884
+ This application integrates:
885
+ - **Semantic Segmentation**: Identifies rooms, furniture, and objects
886
+ - **Enhanced Blackspot Detection**: Locates and visualizes dangerous black areas on floors
887
+ - **Contrast Analysis**: Evaluates color contrast for visual accessibility
888
+ """
889
+
890
+ with gr.Blocks(
891
+ title=title,
892
+ theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
893
+ css="""
894
+ .main-header { text-align: center; margin-bottom: 2rem; }
895
+ .analysis-section { border: 2px solid #f0f0f0; border-radius: 10px; padding: 1rem; margin: 1rem 0; }
896
+ .critical-text { color: #ff0000; font-weight: bold; }
897
+ .high-text { color: #ff8800; font-weight: bold; }
898
+ .medium-text { color: #ffaa00; font-weight: bold; }
899
+ """
900
+ ) as interface:
901
+
902
+ gr.Markdown(f"# {title}")
903
+ gr.Markdown(description)
904
+
905
+ with gr.Row():
906
+ # Input Column
907
+ with gr.Column(scale=1):
908
+ # Image upload
909
+ image_input = gr.Image(
910
+ label="📸 Upload Room Image",
911
+ type="filepath",
912
+ height=300
913
+ )
914
+
915
+ # Analysis settings
916
+ with gr.Accordion("🔧 Analysis Settings", open=True):
917
+ enable_blackspot = gr.Checkbox(
918
+ value=blackspot_ok,
919
+ label="Enable Blackspot Detection",
920
+ interactive=blackspot_ok
921
+ )
922
+
923
+ blackspot_threshold = gr.Slider(
924
+ minimum=0.1,
925
+ maximum=0.9,
926
+ value=0.5,
927
+ step=0.05,
928
+ label="Blackspot Detection Threshold",
929
+ visible=blackspot_ok
930
+ )
931
+
932
+ # NEW: Blackspot visualization options
933
+ blackspot_view_type = gr.Radio(
934
+ choices=["High Contrast", "Segmentation Only", "Blackspots Only", "Side by Side", "Annotated"],
935
+ value="High Contrast",
936
+ label="Blackspot Visualization Style",
937
+ visible=blackspot_ok
938
+ )
939
+
940
+ enable_contrast = gr.Checkbox(
941
+ value=True,
942
+ label="Enable Contrast Analysis"
943
+ )
944
+
945
+ contrast_threshold = gr.Slider(
946
+ minimum=1.0,
947
+ maximum=10.0,
948
+ value=4.5,
949
+ step=0.1,
950
+ label="WCAG Contrast Threshold"
951
+ )
952
+
953
+ # Analysis button
954
+ analyze_button = gr.Button(
955
+ "🔍 Analyze Environment",
956
+ variant="primary",
957
+ size="lg"
958
+ )
959
+
960
+ # Output Column
961
+ with gr.Column(scale=2):
962
+ # Main display (Segmentation by default)
963
+ main_display = gr.Image(
964
+ label="🎯 Object Detection & Segmentation",
965
+ height=400,
966
+ interactive=False
967
+ )
968
+
969
+ # Enhanced analysis tabs
970
+ with gr.Tabs():
971
+ with gr.Tab("📊 Analysis Report"):
972
+ analysis_report = gr.Markdown(
973
+ value="Upload an image and click 'Analyze Environment' to see results.",
974
+ elem_classes=["analysis-section"]
975
+ )
976
+
977
+ if blackspot_ok:
978
+ with gr.Tab("⚫ Blackspot Detection"):
979
+ blackspot_display = gr.Image(
980
+ label="Blackspot Analysis (Selected View)",
981
+ height=300,
982
+ interactive=False
983
+ )
984
+
985
+ with gr.Tab("🔍 Blackspot Segmentation"):
986
+ blackspot_segmentation_display = gr.Image(
987
+ label="Pure Blackspot Segmentation",
988
+ height=300,
989
+ interactive=False
990
+ )
991
+ else:
992
+ blackspot_display = gr.Image(visible=False)
993
+ blackspot_segmentation_display = gr.Image(visible=False)
994
+
995
+ with gr.Tab("🎨 Contrast Analysis"):
996
+ contrast_display = gr.Image(
997
+ label="Contrast Issues Visualization",
998
+ height=300,
999
+ interactive=False
1000
+ )
1001
+
1002
+ # Connect the interface
1003
+ analyze_button.click(
1004
+ fn=analyze_wrapper,
1005
+ inputs=[
1006
+ image_input,
1007
+ blackspot_threshold,
1008
+ contrast_threshold,
1009
+ enable_blackspot,
1010
+ enable_contrast,
1011
+ blackspot_view_type
1012
+ ],
1013
+ outputs=[
1014
+ main_display,
1015
+ blackspot_display,
1016
+ blackspot_segmentation_display,
1017
+ contrast_display,
1018
+ analysis_report
1019
+ ]
1020
+ )
1021
+
1022
+ # Example images (optional)
1023
+ example_dir = Path("examples")
1024
+ if example_dir.exists():
1025
+ examples = [
1026
+ [str(img), 0.5, 4.5, True, True, "High Contrast"]
1027
+ for img in example_dir.glob("*.jpg")
1028
+ ]
1029
+
1030
+ if examples:
1031
+ gr.Examples(
1032
+ examples=examples[:3], # Show max 3 examples
1033
+ inputs=[
1034
+ image_input,
1035
+ blackspot_threshold,
1036
+ contrast_threshold,
1037
+ enable_blackspot,
1038
+ enable_contrast,
1039
+ blackspot_view_type
1040
+ ],
1041
+ outputs=[
1042
+ main_display,
1043
+ blackspot_display,
1044
+ blackspot_segmentation_display,
1045
+ contrast_display,
1046
+ analysis_report
1047
+ ],
1048
+ fn=analyze_wrapper,
1049
+ label="🖼️ Example Images"
1050
+ )
1051
+
1052
+ # Footer
1053
+ gr.Markdown("""
1054
+ ---
1055
+ **NeuroNest** - Advanced AI for Alzheimer's-friendly environments
1056
+ *Helping create safer, more accessible spaces for cognitive health*
1057
+ """)
1058
+
1059
+ return interface
1060
+
1061
+ ###############################
1062
+ # MAIN EXECUTION - FIXED
1063
+ ########################################
1064
+
1065
+ if __name__ == "__main__":
1066
+ print(f"🚀 Starting NeuroNest on {DEVICE}")
1067
+ print(f"OneFormer available: {ONEFORMER_AVAILABLE}")
1068
+
1069
+ try:
1070
+ interface = create_gradio_interface()
1071
+
1072
+ # Fixed launch call - removed incompatible parameters
1073
+ interface.queue(max_size=10).launch(
1074
+ server_name="0.0.0.0",
1075
+ server_port=7860,
1076
+ share=True
1077
+ )
1078
+ except Exception as e:
1079
+ logger.error(f"Failed to launch application: {e}")
1080
+ raise
oneformer/.DS_Store ADDED
Binary file (6.15 kB). View file
 
oneformer/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import data # register all new datasets
3
+ from . import modeling
4
+
5
+ # config
6
+ from .config import *
7
+
8
+ # models
9
+ from .oneformer_model import OneFormer
oneformer/config.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ from detectron2.config import CfgNode as CN
4
+
5
+ __all__ = ["add_common_config", "add_oneformer_config", "add_swin_config",
6
+ "add_dinat_config", "add_beit_adapter_config", "add_convnext_config"]
7
+
8
+ def add_common_config(cfg):
9
+ """
10
+ Add config for common configuration
11
+ """
12
+ # data config
13
+ # select the dataset mapper
14
+ cfg.INPUT.DATASET_MAPPER_NAME = "oneformer_unified"
15
+ # Color augmentation
16
+ cfg.INPUT.COLOR_AUG_SSD = False
17
+ # We retry random cropping until no single category in semantic segmentation GT occupies more
18
+ # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
19
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
20
+ # Pad image and segmentation GT in dataset mapper.
21
+ cfg.INPUT.SIZE_DIVISIBILITY = -1
22
+
23
+ cfg.INPUT.TASK_SEQ_LEN = 77
24
+ cfg.INPUT.MAX_SEQ_LEN = 77
25
+
26
+ cfg.INPUT.TASK_PROB = CN()
27
+ cfg.INPUT.TASK_PROB.SEMANTIC = 0.33
28
+ cfg.INPUT.TASK_PROB.INSTANCE = 0.66
29
+
30
+ # test dataset
31
+ cfg.DATASETS.TEST_PANOPTIC = ("",)
32
+ cfg.DATASETS.TEST_INSTANCE = ("",)
33
+ cfg.DATASETS.TEST_SEMANTIC = ("",)
34
+
35
+ # solver config
36
+ # weight decay on embedding
37
+ cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
38
+ # optimizer
39
+ cfg.SOLVER.OPTIMIZER = "ADAMW"
40
+ cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
41
+
42
+ # wandb
43
+ cfg.WANDB = CN()
44
+ cfg.WANDB.PROJECT = "unified_dense_recognition"
45
+ cfg.WANDB.NAME = None
46
+
47
+ cfg.MODEL.IS_TRAIN = False
48
+ cfg.MODEL.IS_DEMO = True
49
+
50
+ # text encoder config
51
+ cfg.MODEL.TEXT_ENCODER = CN()
52
+
53
+ cfg.MODEL.TEXT_ENCODER.WIDTH = 256
54
+ cfg.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
55
+ cfg.MODEL.TEXT_ENCODER.NUM_LAYERS = 12
56
+ cfg.MODEL.TEXT_ENCODER.VOCAB_SIZE = 49408
57
+ cfg.MODEL.TEXT_ENCODER.PROJ_NUM_LAYERS = 2
58
+ cfg.MODEL.TEXT_ENCODER.N_CTX = 16
59
+
60
+ # mask_former inference config
61
+ cfg.MODEL.TEST = CN()
62
+ cfg.MODEL.TEST.SEMANTIC_ON = True
63
+ cfg.MODEL.TEST.INSTANCE_ON = False
64
+ cfg.MODEL.TEST.PANOPTIC_ON = False
65
+ cfg.MODEL.TEST.DETECTION_ON = False
66
+ cfg.MODEL.TEST.OBJECT_MASK_THRESHOLD = 0.0
67
+ cfg.MODEL.TEST.OVERLAP_THRESHOLD = 0.0
68
+ cfg.MODEL.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
69
+ cfg.MODEL.TEST.TASK = "panoptic"
70
+
71
+ # TEST AUG Slide
72
+ cfg.TEST.AUG.IS_SLIDE = False
73
+ cfg.TEST.AUG.CROP_SIZE = (640, 640)
74
+ cfg.TEST.AUG.STRIDE = (426, 426)
75
+ cfg.TEST.AUG.SCALE = (2048, 640)
76
+ cfg.TEST.AUG.SETR_MULTI_SCALE = True
77
+ cfg.TEST.AUG.KEEP_RATIO = True
78
+ cfg.TEST.AUG.SIZE_DIVISOR = 32
79
+
80
+ # pixel decoder config
81
+ cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
82
+ # adding transformer in pixel decoder
83
+ cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
84
+ # pixel decoder
85
+ cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
86
+ cfg.MODEL.SEM_SEG_HEAD.SEM_EMBED_DIM = 256
87
+ cfg.MODEL.SEM_SEG_HEAD.INST_EMBED_DIM = 256
88
+
89
+ # LSJ aug
90
+ cfg.INPUT.IMAGE_SIZE = 1024
91
+ cfg.INPUT.MIN_SCALE = 0.1
92
+ cfg.INPUT.MAX_SCALE = 2.0
93
+
94
+ # MSDeformAttn encoder configs
95
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
96
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
97
+ cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
98
+
99
+ def add_oneformer_config(cfg):
100
+ """
101
+ Add config for ONE_FORMER.
102
+ """
103
+
104
+ # mask_former model config
105
+ cfg.MODEL.ONE_FORMER = CN()
106
+
107
+ # loss
108
+ cfg.MODEL.ONE_FORMER.DEEP_SUPERVISION = True
109
+ cfg.MODEL.ONE_FORMER.NO_OBJECT_WEIGHT = 0.1
110
+ cfg.MODEL.ONE_FORMER.CLASS_WEIGHT = 1.0
111
+ cfg.MODEL.ONE_FORMER.DICE_WEIGHT = 1.0
112
+ cfg.MODEL.ONE_FORMER.MASK_WEIGHT = 20.0
113
+ cfg.MODEL.ONE_FORMER.CONTRASTIVE_WEIGHT = 0.5
114
+ cfg.MODEL.ONE_FORMER.CONTRASTIVE_TEMPERATURE = 0.07
115
+
116
+ # transformer config
117
+ cfg.MODEL.ONE_FORMER.NHEADS = 8
118
+ cfg.MODEL.ONE_FORMER.DROPOUT = 0.1
119
+ cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD = 2048
120
+ cfg.MODEL.ONE_FORMER.ENC_LAYERS = 0
121
+ cfg.MODEL.ONE_FORMER.CLASS_DEC_LAYERS = 2
122
+ cfg.MODEL.ONE_FORMER.DEC_LAYERS = 6
123
+ cfg.MODEL.ONE_FORMER.PRE_NORM = False
124
+
125
+ cfg.MODEL.ONE_FORMER.HIDDEN_DIM = 256
126
+ cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES = 120
127
+ cfg.MODEL.ONE_FORMER.NUM_OBJECT_CTX = 16
128
+ cfg.MODEL.ONE_FORMER.USE_TASK_NORM = True
129
+
130
+ cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE = "res5"
131
+ cfg.MODEL.ONE_FORMER.ENFORCE_INPUT_PROJ = False
132
+
133
+ # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
134
+ # you can use this config to override
135
+ cfg.MODEL.ONE_FORMER.SIZE_DIVISIBILITY = 32
136
+
137
+ # transformer module
138
+ cfg.MODEL.ONE_FORMER.TRANSFORMER_DECODER_NAME = "ContrastiveMultiScaleMaskedTransformerDecoder"
139
+
140
+ # point loss configs
141
+ # Number of points sampled during training for a mask point head.
142
+ cfg.MODEL.ONE_FORMER.TRAIN_NUM_POINTS = 112 * 112
143
+ # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
144
+ # original paper.
145
+ cfg.MODEL.ONE_FORMER.OVERSAMPLE_RATIO = 3.0
146
+ # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
147
+ # the original paper.
148
+ cfg.MODEL.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
149
+
150
+ def add_swin_config(cfg):
151
+ """
152
+ Add config forSWIN Backbone.
153
+ """
154
+
155
+ # swin transformer backbone
156
+ cfg.MODEL.SWIN = CN()
157
+ cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
158
+ cfg.MODEL.SWIN.PATCH_SIZE = 4
159
+ cfg.MODEL.SWIN.EMBED_DIM = 96
160
+ cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
161
+ cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
162
+ cfg.MODEL.SWIN.WINDOW_SIZE = 7
163
+ cfg.MODEL.SWIN.MLP_RATIO = 4.0
164
+ cfg.MODEL.SWIN.QKV_BIAS = True
165
+ cfg.MODEL.SWIN.QK_SCALE = None
166
+ cfg.MODEL.SWIN.DROP_RATE = 0.0
167
+ cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
168
+ cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
169
+ cfg.MODEL.SWIN.APE = False
170
+ cfg.MODEL.SWIN.PATCH_NORM = True
171
+ cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
172
+ cfg.MODEL.SWIN.USE_CHECKPOINT = False
173
+ ## Semask additions
174
+ cfg.MODEL.SWIN.SEM_WINDOW_SIZE = 7
175
+ cfg.MODEL.SWIN.NUM_SEM_BLOCKS = 1
176
+
177
+ def add_dinat_config(cfg):
178
+ """
179
+ Add config for NAT Backbone.
180
+ """
181
+
182
+ # DINAT transformer backbone
183
+ cfg.MODEL.DiNAT = CN()
184
+ cfg.MODEL.DiNAT.DEPTHS = [3, 4, 18, 5]
185
+ cfg.MODEL.DiNAT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
186
+ cfg.MODEL.DiNAT.EMBED_DIM = 64
187
+ cfg.MODEL.DiNAT.MLP_RATIO = 3.0
188
+ cfg.MODEL.DiNAT.NUM_HEADS = [2, 4, 8, 16]
189
+ cfg.MODEL.DiNAT.DROP_PATH_RATE = 0.2
190
+ cfg.MODEL.DiNAT.KERNEL_SIZE = 7
191
+ cfg.MODEL.DiNAT.DILATIONS = [[1, 16, 1], [1, 4, 1, 8], [1, 2, 1, 3, 1, 4], [1, 2, 1, 2, 1]]
192
+ cfg.MODEL.DiNAT.OUT_INDICES = (0, 1, 2, 3)
193
+ cfg.MODEL.DiNAT.QKV_BIAS = True
194
+ cfg.MODEL.DiNAT.QK_SCALE = None
195
+ cfg.MODEL.DiNAT.DROP_RATE = 0
196
+ cfg.MODEL.DiNAT.ATTN_DROP_RATE = 0.
197
+ cfg.MODEL.DiNAT.IN_PATCH_SIZE = 4
198
+
199
+ def add_convnext_config(cfg):
200
+ """
201
+ Add config for ConvNeXt Backbone.
202
+ """
203
+
204
+ # swin transformer backbone
205
+ cfg.MODEL.CONVNEXT = CN()
206
+ cfg.MODEL.CONVNEXT.IN_CHANNELS = 3
207
+ cfg.MODEL.CONVNEXT.DEPTHS = [3, 3, 27, 3]
208
+ cfg.MODEL.CONVNEXT.DIMS = [192, 384, 768, 1536]
209
+ cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.4
210
+ cfg.MODEL.CONVNEXT.LSIT = 1.0
211
+ cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3]
212
+ cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
213
+
214
+ def add_beit_adapter_config(cfg):
215
+ """
216
+ Add config for BEiT Adapter Backbone.
217
+ """
218
+
219
+ # beit adapter backbone
220
+ cfg.MODEL.BEiTAdapter = CN()
221
+ cfg.MODEL.BEiTAdapter.IMG_SIZE = 640
222
+ cfg.MODEL.BEiTAdapter.PATCH_SIZE = 16
223
+ cfg.MODEL.BEiTAdapter.EMBED_DIM = 1024
224
+ cfg.MODEL.BEiTAdapter.DEPTH = 24
225
+ cfg.MODEL.BEiTAdapter.NUM_HEADS = 16
226
+ cfg.MODEL.BEiTAdapter.MLP_RATIO = 4
227
+ cfg.MODEL.BEiTAdapter.QKV_BIAS = True
228
+ cfg.MODEL.BEiTAdapter.USE_ABS_POS_EMB = False
229
+ cfg.MODEL.BEiTAdapter.USE_REL_POS_BIAS = True
230
+ cfg.MODEL.BEiTAdapter.INIT_VALUES = 1e-6
231
+ cfg.MODEL.BEiTAdapter.DROP_PATH_RATE = 0.3
232
+ cfg.MODEL.BEiTAdapter.CONV_INPLANE = 64
233
+ cfg.MODEL.BEiTAdapter.N_POINTS = 4
234
+ cfg.MODEL.BEiTAdapter.DEFORM_NUM_HEADS = 16
235
+ cfg.MODEL.BEiTAdapter.CFFN_RATIO = 0.25
236
+ cfg.MODEL.BEiTAdapter.DEFORM_RATIO = 0.5
237
+ cfg.MODEL.BEiTAdapter.WITH_CP = True
238
+ cfg.MODEL.BEiTAdapter.INTERACTION_INDEXES=[[0, 5], [6, 11], [12, 17], [18, 23]]
239
+ cfg.MODEL.BEiTAdapter.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
oneformer/data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from . import datasets
oneformer/data/bpe_simple_vocab_16e6.txt ADDED
The diff for this file is too large to render. See raw diff
 
oneformer/data/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
oneformer/data/build.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+ import torch.utils.data as torchdata
4
+
5
+ from detectron2.config import configurable
6
+
7
+
8
+ from detectron2.data.common import DatasetFromList, MapDataset
9
+ from detectron2.data.dataset_mapper import DatasetMapper
10
+ from detectron2.data.samplers import (
11
+ InferenceSampler,
12
+ )
13
+ from detectron2.data.build import (
14
+ get_detection_dataset_dicts,
15
+ trivial_batch_collator
16
+ )
17
+ """
18
+ This file contains the default logic to build a dataloader for training or testing.
19
+ """
20
+
21
+ __all__ = [
22
+ "build_detection_test_loader",
23
+ ]
24
+
25
+
26
+ def _test_loader_from_config(cfg, dataset_name, mapper=None):
27
+ """
28
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
29
+ standard practice is to evaluate each test set individually (not combining them).
30
+ """
31
+ if isinstance(dataset_name, str):
32
+ dataset_name = [dataset_name]
33
+
34
+ dataset = get_detection_dataset_dicts(
35
+ dataset_name,
36
+ filter_empty=False,
37
+ proposal_files=[
38
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
39
+ ]
40
+ if cfg.MODEL.LOAD_PROPOSALS
41
+ else None,
42
+ )
43
+ if mapper is None:
44
+ mapper = DatasetMapper(cfg, False)
45
+ return {
46
+ "dataset": dataset,
47
+ "mapper": mapper,
48
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
49
+ "sampler": InferenceSampler(len(dataset))
50
+ if not isinstance(dataset, torchdata.IterableDataset)
51
+ else None,
52
+ }
53
+
54
+
55
+ @configurable(from_config=_test_loader_from_config)
56
+ def build_detection_test_loader(
57
+ dataset: Union[List[Any], torchdata.Dataset],
58
+ *,
59
+ mapper: Callable[[Dict[str, Any]], Any],
60
+ sampler: Optional[torchdata.Sampler] = None,
61
+ batch_size: int = 1,
62
+ num_workers: int = 0,
63
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
64
+ ) -> torchdata.DataLoader:
65
+ """
66
+ Similar to `build_detection_train_loader`, with default batch size = 1,
67
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
68
+ to produce the exact set of all samples.
69
+
70
+ Args:
71
+ dataset: a list of dataset dicts,
72
+ or a pytorch dataset (either map-style or iterable). They can be obtained
73
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
74
+ mapper: a callable which takes a sample (dict) from dataset
75
+ and returns the format to be consumed by the model.
76
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
77
+ sampler: a sampler that produces
78
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
79
+ which splits the dataset across all workers. Sampler must be None
80
+ if `dataset` is iterable.
81
+ batch_size: the batch size of the data loader to be created.
82
+ Default to 1 image per worker since this is the standard when reporting
83
+ inference time in papers.
84
+ num_workers: number of parallel data loading workers
85
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
86
+ Defaults to do no collation and return a list of data.
87
+
88
+ Returns:
89
+ DataLoader: a torch DataLoader, that loads the given detection
90
+ dataset, with test-time transformation and batching.
91
+
92
+ Examples:
93
+ ::
94
+ data_loader = build_detection_test_loader(
95
+ DatasetRegistry.get("my_test"),
96
+ mapper=DatasetMapper(...))
97
+
98
+ # or, instantiate with a CfgNode:
99
+ data_loader = build_detection_test_loader(cfg, "my_test")
100
+ """
101
+ if isinstance(dataset, list):
102
+ dataset = DatasetFromList(dataset, copy=False)
103
+ if mapper is not None:
104
+ dataset = MapDataset(dataset, mapper)
105
+ if isinstance(dataset, torchdata.IterableDataset):
106
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
107
+ else:
108
+ if sampler is None:
109
+ sampler = InferenceSampler(len(dataset))
110
+ return torchdata.DataLoader(
111
+ dataset,
112
+ batch_size=batch_size,
113
+ sampler=sampler,
114
+ drop_last=False,
115
+ num_workers=num_workers,
116
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
117
+ )
oneformer/data/dataset_mappers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from detectron2.data import MetadataCatalog
13
+ from detectron2.config import configurable
14
+ from detectron2.data import detection_utils as utils
15
+ from detectron2.data import transforms as T
16
+ from detectron2.structures import BitMasks, Instances
17
+ from oneformer.utils.box_ops import masks_to_boxes
18
+ from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
19
+
20
+ __all__ = ["COCOUnifiedNewBaselineDatasetMapper"]
21
+
22
+
23
+ def build_transform_gen(cfg, is_train):
24
+ """
25
+ Create a list of default :class:`Augmentation` from config.
26
+ Now it includes resizing and flipping.
27
+ Returns:
28
+ list[Augmentation]
29
+ """
30
+ assert is_train, "Only support training augmentation"
31
+ image_size = cfg.INPUT.IMAGE_SIZE
32
+ min_scale = cfg.INPUT.MIN_SCALE
33
+ max_scale = cfg.INPUT.MAX_SCALE
34
+
35
+ augmentation = []
36
+
37
+ if cfg.INPUT.RANDOM_FLIP != "none":
38
+ augmentation.append(
39
+ T.RandomFlip(
40
+ horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
41
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
42
+ )
43
+ )
44
+
45
+ augmentation.extend([
46
+ T.ResizeScale(
47
+ min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
48
+ ),
49
+ T.FixedSizeCrop(crop_size=(image_size, image_size)),
50
+ ])
51
+
52
+ return augmentation
53
+
54
+
55
+ # This is specifically designed for the COCO dataset.
56
+ class COCOUnifiedNewBaselineDatasetMapper:
57
+ """
58
+ A callable which takes a dataset dict in Detectron2 Dataset format,
59
+ and map it into a format used by OneFormer.
60
+
61
+ This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
62
+
63
+ The callable currently does the following:
64
+
65
+ 1. Read the image from "file_name"
66
+ 2. Applies geometric transforms to the image and annotation
67
+ 3. Find and applies suitable cropping to the image and annotation
68
+ 4. Prepare image and annotation to Tensors
69
+ """
70
+
71
+ @configurable
72
+ def __init__(
73
+ self,
74
+ is_train=True,
75
+ *,
76
+ num_queries,
77
+ tfm_gens,
78
+ meta,
79
+ image_format,
80
+ max_seq_len,
81
+ task_seq_len,
82
+ semantic_prob,
83
+ instance_prob,
84
+ ):
85
+ """
86
+ NOTE: this interface is experimental.
87
+ Args:
88
+ is_train: for training or inference
89
+ augmentations: a list of augmentations or deterministic transforms to apply
90
+ crop_gen: crop augmentation
91
+ tfm_gens: data augmentation
92
+ image_format: an image format supported by :func:`detection_utils.read_image`.
93
+ """
94
+ self.tfm_gens = tfm_gens
95
+ logging.getLogger(__name__).info(
96
+ "[COCOUnifiedNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
97
+ str(self.tfm_gens)
98
+ )
99
+ )
100
+
101
+ self.img_format = image_format
102
+ self.is_train = is_train
103
+ self.meta = meta
104
+ self.ignore_label = self.meta.ignore_label
105
+ self.num_queries = num_queries
106
+
107
+ self.things = []
108
+ for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
109
+ self.things.append(v)
110
+ self.class_names = self.meta.stuff_classes
111
+ self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
112
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
113
+ self.semantic_prob = semantic_prob
114
+ self.instance_prob = instance_prob
115
+
116
+ @classmethod
117
+ def from_config(cls, cfg, is_train=True):
118
+ # Build augmentation
119
+ tfm_gens = build_transform_gen(cfg, is_train)
120
+ dataset_names = cfg.DATASETS.TRAIN
121
+ meta = MetadataCatalog.get(dataset_names[0])
122
+
123
+ ret = {
124
+ "is_train": is_train,
125
+ "meta": meta,
126
+ "tfm_gens": tfm_gens,
127
+ "image_format": cfg.INPUT.FORMAT,
128
+ "num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
129
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
130
+ "max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
131
+ "semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
132
+ "instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
133
+ }
134
+ return ret
135
+
136
+ def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
137
+ instances = Instances(image_shape)
138
+
139
+ classes = []
140
+ texts = ["a semantic photo"] * self.num_queries
141
+ masks = []
142
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
143
+
144
+ for segment_info in segments_info:
145
+ class_id = segment_info["category_id"]
146
+ if not segment_info["iscrowd"]:
147
+ mask = pan_seg_gt == segment_info["id"]
148
+ if not np.all(mask == False):
149
+ if class_id not in classes:
150
+ cls_name = self.class_names[class_id]
151
+ classes.append(class_id)
152
+ masks.append(mask)
153
+ num_class_obj[cls_name] += 1
154
+ else:
155
+ idx = classes.index(class_id)
156
+ masks[idx] += mask
157
+ masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
158
+ label[mask] = class_id
159
+
160
+ num = 0
161
+ for i, cls_name in enumerate(self.class_names):
162
+ if num_class_obj[cls_name] > 0:
163
+ for _ in range(num_class_obj[cls_name]):
164
+ if num >= len(texts):
165
+ break
166
+ texts[num] = f"a photo with a {cls_name}"
167
+ num += 1
168
+
169
+ classes = np.array(classes)
170
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
171
+ if len(masks) == 0:
172
+ # Some image does not have annotation (all ignored)
173
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
174
+ instances.gt_bboxes = torch.zeros((0, 4))
175
+ else:
176
+ masks = BitMasks(
177
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
178
+ )
179
+ instances.gt_masks = masks.tensor
180
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
181
+ instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
182
+ return instances, texts, label
183
+
184
+ def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
185
+ instances = Instances(image_shape)
186
+
187
+ classes = []
188
+ texts = ["an instance photo"] * self.num_queries
189
+ masks = []
190
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
191
+
192
+ for segment_info in segments_info:
193
+ class_id = segment_info["category_id"]
194
+ if class_id in self.things:
195
+ if not segment_info["iscrowd"]:
196
+ mask = pan_seg_gt == segment_info["id"]
197
+ if not np.all(mask == False):
198
+ cls_name = self.class_names[class_id]
199
+ classes.append(class_id)
200
+ masks.append(mask)
201
+ num_class_obj[cls_name] += 1
202
+ label[mask] = class_id
203
+
204
+ num = 0
205
+ for i, cls_name in enumerate(self.class_names):
206
+ if num_class_obj[cls_name] > 0:
207
+ for _ in range(num_class_obj[cls_name]):
208
+ if num >= len(texts):
209
+ break
210
+ texts[num] = f"a photo with a {cls_name}"
211
+ num += 1
212
+
213
+ classes = np.array(classes)
214
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
215
+ if len(masks) == 0:
216
+ # Some image does not have annotation (all ignored)
217
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
218
+ instances.gt_bboxes = torch.zeros((0, 4))
219
+ else:
220
+ masks = BitMasks(
221
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
222
+ )
223
+ instances.gt_masks = masks.tensor
224
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
225
+ return instances, texts, label
226
+
227
+ def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
228
+ instances = Instances(image_shape)
229
+
230
+ classes = []
231
+ texts = ["a panoptic photo"] * self.num_queries
232
+ masks = []
233
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
234
+
235
+ for segment_info in segments_info:
236
+ class_id = segment_info["category_id"]
237
+ if not segment_info["iscrowd"]:
238
+ mask = pan_seg_gt == segment_info["id"]
239
+ if not np.all(mask == False):
240
+ cls_name = self.class_names[class_id]
241
+ classes.append(class_id)
242
+ masks.append(mask)
243
+ num_class_obj[cls_name] += 1
244
+ label[mask] = class_id
245
+
246
+ num = 0
247
+ for i, cls_name in enumerate(self.class_names):
248
+ if num_class_obj[cls_name] > 0:
249
+ for _ in range(num_class_obj[cls_name]):
250
+ if num >= len(texts):
251
+ break
252
+ texts[num] = f"a photo with a {cls_name}"
253
+ num += 1
254
+
255
+ classes = np.array(classes)
256
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
257
+ if len(masks) == 0:
258
+ # Some image does not have annotation (all ignored)
259
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
260
+ instances.gt_bboxes = torch.zeros((0, 4))
261
+ else:
262
+ masks = BitMasks(
263
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
264
+ )
265
+ instances.gt_masks = masks.tensor
266
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
267
+ for i in range(instances.gt_classes.shape[0]):
268
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
269
+ if instances.gt_classes[i].item() not in self.things:
270
+ instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
271
+ return instances, texts, label
272
+
273
+ def __call__(self, dataset_dict):
274
+ """
275
+ Args:
276
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
277
+
278
+ Returns:
279
+ dict: a format that builtin models in detectron2 accept
280
+ """
281
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
282
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
283
+ utils.check_image_size(dataset_dict, image)
284
+
285
+ image, transforms = T.apply_transform_gens(self.tfm_gens, image)
286
+ image_shape = image.shape[:2] # h, w
287
+
288
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
289
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
290
+ # Therefore it's important to use torch.Tensor.
291
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
292
+
293
+ if not self.is_train:
294
+ # USER: Modify this if you want to keep them for some reason.
295
+ dataset_dict.pop("annotations", None)
296
+ return dataset_dict
297
+
298
+ # semantic segmentation
299
+ if "sem_seg_file_name" in dataset_dict:
300
+ # PyTorch transformation not implemented for uint16, so converting it to double first
301
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
302
+ sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
303
+ else:
304
+ sem_seg_gt = None
305
+
306
+ if "pan_seg_file_name" in dataset_dict:
307
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
308
+ segments_info = dataset_dict["segments_info"]
309
+
310
+ # apply the same transformation to panoptic segmentation
311
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
312
+
313
+ from panopticapi.utils import rgb2id
314
+ pan_seg_gt = rgb2id(pan_seg_gt)
315
+
316
+ prob_task = np.random.uniform(0,1.)
317
+
318
+ num_class_obj = {}
319
+
320
+ for name in self.class_names:
321
+ num_class_obj[name] = 0
322
+
323
+ if prob_task < self.semantic_prob:
324
+ task = "The task is semantic"
325
+ instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
326
+ elif prob_task < self.instance_prob:
327
+ task = "The task is instance"
328
+ instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
329
+ else:
330
+ task = "The task is panoptic"
331
+ instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
332
+
333
+
334
+ dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
335
+ dataset_dict["instances"] = instances
336
+ dataset_dict["orig_shape"] = image_shape
337
+ dataset_dict["task"] = task
338
+ dataset_dict["text"] = text
339
+ dataset_dict["thing_ids"] = self.things
340
+
341
+ return dataset_dict
oneformer/data/dataset_mappers/dataset_mapper.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+ import numpy as np
9
+ from typing import List, Optional, Union
10
+ import torch
11
+
12
+ from detectron2.config import configurable
13
+
14
+ from detectron2.data import detection_utils as utils
15
+ from detectron2.data import transforms as T
16
+ from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
17
+
18
+ __all__ = ["DatasetMapper"]
19
+
20
+
21
+ class DatasetMapper:
22
+ """
23
+ A callable which takes a dataset dict in Detectron2 Dataset format,
24
+ and map it into a format used by the model.
25
+
26
+ This is the default callable to be used to map your dataset dict into training data.
27
+ You may need to follow it to implement your own one for customized logic,
28
+ such as a different way to read or transform images.
29
+ See :doc:`/tutorials/data_loading` for details.
30
+
31
+ The callable currently does the following:
32
+
33
+ 1. Read the image from "file_name"
34
+ 2. Applies cropping/geometric transforms to the image and annotations
35
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
36
+ """
37
+
38
+ @configurable
39
+ def __init__(
40
+ self,
41
+ is_train: bool,
42
+ *,
43
+ augmentations: List[Union[T.Augmentation, T.Transform]],
44
+ image_format: str,
45
+ task_seq_len: int,
46
+ task: str = "panoptic",
47
+ use_instance_mask: bool = False,
48
+ use_keypoint: bool = False,
49
+ instance_mask_format: str = "polygon",
50
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
51
+ precomputed_proposal_topk: Optional[int] = None,
52
+ recompute_boxes: bool = False,
53
+ ):
54
+ """
55
+ NOTE: this interface is experimental.
56
+
57
+ Args:
58
+ is_train: whether it's used in training or inference
59
+ augmentations: a list of augmentations or deterministic transforms to apply
60
+ image_format: an image format supported by :func:`detection_utils.read_image`.
61
+ use_instance_mask: whether to process instance segmentation annotations, if available
62
+ use_keypoint: whether to process keypoint annotations if available
63
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
64
+ masks into this format.
65
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
66
+ precomputed_proposal_topk: if given, will load pre-computed
67
+ proposals from dataset_dict and keep the top k proposals for each image.
68
+ recompute_boxes: whether to overwrite bounding box annotations
69
+ by computing tight bounding boxes from instance mask annotations.
70
+ """
71
+ if recompute_boxes:
72
+ assert use_instance_mask, "recompute_boxes requires instance masks"
73
+ # fmt: off
74
+ self.is_train = is_train
75
+ self.augmentations = T.AugmentationList(augmentations)
76
+ self.image_format = image_format
77
+ self.use_instance_mask = use_instance_mask
78
+ self.instance_mask_format = instance_mask_format
79
+ self.use_keypoint = use_keypoint
80
+ self.keypoint_hflip_indices = keypoint_hflip_indices
81
+ self.proposal_topk = precomputed_proposal_topk
82
+ self.recompute_boxes = recompute_boxes
83
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
84
+ self.task = task
85
+ assert self.task in ["panoptic", "semantic", "instance"]
86
+
87
+ # fmt: on
88
+ logger = logging.getLogger(__name__)
89
+ mode = "training" if is_train else "inference"
90
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
91
+
92
+ @classmethod
93
+ def from_config(cls, cfg, is_train: bool = True):
94
+ augs = utils.build_augmentation(cfg, is_train)
95
+ if cfg.INPUT.CROP.ENABLED and is_train:
96
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
97
+ recompute_boxes = cfg.MODEL.MASK_ON
98
+ else:
99
+ recompute_boxes = False
100
+
101
+ ret = {
102
+ "is_train": is_train,
103
+ "augmentations": augs,
104
+ "image_format": cfg.INPUT.FORMAT,
105
+ "use_instance_mask": cfg.MODEL.MASK_ON,
106
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
107
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
108
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
109
+ "recompute_boxes": recompute_boxes,
110
+ "task": cfg.MODEL.TEST.TASK,
111
+ }
112
+
113
+ if cfg.MODEL.KEYPOINT_ON:
114
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
115
+
116
+ if cfg.MODEL.LOAD_PROPOSALS:
117
+ ret["precomputed_proposal_topk"] = (
118
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
119
+ if is_train
120
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
121
+ )
122
+ return ret
123
+
124
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
125
+ # USER: Modify this if you want to keep them for some reason.
126
+ for anno in dataset_dict["annotations"]:
127
+ if not self.use_instance_mask:
128
+ anno.pop("segmentation", None)
129
+ if not self.use_keypoint:
130
+ anno.pop("keypoints", None)
131
+
132
+ # USER: Implement additional transformations if you have other types of data
133
+ annos = [
134
+ utils.transform_instance_annotations(
135
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
136
+ )
137
+ for obj in dataset_dict.pop("annotations")
138
+ if obj.get("iscrowd", 0) == 0
139
+ ]
140
+ instances = utils.annotations_to_instances(
141
+ annos, image_shape, mask_format=self.instance_mask_format
142
+ )
143
+
144
+ # After transforms such as cropping are applied, the bounding box may no longer
145
+ # tightly bound the object. As an example, imagine a triangle object
146
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
147
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
148
+ # the intersection of original bounding box and the cropping box.
149
+ if self.recompute_boxes:
150
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
151
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
152
+
153
+ def __call__(self, dataset_dict):
154
+ """
155
+ Args:
156
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
157
+
158
+ Returns:
159
+ dict: a format that builtin models in detectron2 accept
160
+ """
161
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
162
+ # USER: Write your own image loading if it's not from a file
163
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
164
+ utils.check_image_size(dataset_dict, image)
165
+
166
+ task = f"The task is {self.task}"
167
+ dataset_dict["task"] = task
168
+
169
+ # USER: Remove if you don't do semantic/panoptic segmentation.
170
+ if "sem_seg_file_name" in dataset_dict:
171
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
172
+ else:
173
+ sem_seg_gt = None
174
+
175
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
176
+ transforms = self.augmentations(aug_input)
177
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
178
+
179
+ image_shape = image.shape[:2] # h, w
180
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
181
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
182
+ # Therefore it's important to use torch.Tensor.
183
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
184
+ if sem_seg_gt is not None:
185
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
186
+
187
+ # USER: Remove if you don't use pre-computed proposals.
188
+ # Most users would not need this feature.
189
+ if self.proposal_topk is not None:
190
+ utils.transform_proposals(
191
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
192
+ )
193
+
194
+ if not self.is_train:
195
+ # USER: Modify this if you want to keep them for some reason.
196
+ dataset_dict.pop("annotations", None)
197
+ dataset_dict.pop("sem_seg_file_name", None)
198
+ return dataset_dict
199
+
200
+ if "annotations" in dataset_dict:
201
+ self._transform_annotations(dataset_dict, transforms, image_shape)
202
+
203
+ return dataset_dict
oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import copy
7
+ import logging
8
+ import os
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+
14
+ from detectron2.config import configurable
15
+ from detectron2.data import detection_utils as utils
16
+ from detectron2.data import transforms as T
17
+ from detectron2.structures import BitMasks, Instances
18
+ from detectron2.data import MetadataCatalog
19
+ from detectron2.projects.point_rend import ColorAugSSDTransform
20
+ from oneformer.utils.box_ops import masks_to_boxes
21
+ from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
22
+
23
+ __all__ = ["OneFormerUnifiedDatasetMapper"]
24
+
25
+
26
+ class OneFormerUnifiedDatasetMapper:
27
+ """
28
+ A callable which takes a dataset dict in Detectron2 Dataset format,
29
+ and map it into a format used by OneFormer for universal segmentation.
30
+
31
+ The callable currently does the following:
32
+
33
+ 1. Read the image from "file_name"
34
+ 2. Applies geometric transforms to the image and annotation
35
+ 3. Find and applies suitable cropping to the image and annotation
36
+ 4. Prepare image and annotation to Tensors
37
+ """
38
+
39
+ @configurable
40
+ def __init__(
41
+ self,
42
+ is_train=True,
43
+ *,
44
+ name,
45
+ num_queries,
46
+ meta,
47
+ augmentations,
48
+ image_format,
49
+ ignore_label,
50
+ size_divisibility,
51
+ task_seq_len,
52
+ max_seq_len,
53
+ semantic_prob,
54
+ instance_prob,
55
+ ):
56
+ """
57
+ NOTE: this interface is experimental.
58
+ Args:
59
+ is_train: for training or inference
60
+ augmentations: a list of augmentations or deterministic transforms to apply
61
+ image_format: an image format supported by :func:`detection_utils.read_image`.
62
+ ignore_label: the label that is ignored to evaluation
63
+ size_divisibility: pad image size to be divisible by this value
64
+ """
65
+ self.is_train = is_train
66
+ self.meta = meta
67
+ self.name = name
68
+ self.tfm_gens = augmentations
69
+ self.img_format = image_format
70
+ self.ignore_label = ignore_label
71
+ self.size_divisibility = size_divisibility
72
+ self.num_queries = num_queries
73
+
74
+ logger = logging.getLogger(__name__)
75
+ mode = "training" if is_train else "inference"
76
+ logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
77
+
78
+ self.things = []
79
+ for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
80
+ self.things.append(v)
81
+ self.class_names = self.meta.stuff_classes
82
+ self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
83
+ self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
84
+ self.semantic_prob = semantic_prob
85
+ self.instance_prob = instance_prob
86
+
87
+ @classmethod
88
+ def from_config(cls, cfg, is_train=True):
89
+ # Build augmentation
90
+ augs = [
91
+ T.ResizeShortestEdge(
92
+ cfg.INPUT.MIN_SIZE_TRAIN,
93
+ cfg.INPUT.MAX_SIZE_TRAIN,
94
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
95
+ )
96
+ ]
97
+ if cfg.INPUT.CROP.ENABLED:
98
+ augs.append(
99
+ T.RandomCrop_CategoryAreaConstraint(
100
+ cfg.INPUT.CROP.TYPE,
101
+ cfg.INPUT.CROP.SIZE,
102
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
103
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
104
+ )
105
+ )
106
+ if cfg.INPUT.COLOR_AUG_SSD:
107
+ augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
108
+ augs.append(T.RandomFlip())
109
+
110
+ # Assume always applies to the training set.
111
+ dataset_names = cfg.DATASETS.TRAIN
112
+ meta = MetadataCatalog.get(dataset_names[0])
113
+ ignore_label = meta.ignore_label
114
+
115
+ ret = {
116
+ "is_train": is_train,
117
+ "meta": meta,
118
+ "name": dataset_names[0],
119
+ "num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
120
+ "task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
121
+ "max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
122
+ "augmentations": augs,
123
+ "image_format": cfg.INPUT.FORMAT,
124
+ "ignore_label": ignore_label,
125
+ "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
126
+ "semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
127
+ "instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
128
+ }
129
+ return ret
130
+
131
+ def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
132
+ pan_seg_gt = pan_seg_gt.numpy()
133
+ instances = Instances(image_shape)
134
+
135
+ classes = []
136
+ texts = ["a semantic photo"] * self.num_queries
137
+ masks = []
138
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
139
+
140
+ for segment_info in segments_info:
141
+ class_id = segment_info["category_id"]
142
+ if not segment_info["iscrowd"]:
143
+ mask = pan_seg_gt == segment_info["id"]
144
+ if not np.all(mask == False):
145
+ if class_id not in classes:
146
+ cls_name = self.class_names[class_id]
147
+ classes.append(class_id)
148
+ masks.append(mask)
149
+ num_class_obj[cls_name] += 1
150
+ else:
151
+ idx = classes.index(class_id)
152
+ masks[idx] += mask
153
+ masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
154
+ label[mask] = class_id
155
+
156
+ num = 0
157
+ for i, cls_name in enumerate(self.class_names):
158
+ if num_class_obj[cls_name] > 0:
159
+ for _ in range(num_class_obj[cls_name]):
160
+ if num >= len(texts):
161
+ break
162
+ texts[num] = f"a photo with a {cls_name}"
163
+ num += 1
164
+
165
+ classes = np.array(classes)
166
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
167
+ if len(masks) == 0:
168
+ # Some image does not have annotation (all ignored)
169
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
170
+ instances.gt_bboxes = torch.zeros((0, 4))
171
+ else:
172
+ masks = BitMasks(
173
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
174
+ )
175
+ instances.gt_masks = masks.tensor
176
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
177
+ instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
178
+ return instances, texts, label
179
+
180
+ def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
181
+ pan_seg_gt = pan_seg_gt.numpy()
182
+ instances = Instances(image_shape)
183
+
184
+ classes = []
185
+ texts = ["an instance photo"] * self.num_queries
186
+ masks = []
187
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
188
+
189
+ for segment_info in segments_info:
190
+ class_id = segment_info["category_id"]
191
+ if class_id in self.things:
192
+ if not segment_info["iscrowd"]:
193
+ mask = pan_seg_gt == segment_info["id"]
194
+ if not np.all(mask == False):
195
+ cls_name = self.class_names[class_id]
196
+ classes.append(class_id)
197
+ masks.append(mask)
198
+ num_class_obj[cls_name] += 1
199
+ label[mask] = class_id
200
+
201
+ num = 0
202
+ for i, cls_name in enumerate(self.class_names):
203
+ if num_class_obj[cls_name] > 0:
204
+ for _ in range(num_class_obj[cls_name]):
205
+ if num >= len(texts):
206
+ break
207
+ texts[num] = f"a photo with a {cls_name}"
208
+ num += 1
209
+
210
+ classes = np.array(classes)
211
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
212
+ if len(masks) == 0:
213
+ # Some image does not have annotation (all ignored)
214
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
215
+ instances.gt_bboxes = torch.zeros((0, 4))
216
+ else:
217
+ masks = BitMasks(
218
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
219
+ )
220
+ instances.gt_masks = masks.tensor
221
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
222
+ return instances, texts, label
223
+
224
+ def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
225
+ pan_seg_gt = pan_seg_gt.numpy()
226
+ instances = Instances(image_shape)
227
+
228
+ classes = []
229
+ texts = ["a panoptic photo"] * self.num_queries
230
+ masks = []
231
+ label = np.ones_like(pan_seg_gt) * self.ignore_label
232
+
233
+ for segment_info in segments_info:
234
+ class_id = segment_info["category_id"]
235
+ if not segment_info["iscrowd"]:
236
+ mask = pan_seg_gt == segment_info["id"]
237
+ if not np.all(mask == False):
238
+ cls_name = self.class_names[class_id]
239
+ classes.append(class_id)
240
+ masks.append(mask)
241
+ num_class_obj[cls_name] += 1
242
+ label[mask] = class_id
243
+
244
+ num = 0
245
+ for i, cls_name in enumerate(self.class_names):
246
+ if num_class_obj[cls_name] > 0:
247
+ for _ in range(num_class_obj[cls_name]):
248
+ if num >= len(texts):
249
+ break
250
+ texts[num] = f"a photo with a {cls_name}"
251
+ num += 1
252
+
253
+ classes = np.array(classes)
254
+ instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
255
+ if len(masks) == 0:
256
+ # Some image does not have annotation (all ignored)
257
+ instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
258
+ instances.gt_bboxes = torch.zeros((0, 4))
259
+ else:
260
+ masks = BitMasks(
261
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
262
+ )
263
+ instances.gt_masks = masks.tensor
264
+ instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
265
+ for i in range(instances.gt_classes.shape[0]):
266
+ # Placeholder bounding boxes for stuff regions. Note that these are not used during training.
267
+ if instances.gt_classes[i].item() not in self.things:
268
+ instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
269
+ return instances, texts, label
270
+
271
+ def __call__(self, dataset_dict):
272
+ """
273
+ Args:
274
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
275
+
276
+ Returns:
277
+ dict: a format that builtin models in detectron2 accept
278
+ """
279
+ assert self.is_train, "OneFormerUnifiedDatasetMapper should only be used for training!"
280
+
281
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
282
+ image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
283
+ utils.check_image_size(dataset_dict, image)
284
+
285
+ # semantic segmentation
286
+ if "sem_seg_file_name" in dataset_dict:
287
+ # PyTorch transformation not implemented for uint16, so converting it to double first
288
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
289
+ else:
290
+ sem_seg_gt = None
291
+
292
+ # panoptic segmentation
293
+ if "pan_seg_file_name" in dataset_dict:
294
+ pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
295
+ segments_info = dataset_dict["segments_info"]
296
+ else:
297
+ pan_seg_gt = None
298
+ segments_info = None
299
+
300
+ if pan_seg_gt is None:
301
+ raise ValueError(
302
+ "Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
303
+ dataset_dict["file_name"]
304
+ )
305
+ )
306
+
307
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
308
+ aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
309
+ image = aug_input.image
310
+ if sem_seg_gt is not None:
311
+ sem_seg_gt = aug_input.sem_seg
312
+
313
+ # apply the same transformation to panoptic segmentation
314
+ pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
315
+
316
+ from panopticapi.utils import rgb2id
317
+
318
+ pan_seg_gt = rgb2id(pan_seg_gt)
319
+
320
+ # Pad image and segmentation label here!
321
+ image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
322
+ if sem_seg_gt is not None:
323
+ sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
324
+ pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
325
+
326
+ if self.size_divisibility > 0:
327
+ image_size = (image.shape[-2], image.shape[-1])
328
+ padding_size = [
329
+ 0,
330
+ self.size_divisibility - image_size[1],
331
+ 0,
332
+ self.size_divisibility - image_size[0],
333
+ ]
334
+ image = F.pad(image, padding_size, value=128).contiguous()
335
+ if sem_seg_gt is not None:
336
+ sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
337
+ pan_seg_gt = F.pad(
338
+ pan_seg_gt, padding_size, value=0
339
+ ).contiguous() # 0 is the VOID panoptic label
340
+
341
+ image_shape = (image.shape[-2], image.shape[-1]) # h, w
342
+
343
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
344
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
345
+ # Therefore it's important to use torch.Tensor.
346
+ dataset_dict["image"] = image
347
+
348
+ if "annotations" in dataset_dict:
349
+ raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
350
+
351
+ prob_task = np.random.uniform(0,1.)
352
+
353
+ num_class_obj = {}
354
+
355
+ for name in self.class_names:
356
+ num_class_obj[name] = 0
357
+
358
+ if prob_task < self.semantic_prob:
359
+ task = "The task is semantic"
360
+ instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
361
+ elif prob_task < self.instance_prob:
362
+ task = "The task is instance"
363
+ instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
364
+ else:
365
+ task = "The task is panoptic"
366
+ instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
367
+
368
+ dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
369
+ dataset_dict["instances"] = instances
370
+ dataset_dict["orig_shape"] = image_shape
371
+ dataset_dict["task"] = task
372
+ dataset_dict["text"] = text
373
+ dataset_dict["thing_ids"] = self.things
374
+
375
+ return dataset_dict
oneformer/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ register_ade20k_panoptic,
3
+ register_cityscapes_panoptic,
4
+ register_coco_panoptic_annos_semseg,
5
+ register_ade20k_instance,
6
+ register_coco_panoptic2instance,
7
+ )
oneformer/data/datasets/register_ade20k_instance.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_instance.py
3
+ # ------------------------------------------------------------------------------
4
+
5
+ import json
6
+ import logging
7
+ import numpy as np
8
+ import os
9
+ from PIL import Image
10
+
11
+ from detectron2.data import DatasetCatalog, MetadataCatalog
12
+ from detectron2.data.datasets.coco import load_coco_json, register_coco_instances
13
+ from detectron2.utils.file_io import PathManager
14
+
15
+ ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}]
16
+
17
+
18
+ _PREDEFINED_SPLITS = {
19
+ # point annotations without masks
20
+ "ade20k_instance_train": (
21
+ "ADEChallengeData2016/images/training",
22
+ "ADEChallengeData2016/ade20k_instance_train.json",
23
+ ),
24
+ "ade20k_instance_val": (
25
+ "ADEChallengeData2016/images/validation",
26
+ "ADEChallengeData2016/ade20k_instance_val.json",
27
+ ),
28
+ }
29
+
30
+
31
+ def _get_ade_instances_meta():
32
+ thing_ids = [k["id"] for k in ADE_CATEGORIES]
33
+ assert len(thing_ids) == 100, len(thing_ids)
34
+ # Mapping from the incontiguous ADE category id to an id in [0, 99]
35
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
36
+ thing_classes = [k["name"] for k in ADE_CATEGORIES]
37
+ ret = {
38
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
39
+ "thing_classes": thing_classes,
40
+ }
41
+ return ret
42
+
43
+
44
+ def register_all_ade20k_instance(root):
45
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS.items():
46
+ # Assume pre-defined datasets live in `./datasets`.
47
+ register_coco_instances(
48
+ key,
49
+ _get_ade_instances_meta(),
50
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
51
+ os.path.join(root, image_root),
52
+ )
53
+
54
+
55
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
56
+ register_all_ade20k_instance(_root)
oneformer/data/datasets/register_ade20k_panoptic.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_panoptic.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import os
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.utils.file_io import PathManager
11
+
12
+ ADE20K_150_CATEGORIES = [
13
+ {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
14
+ {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
15
+ {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
16
+ {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
17
+ {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
18
+ {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
19
+ {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
20
+ {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
21
+ {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
22
+ {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
23
+ {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
24
+ {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
25
+ {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
26
+ {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
27
+ {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
28
+ {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
29
+ {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
30
+ {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
31
+ {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
32
+ {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
33
+ {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
34
+ {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
35
+ {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
36
+ {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
37
+ {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
38
+ {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
39
+ {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
40
+ {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
41
+ {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
42
+ {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
43
+ {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
44
+ {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
45
+ {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
46
+ {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
47
+ {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
48
+ {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
49
+ {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
50
+ {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
51
+ {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
52
+ {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
53
+ {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
54
+ {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
55
+ {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
56
+ {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
57
+ {
58
+ "color": [6, 51, 255],
59
+ "id": 44,
60
+ "isthing": 1,
61
+ "name": "chest of drawers, chest, bureau, dresser",
62
+ },
63
+ {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
64
+ {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
65
+ {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
66
+ {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
67
+ {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
68
+ {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
69
+ {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
70
+ {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
71
+ {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
72
+ {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
73
+ {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
74
+ {
75
+ "color": [255, 71, 0],
76
+ "id": 56,
77
+ "isthing": 1,
78
+ "name": "pool table, billiard table, snooker table",
79
+ },
80
+ {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
81
+ {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
82
+ {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
83
+ {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
84
+ {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
85
+ {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
86
+ {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
87
+ {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
88
+ {
89
+ "color": [0, 255, 133],
90
+ "id": 65,
91
+ "isthing": 1,
92
+ "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
93
+ },
94
+ {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
95
+ {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
96
+ {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
97
+ {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
98
+ {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
99
+ {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
100
+ {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
101
+ {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
102
+ {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
103
+ {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
104
+ {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
105
+ {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
106
+ {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
107
+ {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
108
+ {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
109
+ {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
110
+ {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
111
+ {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
112
+ {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
113
+ {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
114
+ {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
115
+ {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
116
+ {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
117
+ {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
118
+ {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
119
+ {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
120
+ {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
121
+ {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
122
+ {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
123
+ {
124
+ "color": [0, 122, 255],
125
+ "id": 95,
126
+ "isthing": 1,
127
+ "name": "bannister, banister, balustrade, balusters, handrail",
128
+ },
129
+ {
130
+ "color": [0, 255, 163],
131
+ "id": 96,
132
+ "isthing": 0,
133
+ "name": "escalator, moving staircase, moving stairway",
134
+ },
135
+ {
136
+ "color": [255, 153, 0],
137
+ "id": 97,
138
+ "isthing": 1,
139
+ "name": "ottoman, pouf, pouffe, puff, hassock",
140
+ },
141
+ {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
142
+ {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
143
+ {
144
+ "color": [143, 255, 0],
145
+ "id": 100,
146
+ "isthing": 0,
147
+ "name": "poster, posting, placard, notice, bill, card",
148
+ },
149
+ {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
150
+ {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
151
+ {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
152
+ {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
153
+ {
154
+ "color": [133, 0, 255],
155
+ "id": 105,
156
+ "isthing": 0,
157
+ "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
158
+ },
159
+ {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
160
+ {
161
+ "color": [184, 0, 255],
162
+ "id": 107,
163
+ "isthing": 1,
164
+ "name": "washer, automatic washer, washing machine",
165
+ },
166
+ {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
167
+ {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
168
+ {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
169
+ {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
170
+ {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
171
+ {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
172
+ {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
173
+ {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
174
+ {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
175
+ {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
176
+ {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
177
+ {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
178
+ {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
179
+ {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
180
+ {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
181
+ {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
182
+ {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
183
+ {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
184
+ {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
185
+ {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
186
+ {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
187
+ {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
188
+ {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
189
+ {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
190
+ {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
191
+ {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
192
+ {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
193
+ {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
194
+ {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
195
+ {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
196
+ {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
197
+ {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
198
+ {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
199
+ {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
200
+ {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
201
+ {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
202
+ {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
203
+ {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
204
+ {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
205
+ {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
206
+ {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
207
+ {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
208
+ ]
209
+
210
+ ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
211
+
212
+ MetadataCatalog.get("ade20k_sem_seg_train").set(
213
+ stuff_colors=ADE20k_COLORS[:],
214
+ )
215
+
216
+ MetadataCatalog.get("ade20k_sem_seg_val").set(
217
+ stuff_colors=ADE20k_COLORS[:],
218
+ )
219
+
220
+
221
+ def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
222
+ """
223
+ Args:
224
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
225
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
226
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
227
+ Returns:
228
+ list[dict]: a list of dicts in Detectron2 standard format. (See
229
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
230
+ """
231
+
232
+ def _convert_category_id(segment_info, meta):
233
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
234
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
235
+ segment_info["category_id"]
236
+ ]
237
+ segment_info["isthing"] = True
238
+ else:
239
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
240
+ segment_info["category_id"]
241
+ ]
242
+ segment_info["isthing"] = False
243
+ return segment_info
244
+
245
+ with PathManager.open(json_file) as f:
246
+ json_info = json.load(f)
247
+
248
+ ret = []
249
+ for ann in json_info["annotations"]:
250
+ image_id = ann["image_id"]
251
+ # TODO: currently we assume image and label has the same filename but
252
+ # different extension, and images have extension ".jpg" for COCO. Need
253
+ # to make image extension a user-provided argument if we extend this
254
+ # function to support other COCO-like datasets.
255
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
256
+ label_file = os.path.join(gt_dir, ann["file_name"])
257
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
258
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
259
+ ret.append(
260
+ {
261
+ "file_name": image_file,
262
+ "image_id": image_id,
263
+ "pan_seg_file_name": label_file,
264
+ "sem_seg_file_name": sem_label_file,
265
+ "segments_info": segments_info,
266
+ }
267
+ )
268
+ assert len(ret), f"No images found in {image_dir}!"
269
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
270
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
271
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
272
+ return ret
273
+
274
+
275
+ def register_ade20k_panoptic(
276
+ name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None,
277
+ ):
278
+ """
279
+ Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
280
+ The dictionaries in this registered dataset follows detectron2's standard format.
281
+ Hence it's called "standard".
282
+ Args:
283
+ name (str): the name that identifies a dataset,
284
+ e.g. "ade20k_panoptic_train"
285
+ metadata (dict): extra metadata associated with this dataset.
286
+ image_root (str): directory which contains all the images
287
+ panoptic_root (str): directory which contains panoptic annotation images in COCO format
288
+ panoptic_json (str): path to the json panoptic annotation file in COCO format
289
+ sem_seg_root (none): not used, to be consistent with
290
+ `register_coco_panoptic_separated`.
291
+ instances_json (str): path to the json instance annotation file
292
+ """
293
+ panoptic_name = name
294
+ DatasetCatalog.register(
295
+ panoptic_name,
296
+ lambda: load_ade20k_panoptic_json(
297
+ panoptic_json, image_root, panoptic_root, semantic_root, metadata
298
+ ),
299
+ )
300
+ MetadataCatalog.get(panoptic_name).set(
301
+ panoptic_root=panoptic_root,
302
+ image_root=image_root,
303
+ panoptic_json=panoptic_json,
304
+ json_file=instances_json,
305
+ evaluator_type="ade20k_panoptic_seg",
306
+ ignore_label=255,
307
+ label_divisor=1000,
308
+ **metadata,
309
+ )
310
+
311
+
312
+ _PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
313
+ "ade20k_panoptic_train": (
314
+ "ADEChallengeData2016/images/training",
315
+ "ADEChallengeData2016/ade20k_panoptic_train",
316
+ "ADEChallengeData2016/ade20k_panoptic_train.json",
317
+ "ADEChallengeData2016/annotations_detectron2/training",
318
+ "ADEChallengeData2016/ade20k_instance_train.json",
319
+ ),
320
+ "ade20k_panoptic_val": (
321
+ "ADEChallengeData2016/images/validation",
322
+ "ADEChallengeData2016/ade20k_panoptic_val",
323
+ "ADEChallengeData2016/ade20k_panoptic_val.json",
324
+ "ADEChallengeData2016/annotations_detectron2/validation",
325
+ "ADEChallengeData2016/ade20k_instance_val.json",
326
+ ),
327
+ }
328
+
329
+
330
+ def get_metadata():
331
+ meta = {}
332
+ # The following metadata maps contiguous id from [0, #thing categories +
333
+ # #stuff categories) to their names and colors. We have to replica of the
334
+ # same name and color under "thing_*" and "stuff_*" because the current
335
+ # visualization function in D2 handles thing and class classes differently
336
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
337
+ # enable reusing existing visualization functions.
338
+ thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
339
+ thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
340
+ stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
341
+ stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
342
+
343
+ meta["thing_classes"] = thing_classes
344
+ meta["thing_colors"] = thing_colors
345
+ meta["stuff_classes"] = stuff_classes
346
+ meta["stuff_colors"] = stuff_colors
347
+
348
+ # Convert category id for training:
349
+ # category id: like semantic segmentation, it is the class id for each
350
+ # pixel. Since there are some classes not used in evaluation, the category
351
+ # id is not always contiguous and thus we have two set of category ids:
352
+ # - original category id: category id in the original dataset, mainly
353
+ # used for evaluation.
354
+ # - contiguous category id: [0, #classes), in order to train the linear
355
+ # softmax classifier.
356
+ thing_dataset_id_to_contiguous_id = {}
357
+ stuff_dataset_id_to_contiguous_id = {}
358
+
359
+ for i, cat in enumerate(ADE20K_150_CATEGORIES):
360
+ if cat["isthing"]:
361
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
362
+ # else:
363
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
364
+
365
+ # in order to use sem_seg evaluator
366
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
367
+
368
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
369
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
370
+
371
+ return meta
372
+
373
+
374
+ def register_all_ade20k_panoptic(root):
375
+ metadata = get_metadata()
376
+ for (
377
+ prefix,
378
+ (image_root, panoptic_root, panoptic_json, semantic_root, instance_json),
379
+ ) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
380
+ # The "standard" version of COCO panoptic segmentation dataset,
381
+ # e.g. used by Panoptic-DeepLab
382
+ register_ade20k_panoptic(
383
+ prefix,
384
+ metadata,
385
+ os.path.join(root, image_root),
386
+ os.path.join(root, panoptic_root),
387
+ os.path.join(root, semantic_root),
388
+ os.path.join(root, panoptic_json),
389
+ os.path.join(root, instance_json),
390
+ )
391
+
392
+
393
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
394
+ register_all_ade20k_panoptic(_root)
oneformer/data/datasets/register_cityscapes_panoptic.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/cityscapes_panoptic.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ from detectron2.data import DatasetCatalog, MetadataCatalog
11
+ from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
12
+ from detectron2.utils.file_io import PathManager
13
+
14
+ """
15
+ This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
16
+ """
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
23
+ files = []
24
+ # scan through the directory
25
+ cities = PathManager.ls(image_dir)
26
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
27
+ image_dict = {}
28
+ for city in cities:
29
+ city_img_dir = os.path.join(image_dir, city)
30
+ for basename in PathManager.ls(city_img_dir):
31
+ image_file = os.path.join(city_img_dir, basename)
32
+
33
+ suffix = "_leftImg8bit.png"
34
+ assert basename.endswith(suffix), basename
35
+ basename = os.path.basename(basename)[: -len(suffix)]
36
+
37
+ image_dict[basename] = image_file
38
+
39
+ for ann in json_info["annotations"]:
40
+ image_file = image_dict.get(ann["image_id"], None)
41
+ assert image_file is not None, "No image {} found for annotation {}".format(
42
+ ann["image_id"], ann["file_name"]
43
+ )
44
+ label_file = os.path.join(gt_dir, ann["file_name"])
45
+ segments_info = ann["segments_info"]
46
+ files.append((image_file, label_file, segments_info))
47
+
48
+ assert len(files), "No images found in {}".format(image_dir)
49
+ assert PathManager.isfile(files[0][0]), files[0][0]
50
+ assert PathManager.isfile(files[0][1]), files[0][1]
51
+ return files
52
+
53
+
54
+ def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
55
+ """
56
+ Args:
57
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
58
+ gt_dir (str): path to the raw annotations. e.g.,
59
+ "~/cityscapes/gtFine/cityscapes_panoptic_train".
60
+ gt_json (str): path to the json file. e.g.,
61
+ "~/cityscapes/gtFine/cityscapes_panoptic_train.json".
62
+ meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
63
+ and "stuff_dataset_id_to_contiguous_id" to map category ids to
64
+ contiguous ids for training.
65
+
66
+ Returns:
67
+ list[dict]: a list of dicts in Detectron2 standard format. (See
68
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
69
+ """
70
+
71
+ def _convert_category_id(segment_info, meta):
72
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
73
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
74
+ segment_info["category_id"]
75
+ ]
76
+ else:
77
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
78
+ segment_info["category_id"]
79
+ ]
80
+ return segment_info
81
+
82
+ assert os.path.exists(
83
+ gt_json
84
+ ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
85
+
86
+
87
+ with open(gt_json) as f:
88
+ json_info = json.load(f)
89
+
90
+ files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
91
+ ret = []
92
+ for image_file, label_file, segments_info in files:
93
+ sem_label_file = (
94
+ image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
95
+ )
96
+ segments_info = [_convert_category_id(x, meta) for x in segments_info]
97
+ ret.append(
98
+ {
99
+ "file_name": image_file,
100
+ "image_id": "_".join(
101
+ os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
102
+ ),
103
+ "sem_seg_file_name": sem_label_file,
104
+ "pan_seg_file_name": label_file,
105
+ "segments_info": segments_info,
106
+ }
107
+ )
108
+ assert len(ret), f"No images found in {image_dir}!"
109
+ assert PathManager.isfile(
110
+ ret[0]["sem_seg_file_name"]
111
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
112
+ assert PathManager.isfile(
113
+ ret[0]["pan_seg_file_name"]
114
+ ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
115
+ return ret
116
+
117
+
118
+ _RAW_CITYSCAPES_PANOPTIC_SPLITS = {
119
+ "cityscapes_fine_panoptic_train": (
120
+ "cityscapes/leftImg8bit/train",
121
+ "cityscapes/gtFine/cityscapes_panoptic_train",
122
+ "cityscapes/gtFine/cityscapes_panoptic_train.json",
123
+ ),
124
+ "cityscapes_fine_panoptic_val": (
125
+ "cityscapes/leftImg8bit/val",
126
+ "cityscapes/gtFine/cityscapes_panoptic_val",
127
+ "cityscapes/gtFine/cityscapes_panoptic_val.json",
128
+ ),
129
+ # "cityscapes_fine_panoptic_test": not supported yet
130
+ }
131
+
132
+
133
+ def register_all_cityscapes_panoptic(root):
134
+ meta = {}
135
+ # The following metadata maps contiguous id from [0, #thing categories +
136
+ # #stuff categories) to their names and colors. We have to replica of the
137
+ # same name and color under "thing_*" and "stuff_*" because the current
138
+ # visualization function in D2 handles thing and class classes differently
139
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
140
+ # enable reusing existing visualization functions.
141
+ thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
142
+ thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
143
+ stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
144
+ stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
145
+
146
+ meta["thing_classes"] = thing_classes
147
+ meta["thing_colors"] = thing_colors
148
+ meta["stuff_classes"] = stuff_classes
149
+ meta["stuff_colors"] = stuff_colors
150
+
151
+ # There are three types of ids in cityscapes panoptic segmentation:
152
+ # (1) category id: like semantic segmentation, it is the class id for each
153
+ # pixel. Since there are some classes not used in evaluation, the category
154
+ # id is not always contiguous and thus we have two set of category ids:
155
+ # - original category id: category id in the original dataset, mainly
156
+ # used for evaluation.
157
+ # - contiguous category id: [0, #classes), in order to train the classifier
158
+ # (2) instance id: this id is used to differentiate different instances from
159
+ # the same category. For "stuff" classes, the instance id is always 0; for
160
+ # "thing" classes, the instance id starts from 1 and 0 is reserved for
161
+ # ignored instances (e.g. crowd annotation).
162
+ # (3) panoptic id: this is the compact id that encode both category and
163
+ # instance id by: category_id * 1000 + instance_id.
164
+ thing_dataset_id_to_contiguous_id = {}
165
+ stuff_dataset_id_to_contiguous_id = {}
166
+
167
+ for k in CITYSCAPES_CATEGORIES:
168
+ if k["isthing"] == 1:
169
+ thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
170
+ else:
171
+ stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
172
+
173
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
174
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
175
+
176
+ for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
177
+ image_dir = os.path.join(root, image_dir)
178
+ gt_dir = os.path.join(root, gt_dir)
179
+ gt_json = os.path.join(root, gt_json)
180
+
181
+ if key in DatasetCatalog.list():
182
+ DatasetCatalog.remove(key)
183
+
184
+ DatasetCatalog.register(
185
+ key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
186
+ )
187
+ MetadataCatalog.get(key).set(
188
+ panoptic_root=gt_dir,
189
+ image_root=image_dir,
190
+ panoptic_json=gt_json,
191
+ gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
192
+ evaluator_type="cityscapes_panoptic_seg",
193
+ ignore_label=255,
194
+ label_divisor=1000,
195
+ **meta,
196
+ )
197
+
198
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
199
+ register_all_cityscapes_panoptic(_root)
oneformer/data/datasets/register_coco_panoptic2instance.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+
7
+ """
8
+ This file registers pre-defined datasets at hard-coded paths, and their metadata.
9
+
10
+ We hard-code metadata for common datasets. This will enable:
11
+ 1. Consistency check when loading the datasets
12
+ 2. Use models on these standard datasets directly and run demos,
13
+ without having to download the dataset annotations
14
+
15
+ We hard-code some paths to the dataset that's assumed to
16
+ exist in "./datasets/".
17
+
18
+ Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
19
+ To add new dataset, refer to the tutorial "docs/DATASETS.md".
20
+ """
21
+
22
+ import os
23
+ from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
24
+ from detectron2.data.datasets.coco import register_coco_instances
25
+
26
+
27
+ _PREDEFINED_SPLITS_COCO = {
28
+ "coco_2017_val_panoptic2instance": ("coco/val2017", "coco/annotations/panoptic2instances_val2017.json"),
29
+ }
30
+
31
+
32
+ def register_panoptic2instances_coco(root):
33
+ for key, (image_root, json_file) in _PREDEFINED_SPLITS_COCO.items():
34
+ # Assume pre-defined datasets live in `./datasets`.
35
+ register_coco_instances(
36
+ key,
37
+ _get_builtin_metadata("coco"),
38
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
39
+ os.path.join(root, image_root),
40
+ )
41
+
42
+
43
+ _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
44
+ register_panoptic2instances_coco(_root)
oneformer/data/datasets/register_coco_panoptic_annos_semseg.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py
3
+ # Modified by Jitesh Jain (https://github.com/praeclarumjj3)
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import json
7
+ import os
8
+
9
+ from detectron2.data import DatasetCatalog, MetadataCatalog
10
+ from detectron2.data.datasets import load_sem_seg
11
+ from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
12
+ from detectron2.utils.file_io import PathManager
13
+ import contextlib
14
+ import logging
15
+ import io
16
+ from fvcore.common.timer import Timer
17
+ import pycocotools.mask as mask_util
18
+ from detectron2.structures import BoxMode
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ _PREDEFINED_SPLITS_COCO_PANOPTIC = {
25
+ "coco_2017_train_panoptic": (
26
+ # This is the original panoptic annotation directory
27
+ "coco/panoptic_train2017",
28
+ "coco/annotations/panoptic_train2017.json",
29
+ # This directory contains semantic annotations that are
30
+ # converted from panoptic annotations.
31
+ # It is used by PanopticFPN.
32
+ # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
33
+ # to create these directories.
34
+ "coco/panoptic_semseg_train2017",
35
+ ),
36
+ "coco_2017_val_panoptic": (
37
+ "coco/panoptic_val2017",
38
+ "coco/annotations/panoptic_val2017.json",
39
+ "coco/panoptic_semseg_val2017",
40
+ ),
41
+ }
42
+
43
+ def load_coco_instance_json(json_file, image_root, dataset_name=None):
44
+ from pycocotools.coco import COCO
45
+
46
+ timer = Timer()
47
+ json_file = PathManager.get_local_path(json_file)
48
+ with contextlib.redirect_stdout(io.StringIO()):
49
+ coco_api = COCO(json_file)
50
+ if timer.seconds() > 1:
51
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
52
+
53
+ id_map = None
54
+ if dataset_name is not None:
55
+ meta = MetadataCatalog.get(dataset_name)
56
+ cat_ids = sorted(coco_api.getCatIds())
57
+ cats = coco_api.loadCats(cat_ids)
58
+ # The categories in a custom json file may not be sorted.
59
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
60
+ meta.thing_classes = thing_classes
61
+
62
+ # In COCO, certain category ids are artificially removed,
63
+ # and by convention they are always ignored.
64
+ # We deal with COCO's id issue and translate
65
+ # the category ids to contiguous ids in [0, 80).
66
+
67
+ # It works by looking at the "categories" field in the json, therefore
68
+ # if users' own json also have incontiguous ids, we'll
69
+ # apply this mapping as well but print a warning.
70
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
71
+ if "coco" not in dataset_name:
72
+ logger.warning(
73
+ """
74
+ Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
75
+ """
76
+ )
77
+ id_map = {v: i for i, v in enumerate(cat_ids)}
78
+ meta.thing_dataset_id_to_contiguous_id = id_map
79
+
80
+ # sort indices for reproducible results
81
+ img_ids = sorted(coco_api.imgs.keys())
82
+ # imgs is a list of dicts, each looks something like:
83
+ # {'license': 4,
84
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
85
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
86
+ # 'height': 427,
87
+ # 'width': 640,
88
+ # 'date_captured': '2013-11-17 05:57:24',
89
+ # 'id': 1268}
90
+ imgs = coco_api.loadImgs(img_ids)
91
+ # anns is a list[list[dict]], where each dict is an annotation
92
+ # record for an object. The inner list enumerates the objects in an image
93
+ # and the outer list enumerates over images. Example of anns[0]:
94
+ # [{'segmentation': [[192.81,
95
+ # 247.09,
96
+ # ...
97
+ # 219.03,
98
+ # 249.06]],
99
+ # 'area': 1035.749,
100
+ # 'iscrowd': 0,
101
+ # 'image_id': 1268,
102
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
103
+ # 'category_id': 16,
104
+ # 'id': 42986},
105
+ # ...]
106
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
107
+ total_num_valid_anns = sum([len(x) for x in anns])
108
+ total_num_anns = len(coco_api.anns)
109
+ if total_num_valid_anns < total_num_anns:
110
+ logger.warning(
111
+ f"{json_file} contains {total_num_anns} annotations, but only "
112
+ f"{total_num_valid_anns} of them match to images in the file."
113
+ )
114
+
115
+ if "minival" not in json_file:
116
+ # The popular valminusminival & minival annotations for COCO2014 contain this bug.
117
+ # However the ratio of buggy annotations there is tiny and does not affect accuracy.
118
+ # Therefore we explicitly white-list them.
119
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
120
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
121
+ json_file
122
+ )
123
+
124
+ imgs_anns = list(zip(imgs, anns))
125
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
126
+
127
+ dataset_dicts = {}
128
+
129
+ ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"]
130
+
131
+ num_instances_without_valid_segmentation = 0
132
+
133
+ for (img_dict, anno_dict_list) in imgs_anns:
134
+ record = {}
135
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
136
+ record["height"] = img_dict["height"]
137
+ record["width"] = img_dict["width"]
138
+ image_id = record["image_id"] = img_dict["id"]
139
+
140
+ objs = []
141
+ for anno in anno_dict_list:
142
+ # Check that the image_id in this annotation is the same as
143
+ # the image_id we're looking at.
144
+ # This fails only when the data parsing logic or the annotation file is buggy.
145
+
146
+ # The original COCO valminusminival2014 & minival2014 annotation files
147
+ # actually contains bugs that, together with certain ways of using COCO API,
148
+ # can trigger this assertion.
149
+ assert anno["image_id"] == image_id
150
+
151
+ assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
152
+
153
+ obj = {key: anno[key] for key in ann_keys if key in anno}
154
+ if "bbox" in obj and len(obj["bbox"]) == 0:
155
+ raise ValueError(
156
+ f"One annotation of image {image_id} contains empty 'bbox' value! "
157
+ "This json does not have valid COCO format."
158
+ )
159
+
160
+ segm = anno.get("segmentation", None)
161
+ if segm: # either list[list[float]] or dict(RLE)
162
+ if isinstance(segm, dict):
163
+ if isinstance(segm["counts"], list):
164
+ # convert to compressed RLE
165
+ segm = mask_util.frPyObjects(segm, *segm["size"])
166
+ else:
167
+ # filter out invalid polygons (< 3 points)
168
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
169
+ if len(segm) == 0:
170
+ num_instances_without_valid_segmentation += 1
171
+ continue # ignore this instance
172
+ obj["segmentation"] = segm
173
+
174
+ keypts = anno.get("keypoints", None)
175
+ if keypts: # list[int]
176
+ for idx, v in enumerate(keypts):
177
+ if idx % 3 != 2:
178
+ # COCO's segmentation coordinates are floating points in [0, H or W],
179
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
180
+ # Therefore we assume the coordinates are "pixel indices" and
181
+ # add 0.5 to convert to floating point coordinates.
182
+ keypts[idx] = v + 0.5
183
+ obj["keypoints"] = keypts
184
+
185
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
186
+ if id_map:
187
+ annotation_category_id = obj["category_id"]
188
+ try:
189
+ obj["category_id"] = id_map[annotation_category_id]
190
+ except KeyError as e:
191
+ raise KeyError(
192
+ f"Encountered category_id={annotation_category_id} "
193
+ "but this id does not exist in 'categories' of the json file."
194
+ ) from e
195
+ objs.append(obj)
196
+ record["annotations"] = objs
197
+ dataset_dicts[image_id] = record
198
+
199
+ if num_instances_without_valid_segmentation > 0:
200
+ logger.warning(
201
+ "Filtered out {} instances without valid segmentation. ".format(
202
+ num_instances_without_valid_segmentation
203
+ )
204
+ + "There might be issues in your dataset generation process. Please "
205
+ "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
206
+ )
207
+ return dataset_dicts
208
+
209
+ def get_metadata():
210
+ meta = {}
211
+ # The following metadata maps contiguous id from [0, #thing categories +
212
+ # #stuff categories) to their names and colors. We have to replica of the
213
+ # same name and color under "thing_*" and "stuff_*" because the current
214
+ # visualization function in D2 handles thing and class classes differently
215
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
216
+ # enable reusing existing visualization functions.
217
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
218
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
219
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
220
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
221
+
222
+ meta["thing_classes"] = thing_classes
223
+ meta["thing_colors"] = thing_colors
224
+ meta["stuff_classes"] = stuff_classes
225
+ meta["stuff_colors"] = stuff_colors
226
+
227
+ # Convert category id for training:
228
+ # category id: like semantic segmentation, it is the class id for each
229
+ # pixel. Since there are some classes not used in evaluation, the category
230
+ # id is not always contiguous and thus we have two set of category ids:
231
+ # - original category id: category id in the original dataset, mainly
232
+ # used for evaluation.
233
+ # - contiguous category id: [0, #classes), in order to train the linear
234
+ # softmax classifier.
235
+ thing_dataset_id_to_contiguous_id = {}
236
+ stuff_dataset_id_to_contiguous_id = {}
237
+
238
+ for i, cat in enumerate(COCO_CATEGORIES):
239
+ if cat["isthing"]:
240
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
241
+ # else:
242
+ # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
243
+
244
+ # in order to use sem_seg evaluator
245
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
246
+
247
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
248
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
249
+
250
+ return meta
251
+
252
+
253
+ def load_coco_panoptic_json(json_file, instances_json, instances_name, image_dir, gt_dir, semseg_dir, meta):
254
+ """
255
+ Args:
256
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
257
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
258
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
259
+ Returns:
260
+ list[dict]: a list of dicts in Detectron2 standard format. (See
261
+ `Using Custom Datasets </tutorials/datasets.html>`_ )
262
+ """
263
+
264
+ def _convert_category_id(segment_info, meta):
265
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
266
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
267
+ segment_info["category_id"]
268
+ ]
269
+ segment_info["isthing"] = True
270
+ else:
271
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
272
+ segment_info["category_id"]
273
+ ]
274
+ segment_info["isthing"] = False
275
+ return segment_info
276
+
277
+ with PathManager.open(json_file) as f:
278
+ json_info = json.load(f)
279
+
280
+ instance_data_dicts = load_coco_instance_json(instances_json, image_dir.replace("panoptic_", ""), instances_name)
281
+
282
+ ret = []
283
+ for ann in json_info["annotations"]:
284
+ image_id = int(ann["image_id"])
285
+ # TODO: currently we assume image and label has the same filename but
286
+ # different extension, and images have extension ".jpg" for COCO. Need
287
+ # to make image extension a user-provided argument if we extend this
288
+ # function to support other COCO-like datasets.
289
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
290
+ label_file = os.path.join(gt_dir, ann["file_name"])
291
+ sem_label_file = os.path.join(semseg_dir, ann["file_name"])
292
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
293
+ ret.append(
294
+ {
295
+ "file_name": image_file,
296
+ "image_id": image_id,
297
+ "pan_seg_file_name": label_file,
298
+ "sem_seg_file_name": sem_label_file,
299
+ "segments_info": segments_info,
300
+ "annotations": instance_data_dicts[image_id]["annotations"],
301
+ }
302
+ )
303
+ assert len(ret), f"No images found in {image_dir}!"
304
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
305
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
306
+ assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
307
+ return ret
308
+
309
+
310
+ def register_coco_panoptic_annos_sem_seg(
311
+ name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json, instances_name,
312
+ ):
313
+ panoptic_name = name
314
+ delattr(MetadataCatalog.get(panoptic_name), "thing_classes")
315
+ delattr(MetadataCatalog.get(panoptic_name), "thing_colors")
316
+ MetadataCatalog.get(panoptic_name).set(
317
+ thing_classes=metadata["thing_classes"],
318
+ thing_colors=metadata["thing_colors"],
319
+ # thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"],
320
+ )
321
+
322
+ # the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg"
323
+ semantic_name = name + "_with_sem_seg"
324
+ DatasetCatalog.register(
325
+ semantic_name,
326
+ lambda: load_coco_panoptic_json(panoptic_json, instances_json, instances_name, image_root, panoptic_root, sem_seg_root, metadata),
327
+ )
328
+ MetadataCatalog.get(semantic_name).set(
329
+ sem_seg_root=sem_seg_root,
330
+ panoptic_root=panoptic_root,
331
+ image_root=image_root,
332
+ panoptic_json=panoptic_json,
333
+ json_file=instances_json,
334
+ evaluator_type="coco_panoptic_seg",
335
+ ignore_label=255,
336
+ label_divisor=1000,
337
+ **metadata,
338
+ )
339
+
340
+
341
+ def register_all_coco_panoptic_annos_sem_seg(root):
342
+ for (
343
+ prefix,
344
+ (panoptic_root, panoptic_json, semantic_root),
345
+ ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
346
+
347
+ prefix_instances = prefix[: -len("_panoptic")]
348
+ instances_meta = MetadataCatalog.get(prefix_instances)
349
+ image_root, instances_json = instances_meta.image_root, instances_meta.json_file
350
+
351
+ if 'val' in instances_json:
352
+ instances_json = instances_json.replace('instances_', 'panoptic2instances_')
353
+
354
+ register_coco_panoptic_annos_sem_seg(
355
+ prefix,
356
+ get_metadata(),
357
+ image_root,
358
+ os.path.join(root, panoptic_root),
359
+ os.path.join(root, panoptic_json),
360
+ os.path.join(root, semantic_root),
361
+ instances_json,
362
+ prefix_instances,
363
+ )
364
+
365
+
366
+ _root = os.getenv("DETECTRON2_DATASETS", "datasets")
367
+ register_all_coco_panoptic_annos_sem_seg(_root)
oneformer/data/tokenizer.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------------------
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2021 OpenAI
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ #
24
+ # Modified by Jiarui Xu
25
+ # -------------------------------------------------------------------------
26
+
27
+ import gzip
28
+ import html
29
+ import os
30
+ from functools import lru_cache
31
+
32
+ import ftfy
33
+ import regex as re
34
+ import torch
35
+
36
+
37
+ @lru_cache()
38
+ def default_bpe():
39
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bpe_simple_vocab_16e6.txt')
40
+
41
+ @lru_cache()
42
+ def bytes_to_unicode():
43
+ """Returns list of utf-8 byte and a corresponding list of unicode strings.
44
+
45
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
46
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent
47
+ coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables
48
+ between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
49
+ """
50
+ bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
51
+ cs = bs[:]
52
+ n = 0
53
+ for b in range(2**8):
54
+ if b not in bs:
55
+ bs.append(b)
56
+ cs.append(2**8 + n)
57
+ n += 1
58
+ cs = [chr(n) for n in cs]
59
+ return dict(zip(bs, cs))
60
+
61
+
62
+ def get_pairs(word):
63
+ """Return set of symbol pairs in a word.
64
+
65
+ Word is represented as tuple of symbols (symbols being variable-length strings).
66
+ """
67
+ pairs = set()
68
+ prev_char = word[0]
69
+ for char in word[1:]:
70
+ pairs.add((prev_char, char))
71
+ prev_char = char
72
+ return pairs
73
+
74
+
75
+ def basic_clean(text):
76
+ text = ftfy.fix_text(text)
77
+ text = html.unescape(html.unescape(text))
78
+ return text.strip()
79
+
80
+
81
+ def whitespace_clean(text):
82
+ text = re.sub(r'\s+', ' ', text)
83
+ text = text.strip()
84
+ return text
85
+
86
+ class Tokenize:
87
+
88
+ def __init__(self, tokenizer, max_seq_len=77, truncate=True):
89
+ self.tokenizer = tokenizer
90
+ self.max_seq_len = max_seq_len
91
+ self.truncate = truncate
92
+
93
+ def __call__(self, texts):
94
+ expanded_dim = False
95
+ if isinstance(texts, str):
96
+ texts = [texts]
97
+ expanded_dim = True
98
+
99
+ sot_token = self.tokenizer.encoder['<|startoftext|>']
100
+ eot_token = self.tokenizer.encoder['<|endoftext|>']
101
+ all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
102
+ result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
103
+
104
+ for i, tokens in enumerate(all_tokens):
105
+ if len(tokens) > self.max_seq_len:
106
+ if self.truncate:
107
+ tokens = tokens[:self.max_seq_len]
108
+ tokens[-1] = eot_token
109
+ else:
110
+ raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
111
+ result[i, :len(tokens)] = torch.tensor(tokens)
112
+
113
+ if expanded_dim:
114
+ return result[0]
115
+
116
+ return result
117
+
118
+
119
+ class SimpleTokenizer(object):
120
+
121
+ def __init__(self, bpe_path: str = default_bpe()):
122
+ self.byte_encoder = bytes_to_unicode()
123
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
124
+
125
+ with open(bpe_path) as f:
126
+ contents = f.readlines()
127
+ merges = []
128
+ for cnt in contents:
129
+ merges.append(cnt.split('\n')[0])
130
+ merges.append("")
131
+
132
+ # merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
133
+ merges = merges[1:49152 - 256 - 2 + 1]
134
+ merges = [tuple(merge.split()) for merge in merges]
135
+ vocab = list(bytes_to_unicode().values())
136
+ vocab = vocab + [v + '</w>' for v in vocab]
137
+ for merge in merges:
138
+ vocab.append(''.join(merge))
139
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
140
+ self.encoder = dict(zip(vocab, range(len(vocab))))
141
+ self.decoder = {v: k for k, v in self.encoder.items()}
142
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
143
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
144
+ self.pat = re.compile(
145
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
146
+ re.IGNORECASE)
147
+
148
+ def bpe(self, token):
149
+ if token in self.cache:
150
+ return self.cache[token]
151
+ word = tuple(token[:-1]) + (token[-1] + '</w>', )
152
+ pairs = get_pairs(word)
153
+
154
+ if not pairs:
155
+ return token + '</w>'
156
+
157
+ while True:
158
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
159
+ if bigram not in self.bpe_ranks:
160
+ break
161
+ first, second = bigram
162
+ new_word = []
163
+ i = 0
164
+ while i < len(word):
165
+ try:
166
+ j = word.index(first, i)
167
+ new_word.extend(word[i:j])
168
+ i = j
169
+ except: # noqa: E722
170
+ new_word.extend(word[i:])
171
+ break
172
+
173
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
174
+ new_word.append(first + second)
175
+ i += 2
176
+ else:
177
+ new_word.append(word[i])
178
+ i += 1
179
+ new_word = tuple(new_word)
180
+ word = new_word
181
+ if len(word) == 1:
182
+ break
183
+ else:
184
+ pairs = get_pairs(word)
185
+ word = ' '.join(word)
186
+ self.cache[token] = word
187
+ return word
188
+
189
+ def encode(self, text):
190
+ bpe_tokens = []
191
+ text = whitespace_clean(basic_clean(text)).lower()
192
+ for token in re.findall(self.pat, text):
193
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
194
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
195
+ return bpe_tokens
196
+
197
+ def decode(self, tokens):
198
+ text = ''.join([self.decoder[token] for token in tokens])
199
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace').replace('</w>', ' ')
200
+ return text
oneformer/evaluation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .detection_coco_evaluator import *
2
+ from .coco_evaluator import *
3
+ from .cityscapes_evaluation import CityscapesInstanceEvaluator