Commit
·
6524e7a
0
Parent(s):
Initial deployment to Hugging Face
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +10 -0
- .gitignore +17 -0
- CLAUDE.md +76 -0
- Dockerfile +61 -0
- NeuroNest/.gitattributes +35 -0
- NeuroNest/README.md +12 -0
- app.py +38 -0
- configs/.DS_Store +0 -0
- configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml +68 -0
- configs/ade20k/oneformer_R50_bs16_160k.yaml +58 -0
- configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml +42 -0
- configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml +40 -0
- configs/cityscapes/.DS_Store +0 -0
- configs/cityscapes/Base-Cityscapes-UnifiedSegmentation.yaml +68 -0
- configs/cityscapes/oneformer_R50_bs16_90k.yaml +59 -0
- configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml +22 -0
- configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml +20 -0
- configs/coco/Base-COCO-UnifiedSegmentation.yaml +54 -0
- configs/coco/oneformer_R50_bs16_50ep.yaml +59 -0
- configs/coco/oneformer_dinat_large_bs16_100ep.yaml +22 -0
- configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml +25 -0
- deform_setup.sh +21 -0
- demo/colormap.py +170 -0
- demo/defaults.py +77 -0
- demo/predictor.py +190 -0
- demo/visualizer.py +1350 -0
- gradio_app.py +222 -0
- gradio_combine.py +384 -0
- gradio_contrast.py +455 -0
- gradio_gpu.py +234 -0
- gradio_test.py +1080 -0
- oneformer/.DS_Store +0 -0
- oneformer/__init__.py +9 -0
- oneformer/config.py +239 -0
- oneformer/data/__init__.py +2 -0
- oneformer/data/bpe_simple_vocab_16e6.txt +0 -0
- oneformer/data/bpe_simple_vocab_16e6.txt.gz +3 -0
- oneformer/data/build.py +117 -0
- oneformer/data/dataset_mappers/__init__.py +1 -0
- oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py +341 -0
- oneformer/data/dataset_mappers/dataset_mapper.py +203 -0
- oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py +375 -0
- oneformer/data/datasets/__init__.py +7 -0
- oneformer/data/datasets/register_ade20k_instance.py +56 -0
- oneformer/data/datasets/register_ade20k_panoptic.py +394 -0
- oneformer/data/datasets/register_cityscapes_panoptic.py +199 -0
- oneformer/data/datasets/register_coco_panoptic2instance.py +44 -0
- oneformer/data/datasets/register_coco_panoptic_annos_semseg.py +367 -0
- oneformer/data/tokenizer.py +200 -0
- oneformer/evaluation/__init__.py +3 -0
.gitattributes
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.o filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.egg filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.out
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
*.pyo
|
5 |
+
*.pyd
|
6 |
+
slurm-*.out
|
7 |
+
output_*/
|
8 |
+
build/
|
9 |
+
dist/
|
10 |
+
*.egg-info/
|
11 |
+
.env
|
12 |
+
.claude/
|
13 |
+
oneformer/modeling/pixel_decoder/ops/build/
|
14 |
+
oneformer/modeling/pixel_decoder/ops/dist/
|
15 |
+
oneformer/modeling/pixel_decoder/ops/*.egg-info/
|
16 |
+
output_floor_blackspot/
|
17 |
+
environment.yml
|
CLAUDE.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLAUDE.md
|
2 |
+
|
3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
4 |
+
|
5 |
+
## Project Overview
|
6 |
+
This is a OneFormer-based universal image segmentation application with integrated contrast detection capabilities. It supports panoptic, instance, and semantic segmentation using transformer-based models (Swin-L and DiNAT-L backbones).
|
7 |
+
|
8 |
+
## Essential Commands
|
9 |
+
|
10 |
+
### Setup and Installation
|
11 |
+
```bash
|
12 |
+
# Download model checkpoints from HuggingFace
|
13 |
+
python setup.py --download
|
14 |
+
|
15 |
+
# Build deformable attention CUDA operations (required)
|
16 |
+
bash deform_setup.sh
|
17 |
+
|
18 |
+
# Install NATTEN if needed
|
19 |
+
cd NATTEN && pip install -e .
|
20 |
+
```
|
21 |
+
|
22 |
+
### Running the Application
|
23 |
+
```bash
|
24 |
+
# Main application launcher
|
25 |
+
bash t.sh
|
26 |
+
|
27 |
+
# Or run directly
|
28 |
+
python gradio_test.py
|
29 |
+
```
|
30 |
+
|
31 |
+
### Testing
|
32 |
+
```bash
|
33 |
+
# Run NATTEN tests
|
34 |
+
cd NATTEN && python -m unittest discover -v -s ./tests
|
35 |
+
```
|
36 |
+
|
37 |
+
## Architecture Overview
|
38 |
+
|
39 |
+
### Core Components
|
40 |
+
1. **OneFormer Model** (`oneformer/`): Universal segmentation transformer implementation
|
41 |
+
- `oneformer_model.py`: Main model class
|
42 |
+
- `modeling/`: Model components including backbones, decoders, and criterion
|
43 |
+
- `data/`: Dataset mappers and tokenizers for text queries
|
44 |
+
|
45 |
+
2. **Gradio Applications** (multiple entry points):
|
46 |
+
- `gradio_test.py`: Main combined interface with pre/post-processing logic
|
47 |
+
- `gradio_app.py`: Basic segmentation interface
|
48 |
+
- `gradio_contrast.py`: Contrast detection focused interface
|
49 |
+
|
50 |
+
3. **Contrast Detection** (`utils/`): Multiple contrast analysis algorithms
|
51 |
+
- `improved_contrast_analyzer.py`: Main analyzer combining all methods
|
52 |
+
- Individual analyzers: luminance, saturation, hue contrast detection
|
53 |
+
|
54 |
+
4. **Model Configurations** (`configs/`): Pre-defined configs for:
|
55 |
+
- ADE20K (150 classes)
|
56 |
+
- Cityscapes (19 classes)
|
57 |
+
- COCO (133 classes)
|
58 |
+
|
59 |
+
### Key Implementation Details
|
60 |
+
- Models are loaded from HuggingFace hub (shi-labs organization)
|
61 |
+
- Supports both CPU and GPU inference (auto-detection)
|
62 |
+
- Custom CUDA kernels for multi-scale deformable attention
|
63 |
+
- NATTEN library provides neighborhood attention operations
|
64 |
+
|
65 |
+
### Model Checkpoints
|
66 |
+
The application uses pre-trained models from:
|
67 |
+
- `shi-labs/oneformer_cityscapes_swin_large`
|
68 |
+
- `shi-labs/oneformer_coco_swin_large`
|
69 |
+
- `shi-labs/oneformer_ade20k_swin_large`
|
70 |
+
- DiNAT variants also available
|
71 |
+
|
72 |
+
### Important Notes
|
73 |
+
- Always run `deform_setup.sh` after environment setup for CUDA operations
|
74 |
+
- The application expects model checkpoints to be downloaded before running
|
75 |
+
- Gradio interface runs on port 7860 by default
|
76 |
+
- GPU memory requirements vary by model (Swin-L models need ~12GB)
|
Dockerfile
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
|
2 |
+
CMD nvidia-smi
|
3 |
+
|
4 |
+
ENV DEBIAN_FRONTEND noninteractive
|
5 |
+
RUN apt-get update && apt-get install -y \
|
6 |
+
git \
|
7 |
+
make build-essential libssl-dev zlib1g-dev \
|
8 |
+
libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \
|
9 |
+
libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev \
|
10 |
+
ffmpeg libsm6 libxext6 cmake libgl1-mesa-glx \
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
+
|
13 |
+
RUN useradd -ms /bin/bash user
|
14 |
+
USER user
|
15 |
+
|
16 |
+
ENV HOME=/home/user \
|
17 |
+
PATH=/home/user/.local/bin:$PATH
|
18 |
+
|
19 |
+
RUN curl https://pyenv.run | bash
|
20 |
+
ENV PATH=$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH
|
21 |
+
RUN pyenv install 3.8.15 && \
|
22 |
+
pyenv global 3.8.15 && \
|
23 |
+
pyenv rehash && \
|
24 |
+
pip install --no-cache-dir --upgrade pip setuptools wheel
|
25 |
+
|
26 |
+
ENV WORKDIR=/code
|
27 |
+
WORKDIR $WORKDIR
|
28 |
+
RUN chown -R user:user $WORKDIR
|
29 |
+
RUN chmod -R 777 $WORKDIR
|
30 |
+
|
31 |
+
COPY requirements.txt $WORKDIR/requirements.txt
|
32 |
+
RUN pip install --no-cache-dir --upgrade -r $WORKDIR/requirements.txt
|
33 |
+
RUN pip install ninja
|
34 |
+
|
35 |
+
COPY . .
|
36 |
+
|
37 |
+
ARG TORCH_CUDA_ARCH_LIST=7.5+PTX
|
38 |
+
|
39 |
+
USER root
|
40 |
+
RUN chown -R user:user $HOME
|
41 |
+
RUN chmod -R 777 $HOME
|
42 |
+
RUN chown -R user:user $WORKDIR
|
43 |
+
RUN chmod -R 777 $WORKDIR
|
44 |
+
|
45 |
+
USER user
|
46 |
+
RUN ln -s $WORKDIR/oneformer/modeling/pixel_decoder/ops/ $WORKDIR/ && ls && cd ops/ && FORCE_CUDA=1 python setup.py build --build-base=$WORKDIR/ install --user && cd ..
|
47 |
+
RUN sh deform_setup.sh
|
48 |
+
|
49 |
+
USER user
|
50 |
+
RUN sh deform_setup.sh
|
51 |
+
|
52 |
+
RUN mkdir -p examples
|
53 |
+
RUN wget https://praeclarumjj3.github.io/files/ade20k.jpeg -P $WORKDIR/examples/
|
54 |
+
RUN wget https://praeclarumjj3.github.io/files/cityscapes.png -P $WORKDIR/examples/
|
55 |
+
RUN wget https://praeclarumjj3.github.io/files/coco.jpeg -P $WORKDIR/examples/
|
56 |
+
|
57 |
+
USER user
|
58 |
+
|
59 |
+
EXPOSE 7860
|
60 |
+
|
61 |
+
ENTRYPOINT ["python", "gradio_app.py"]
|
NeuroNest/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
NeuroNest/README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: NeuroNest
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.30.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Hugging Face Spaces entry point for OneFormer application
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
# Set up environment variables for HF Spaces
|
11 |
+
os.environ['CUDA_HOME'] = '/usr/local/cuda' if os.path.exists('/usr/local/cuda') else ''
|
12 |
+
|
13 |
+
# Install deformable attention ops if not already installed
|
14 |
+
def setup_deformable_attention():
|
15 |
+
ops_dir = os.path.join(os.path.dirname(__file__), 'oneformer/modeling/pixel_decoder/ops')
|
16 |
+
if os.path.exists(ops_dir):
|
17 |
+
try:
|
18 |
+
subprocess.run(['bash', 'deform_setup.sh'], check=True, cwd=os.path.dirname(__file__))
|
19 |
+
print("Deformable attention ops installed successfully")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"Warning: Could not install deformable attention ops: {e}")
|
22 |
+
print("Continuing without custom CUDA kernels...")
|
23 |
+
|
24 |
+
# Run setup on first launch
|
25 |
+
if not os.path.exists('oneformer/modeling/pixel_decoder/ops/build'):
|
26 |
+
setup_deformable_attention()
|
27 |
+
|
28 |
+
# Import and run the main gradio app
|
29 |
+
from gradio_test import demo
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
# Launch with HF Spaces compatible settings
|
33 |
+
demo.launch(
|
34 |
+
server_name="0.0.0.0",
|
35 |
+
server_port=7860,
|
36 |
+
share=False, # Disabled on HF Spaces
|
37 |
+
debug=False
|
38 |
+
)
|
configs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
BACKBONE:
|
3 |
+
FREEZE_AT: 0
|
4 |
+
NAME: "build_resnet_backbone"
|
5 |
+
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
|
6 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
7 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
8 |
+
RESNETS:
|
9 |
+
DEPTH: 50
|
10 |
+
STEM_TYPE: "basic" # not used
|
11 |
+
STEM_OUT_CHANNELS: 64
|
12 |
+
STRIDE_IN_1X1: False
|
13 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
14 |
+
# NORM: "SyncBN"
|
15 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("ade20k_panoptic_train",)
|
18 |
+
TEST_PANOPTIC: ("ade20k_panoptic_val",)
|
19 |
+
TEST_INSTANCE: ("ade20k_instance_val",)
|
20 |
+
TEST_SEMANTIC: ("ade20k_sem_seg_val",)
|
21 |
+
SOLVER:
|
22 |
+
IMS_PER_BATCH: 16
|
23 |
+
BASE_LR: 0.0001
|
24 |
+
MAX_ITER: 160000
|
25 |
+
WARMUP_FACTOR: 1.0
|
26 |
+
WARMUP_ITERS: 0
|
27 |
+
WEIGHT_DECAY: 0.05
|
28 |
+
OPTIMIZER: "ADAMW"
|
29 |
+
LR_SCHEDULER_NAME: "WarmupPolyLR"
|
30 |
+
BACKBONE_MULTIPLIER: 0.1
|
31 |
+
CLIP_GRADIENTS:
|
32 |
+
ENABLED: True
|
33 |
+
CLIP_TYPE: "full_model"
|
34 |
+
CLIP_VALUE: 0.01
|
35 |
+
NORM_TYPE: 2.0
|
36 |
+
AMP:
|
37 |
+
ENABLED: True
|
38 |
+
INPUT:
|
39 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
|
40 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
41 |
+
MIN_SIZE_TEST: 512
|
42 |
+
MAX_SIZE_TRAIN: 2048
|
43 |
+
MAX_SIZE_TEST: 2048
|
44 |
+
CROP:
|
45 |
+
ENABLED: True
|
46 |
+
TYPE: "absolute"
|
47 |
+
SIZE: (512, 512)
|
48 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
49 |
+
COLOR_AUG_SSD: True
|
50 |
+
SIZE_DIVISIBILITY: 512 # used in dataset mapper
|
51 |
+
FORMAT: "RGB"
|
52 |
+
DATASET_MAPPER_NAME: "oneformer_unified"
|
53 |
+
MAX_SEQ_LEN: 77
|
54 |
+
TASK_SEQ_LEN: 77
|
55 |
+
TASK_PROB:
|
56 |
+
SEMANTIC: 0.33
|
57 |
+
INSTANCE: 0.66
|
58 |
+
TEST:
|
59 |
+
EVAL_PERIOD: 5000
|
60 |
+
AUG:
|
61 |
+
ENABLED: False
|
62 |
+
MIN_SIZES: [256, 384, 512, 640, 768, 896]
|
63 |
+
MAX_SIZE: 3584
|
64 |
+
FLIP: True
|
65 |
+
DATALOADER:
|
66 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
67 |
+
NUM_WORKERS: 4
|
68 |
+
VERSION: 2
|
configs/ade20k/oneformer_R50_bs16_160k.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-ADE20K-UnifiedSegmentation.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "OneFormer"
|
4 |
+
SEM_SEG_HEAD:
|
5 |
+
NAME: "OneFormerHead"
|
6 |
+
IGNORE_VALUE: 255
|
7 |
+
NUM_CLASSES: 150
|
8 |
+
LOSS_WEIGHT: 1.0
|
9 |
+
CONVS_DIM: 256
|
10 |
+
MASK_DIM: 256
|
11 |
+
NORM: "GN"
|
12 |
+
# pixel decoder
|
13 |
+
PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
|
14 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
16 |
+
COMMON_STRIDE: 4
|
17 |
+
TRANSFORMER_ENC_LAYERS: 6
|
18 |
+
ONE_FORMER:
|
19 |
+
TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
|
20 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
21 |
+
DEEP_SUPERVISION: True
|
22 |
+
NO_OBJECT_WEIGHT: 0.1
|
23 |
+
CLASS_WEIGHT: 2.0
|
24 |
+
MASK_WEIGHT: 5.0
|
25 |
+
DICE_WEIGHT: 5.0
|
26 |
+
CONTRASTIVE_WEIGHT: 0.5
|
27 |
+
CONTRASTIVE_TEMPERATURE: 0.07
|
28 |
+
HIDDEN_DIM: 256
|
29 |
+
NUM_OBJECT_QUERIES: 150
|
30 |
+
USE_TASK_NORM: True
|
31 |
+
NHEADS: 8
|
32 |
+
DROPOUT: 0.1
|
33 |
+
DIM_FEEDFORWARD: 2048
|
34 |
+
ENC_LAYERS: 0
|
35 |
+
PRE_NORM: False
|
36 |
+
ENFORCE_INPUT_PROJ: False
|
37 |
+
SIZE_DIVISIBILITY: 32
|
38 |
+
CLASS_DEC_LAYERS: 2
|
39 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
40 |
+
TRAIN_NUM_POINTS: 12544
|
41 |
+
OVERSAMPLE_RATIO: 3.0
|
42 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
43 |
+
TEXT_ENCODER:
|
44 |
+
WIDTH: 256
|
45 |
+
CONTEXT_LENGTH: 77
|
46 |
+
NUM_LAYERS: 6
|
47 |
+
VOCAB_SIZE: 49408
|
48 |
+
PROJ_NUM_LAYERS: 2
|
49 |
+
N_CTX: 16
|
50 |
+
TEST:
|
51 |
+
SEMANTIC_ON: True
|
52 |
+
INSTANCE_ON: True
|
53 |
+
PANOPTIC_ON: True
|
54 |
+
OVERLAP_THRESHOLD: 0.8
|
55 |
+
OBJECT_MASK_THRESHOLD: 0.8
|
56 |
+
TASK: "panoptic"
|
57 |
+
TEST:
|
58 |
+
DETECTIONS_PER_IMAGE: 150
|
configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2DiNAT"
|
5 |
+
DiNAT:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
MLP_RATIO: 2.0
|
8 |
+
DEPTHS: [3, 4, 18, 5]
|
9 |
+
NUM_HEADS: [6, 12, 24, 48]
|
10 |
+
KERNEL_SIZE: 11
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
DILATIONS: [[1, 20, 1], [1, 5, 1, 10], [1, 2, 1, 3, 1, 4, 1, 5, 1, 2, 1, 3, 1, 4, 1, 5, 1, 5], [1, 2, 1, 2, 1]]
|
13 |
+
WEIGHTS: "dinat_large_in22k_in1k_384_11x11.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
ONE_FORMER:
|
17 |
+
NUM_OBJECT_QUERIES: 250
|
18 |
+
SOLVER:
|
19 |
+
AMP:
|
20 |
+
ENABLED: False
|
21 |
+
INPUT:
|
22 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
|
23 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
24 |
+
MIN_SIZE_TEST: 640
|
25 |
+
MAX_SIZE_TRAIN: 2560
|
26 |
+
MAX_SIZE_TEST: 2560
|
27 |
+
CROP:
|
28 |
+
ENABLED: True
|
29 |
+
TYPE: "absolute"
|
30 |
+
SIZE: (640, 640)
|
31 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
32 |
+
COLOR_AUG_SSD: True
|
33 |
+
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
34 |
+
FORMAT: "RGB"
|
35 |
+
TEST:
|
36 |
+
DETECTIONS_PER_IMAGE: 250
|
37 |
+
EVAL_PERIOD: 5000
|
38 |
+
AUG:
|
39 |
+
ENABLED: False
|
40 |
+
MIN_SIZES: [320, 480, 640, 800, 960, 1120]
|
41 |
+
MAX_SIZE: 4480
|
42 |
+
FLIP: True
|
configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_160k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [6, 12, 24, 48]
|
9 |
+
WINDOW_SIZE: 12
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
PRETRAIN_IMG_SIZE: 384
|
14 |
+
WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
ONE_FORMER:
|
18 |
+
NUM_OBJECT_QUERIES: 250
|
19 |
+
INPUT:
|
20 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
|
21 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
22 |
+
MIN_SIZE_TEST: 640
|
23 |
+
MAX_SIZE_TRAIN: 2560
|
24 |
+
MAX_SIZE_TEST: 2560
|
25 |
+
CROP:
|
26 |
+
ENABLED: True
|
27 |
+
TYPE: "absolute"
|
28 |
+
SIZE: (640, 640)
|
29 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
30 |
+
COLOR_AUG_SSD: True
|
31 |
+
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
32 |
+
FORMAT: "RGB"
|
33 |
+
TEST:
|
34 |
+
DETECTIONS_PER_IMAGE: 250
|
35 |
+
EVAL_PERIOD: 5000
|
36 |
+
AUG:
|
37 |
+
ENABLED: False
|
38 |
+
MIN_SIZES: [320, 480, 640, 800, 960, 1120]
|
39 |
+
MAX_SIZE: 4480
|
40 |
+
FLIP: True
|
configs/cityscapes/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
configs/cityscapes/Base-Cityscapes-UnifiedSegmentation.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
BACKBONE:
|
3 |
+
FREEZE_AT: 0
|
4 |
+
NAME: "build_resnet_backbone"
|
5 |
+
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
|
6 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
7 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
8 |
+
RESNETS:
|
9 |
+
DEPTH: 50
|
10 |
+
STEM_TYPE: "basic" # not used
|
11 |
+
STEM_OUT_CHANNELS: 64
|
12 |
+
STRIDE_IN_1X1: False
|
13 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
14 |
+
NORM: "SyncBN" # use syncbn for cityscapes dataset
|
15 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("cityscapes_fine_panoptic_train",)
|
18 |
+
TEST_PANOPTIC: ("cityscapes_fine_panoptic_val",)
|
19 |
+
TEST_INSTANCE: ("cityscapes_fine_instance_seg_val",)
|
20 |
+
TEST_SEMANTIC: ("cityscapes_fine_sem_seg_val",)
|
21 |
+
SOLVER:
|
22 |
+
IMS_PER_BATCH: 16
|
23 |
+
BASE_LR: 0.0001
|
24 |
+
MAX_ITER: 90000
|
25 |
+
WARMUP_FACTOR: 1.0
|
26 |
+
WARMUP_ITERS: 0
|
27 |
+
WEIGHT_DECAY: 0.05
|
28 |
+
OPTIMIZER: "ADAMW"
|
29 |
+
LR_SCHEDULER_NAME: "WarmupPolyLR"
|
30 |
+
BACKBONE_MULTIPLIER: 0.1
|
31 |
+
CLIP_GRADIENTS:
|
32 |
+
ENABLED: True
|
33 |
+
CLIP_TYPE: "full_model"
|
34 |
+
CLIP_VALUE: 0.01
|
35 |
+
NORM_TYPE: 2.0
|
36 |
+
AMP:
|
37 |
+
ENABLED: True
|
38 |
+
INPUT:
|
39 |
+
MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 1024) for x in range(5, 21)]"]
|
40 |
+
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
41 |
+
MIN_SIZE_TEST: 1024
|
42 |
+
MAX_SIZE_TRAIN: 4096
|
43 |
+
MAX_SIZE_TEST: 2048
|
44 |
+
CROP:
|
45 |
+
ENABLED: True
|
46 |
+
TYPE: "absolute"
|
47 |
+
SIZE: (512, 1024)
|
48 |
+
SINGLE_CATEGORY_MAX_AREA: 1.0
|
49 |
+
COLOR_AUG_SSD: True
|
50 |
+
SIZE_DIVISIBILITY: -1
|
51 |
+
FORMAT: "RGB"
|
52 |
+
DATASET_MAPPER_NAME: "oneformer_unified"
|
53 |
+
MAX_SEQ_LEN: 77
|
54 |
+
TASK_SEQ_LEN: 77
|
55 |
+
TASK_PROB:
|
56 |
+
SEMANTIC: 0.33
|
57 |
+
INSTANCE: 0.66
|
58 |
+
TEST:
|
59 |
+
EVAL_PERIOD: 5000
|
60 |
+
AUG:
|
61 |
+
ENABLED: False
|
62 |
+
MIN_SIZES: [512, 768, 1024, 1280, 1536, 1792]
|
63 |
+
MAX_SIZE: 4096
|
64 |
+
FLIP: True
|
65 |
+
DATALOADER:
|
66 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
67 |
+
NUM_WORKERS: 4
|
68 |
+
VERSION: 2
|
configs/cityscapes/oneformer_R50_bs16_90k.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-Cityscapes-UnifiedSegmentation.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "OneFormer"
|
4 |
+
SEM_SEG_HEAD:
|
5 |
+
NAME: "OneFormerHead"
|
6 |
+
IGNORE_VALUE: 255
|
7 |
+
NUM_CLASSES: 19
|
8 |
+
LOSS_WEIGHT: 1.0
|
9 |
+
CONVS_DIM: 256
|
10 |
+
MASK_DIM: 256
|
11 |
+
NORM: "GN"
|
12 |
+
# pixel decoder
|
13 |
+
PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
|
14 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
16 |
+
COMMON_STRIDE: 4
|
17 |
+
TRANSFORMER_ENC_LAYERS: 6
|
18 |
+
ONE_FORMER:
|
19 |
+
TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
|
20 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
21 |
+
DEEP_SUPERVISION: True
|
22 |
+
NO_OBJECT_WEIGHT: 0.1
|
23 |
+
CLASS_WEIGHT: 2.0
|
24 |
+
MASK_WEIGHT: 5.0
|
25 |
+
DICE_WEIGHT: 5.0
|
26 |
+
CONTRASTIVE_WEIGHT: 0.5
|
27 |
+
CONTRASTIVE_TEMPERATURE: 0.07
|
28 |
+
HIDDEN_DIM: 256
|
29 |
+
NUM_OBJECT_QUERIES: 150
|
30 |
+
USE_TASK_NORM: True
|
31 |
+
NHEADS: 8
|
32 |
+
DROPOUT: 0.1
|
33 |
+
DIM_FEEDFORWARD: 2048
|
34 |
+
ENC_LAYERS: 0
|
35 |
+
PRE_NORM: False
|
36 |
+
ENFORCE_INPUT_PROJ: False
|
37 |
+
SIZE_DIVISIBILITY: 32
|
38 |
+
ENC_LAYERS: 0
|
39 |
+
CLASS_DEC_LAYERS: 2
|
40 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
41 |
+
TRAIN_NUM_POINTS: 12544
|
42 |
+
OVERSAMPLE_RATIO: 3.0
|
43 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
44 |
+
TEXT_ENCODER:
|
45 |
+
WIDTH: 256
|
46 |
+
CONTEXT_LENGTH: 77
|
47 |
+
NUM_LAYERS: 6
|
48 |
+
VOCAB_SIZE: 49408
|
49 |
+
PROJ_NUM_LAYERS: 2
|
50 |
+
N_CTX: 16
|
51 |
+
TEST:
|
52 |
+
SEMANTIC_ON: True
|
53 |
+
INSTANCE_ON: True
|
54 |
+
PANOPTIC_ON: True
|
55 |
+
OVERLAP_THRESHOLD: 0.8
|
56 |
+
OBJECT_MASK_THRESHOLD: 0.8
|
57 |
+
TASK: "panoptic"
|
58 |
+
TEST:
|
59 |
+
DETECTIONS_PER_IMAGE: 150
|
configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_90k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2DiNAT"
|
5 |
+
DiNAT:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
MLP_RATIO: 2.0
|
8 |
+
DEPTHS: [3, 4, 18, 5]
|
9 |
+
NUM_HEADS: [6, 12, 24, 48]
|
10 |
+
KERNEL_SIZE: 7
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
DILATIONS: [[1, 18, 1], [1, 5, 1, 9], [1, 2, 1, 3, 1, 4, 1, 2, 1, 3, 1, 4, 1, 2, 1, 3, 1, 4], [1, 2, 1, 2, 1]]
|
13 |
+
WEIGHTS: "dinat_large_in22k_224.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
ONE_FORMER:
|
17 |
+
NUM_OBJECT_QUERIES: 250
|
18 |
+
SOLVER:
|
19 |
+
AMP:
|
20 |
+
ENABLED: False
|
21 |
+
TEST:
|
22 |
+
DETECTIONS_PER_IMAGE: 250
|
configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_90k.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [6, 12, 24, 48]
|
9 |
+
WINDOW_SIZE: 12
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
PRETRAIN_IMG_SIZE: 384
|
14 |
+
WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
ONE_FORMER:
|
18 |
+
NUM_OBJECT_QUERIES: 250
|
19 |
+
TEST:
|
20 |
+
DETECTIONS_PER_IMAGE: 250
|
configs/coco/Base-COCO-UnifiedSegmentation.yaml
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MODEL:
|
2 |
+
BACKBONE:
|
3 |
+
FREEZE_AT: 0
|
4 |
+
NAME: "build_resnet_backbone"
|
5 |
+
WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
|
6 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
7 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
8 |
+
RESNETS:
|
9 |
+
DEPTH: 50
|
10 |
+
STEM_TYPE: "basic" # not used
|
11 |
+
STEM_OUT_CHANNELS: 64
|
12 |
+
STRIDE_IN_1X1: False
|
13 |
+
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
14 |
+
# NORM: "SyncBN"
|
15 |
+
RES5_MULTI_GRID: [1, 1, 1] # not used
|
16 |
+
DATASETS:
|
17 |
+
TRAIN: ("coco_2017_train_panoptic_with_sem_seg",)
|
18 |
+
TEST_PANOPTIC: ("coco_2017_val_panoptic_with_sem_seg",) # to evaluate instance and semantic performance as well
|
19 |
+
TEST_INSTANCE: ("coco_2017_val",)
|
20 |
+
TEST_SEMANTIC: ("coco_2017_val_panoptic_with_sem_seg",)
|
21 |
+
SOLVER:
|
22 |
+
IMS_PER_BATCH: 16
|
23 |
+
BASE_LR: 0.0001
|
24 |
+
STEPS: (327778, 355092)
|
25 |
+
MAX_ITER: 368750
|
26 |
+
WARMUP_FACTOR: 1.0
|
27 |
+
WARMUP_ITERS: 10
|
28 |
+
WEIGHT_DECAY: 0.05
|
29 |
+
OPTIMIZER: "ADAMW"
|
30 |
+
BACKBONE_MULTIPLIER: 0.1
|
31 |
+
CLIP_GRADIENTS:
|
32 |
+
ENABLED: True
|
33 |
+
CLIP_TYPE: "full_model"
|
34 |
+
CLIP_VALUE: 0.01
|
35 |
+
NORM_TYPE: 2.0
|
36 |
+
AMP:
|
37 |
+
ENABLED: True
|
38 |
+
INPUT:
|
39 |
+
IMAGE_SIZE: 1024
|
40 |
+
MIN_SCALE: 0.1
|
41 |
+
MAX_SCALE: 2.0
|
42 |
+
FORMAT: "RGB"
|
43 |
+
DATASET_MAPPER_NAME: "coco_unified_lsj"
|
44 |
+
MAX_SEQ_LEN: 77
|
45 |
+
TASK_SEQ_LEN: 77
|
46 |
+
TASK_PROB:
|
47 |
+
SEMANTIC: 0.33
|
48 |
+
INSTANCE: 0.66
|
49 |
+
TEST:
|
50 |
+
EVAL_PERIOD: 5000
|
51 |
+
DATALOADER:
|
52 |
+
FILTER_EMPTY_ANNOTATIONS: True
|
53 |
+
NUM_WORKERS: 4
|
54 |
+
VERSION: 2
|
configs/coco/oneformer_R50_bs16_50ep.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: Base-COCO-UnifiedSegmentation.yaml
|
2 |
+
MODEL:
|
3 |
+
META_ARCHITECTURE: "OneFormer"
|
4 |
+
SEM_SEG_HEAD:
|
5 |
+
NAME: "OneFormerHead"
|
6 |
+
IGNORE_VALUE: 255
|
7 |
+
NUM_CLASSES: 133
|
8 |
+
LOSS_WEIGHT: 1.0
|
9 |
+
CONVS_DIM: 256
|
10 |
+
MASK_DIM: 256
|
11 |
+
NORM: "GN"
|
12 |
+
# pixel decoder
|
13 |
+
PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
|
14 |
+
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
15 |
+
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
16 |
+
COMMON_STRIDE: 4
|
17 |
+
TRANSFORMER_ENC_LAYERS: 6
|
18 |
+
ONE_FORMER:
|
19 |
+
TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
|
20 |
+
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
21 |
+
DEEP_SUPERVISION: True
|
22 |
+
NO_OBJECT_WEIGHT: 0.1
|
23 |
+
CLASS_WEIGHT: 2.0
|
24 |
+
MASK_WEIGHT: 5.0
|
25 |
+
DICE_WEIGHT: 5.0
|
26 |
+
CONTRASTIVE_WEIGHT: 0.5
|
27 |
+
CONTRASTIVE_TEMPERATURE: 0.07
|
28 |
+
HIDDEN_DIM: 256
|
29 |
+
NUM_OBJECT_QUERIES: 150
|
30 |
+
USE_TASK_NORM: True
|
31 |
+
NHEADS: 8
|
32 |
+
DROPOUT: 0.1
|
33 |
+
DIM_FEEDFORWARD: 2048
|
34 |
+
ENC_LAYERS: 0
|
35 |
+
PRE_NORM: False
|
36 |
+
ENFORCE_INPUT_PROJ: False
|
37 |
+
SIZE_DIVISIBILITY: 32
|
38 |
+
CLASS_DEC_LAYERS: 2
|
39 |
+
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
40 |
+
TRAIN_NUM_POINTS: 12544
|
41 |
+
OVERSAMPLE_RATIO: 3.0
|
42 |
+
IMPORTANCE_SAMPLE_RATIO: 0.75
|
43 |
+
TEXT_ENCODER:
|
44 |
+
WIDTH: 256
|
45 |
+
CONTEXT_LENGTH: 77
|
46 |
+
NUM_LAYERS: 6
|
47 |
+
VOCAB_SIZE: 49408
|
48 |
+
PROJ_NUM_LAYERS: 2
|
49 |
+
N_CTX: 16
|
50 |
+
TEST:
|
51 |
+
SEMANTIC_ON: True
|
52 |
+
INSTANCE_ON: True
|
53 |
+
PANOPTIC_ON: True
|
54 |
+
DETECTION_ON: False
|
55 |
+
OVERLAP_THRESHOLD: 0.8
|
56 |
+
OBJECT_MASK_THRESHOLD: 0.8
|
57 |
+
TASK: "panoptic"
|
58 |
+
TEST:
|
59 |
+
DETECTIONS_PER_IMAGE: 150
|
configs/coco/oneformer_dinat_large_bs16_100ep.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_50ep.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2DiNAT"
|
5 |
+
DiNAT:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
MLP_RATIO: 2.0
|
8 |
+
DEPTHS: [3, 4, 18, 5]
|
9 |
+
NUM_HEADS: [6, 12, 24, 48]
|
10 |
+
KERNEL_SIZE: 11
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
DILATIONS: [[1, 20, 1], [1, 5, 1, 10], [1, 2, 1, 3, 1, 4, 1, 5, 1, 2, 1, 3, 1, 4, 1, 5, 1, 5], [1, 2, 1, 2, 1]]
|
13 |
+
WEIGHTS: "dinat_large_in22k_in1k_384_11x11.pkl"
|
14 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
15 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
16 |
+
ONE_FORMER:
|
17 |
+
NUM_OBJECT_QUERIES: 150
|
18 |
+
SOLVER:
|
19 |
+
STEPS: (655556, 710184)
|
20 |
+
MAX_ITER: 737500
|
21 |
+
TEST:
|
22 |
+
DETECTIONS_PER_IMAGE: 150
|
configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_BASE_: oneformer_R50_bs16_50ep.yaml
|
2 |
+
MODEL:
|
3 |
+
BACKBONE:
|
4 |
+
NAME: "D2SwinTransformer"
|
5 |
+
SWIN:
|
6 |
+
EMBED_DIM: 192
|
7 |
+
DEPTHS: [2, 2, 18, 2]
|
8 |
+
NUM_HEADS: [6, 12, 24, 48]
|
9 |
+
WINDOW_SIZE: 12
|
10 |
+
APE: False
|
11 |
+
DROP_PATH_RATE: 0.3
|
12 |
+
PATCH_NORM: True
|
13 |
+
PRETRAIN_IMG_SIZE: 384
|
14 |
+
WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
|
15 |
+
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
16 |
+
PIXEL_STD: [58.395, 57.120, 57.375]
|
17 |
+
ONE_FORMER:
|
18 |
+
NUM_OBJECT_QUERIES: 150
|
19 |
+
SOLVER:
|
20 |
+
STEPS: (655556, 735184)
|
21 |
+
MAX_ITER: 737500
|
22 |
+
AMP:
|
23 |
+
ENABLED: False
|
24 |
+
TEST:
|
25 |
+
DETECTIONS_PER_IMAGE: 150
|
deform_setup.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# ln -s ./oneformer/modeling/pixel_decoder/ops/ ./
|
4 |
+
# ls
|
5 |
+
# cd ops/ && bash make.sh && cd ..
|
6 |
+
echo '----------------------------------------------------------------'
|
7 |
+
echo '----------------------------------------------------------------'
|
8 |
+
pip3 freeze | grep MultiScaleDeformableAttention
|
9 |
+
pip3 freeze | grep torch
|
10 |
+
pip3 freeze | grep detectron2
|
11 |
+
pip3 freeze | grep natten
|
12 |
+
echo '----------------------------------------------------------------'
|
13 |
+
echo '----------------------------------------------------------------'
|
14 |
+
|
15 |
+
# echo '----------------------------------------------------------------'
|
16 |
+
# echo '----------------------------------------------------------------'
|
17 |
+
# cd /home/user/.pyenv/versions/3.8.15/lib/python3.8/site-packages
|
18 |
+
# ls
|
19 |
+
# ls | grep MultiScale
|
20 |
+
# echo '----------------------------------------------------------------'
|
21 |
+
# echo '----------------------------------------------------------------'
|
demo/colormap.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
"""
|
4 |
+
An awesome colormap for really neat visualizations.
|
5 |
+
Copied from Detectron, and removed gray colors.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
random.seed(0)
|
11 |
+
|
12 |
+
__all__ = ["colormap", "random_color", "random_colors"]
|
13 |
+
|
14 |
+
# fmt: off
|
15 |
+
# RGB:
|
16 |
+
# _COLORS = np.array(
|
17 |
+
# [
|
18 |
+
# 0.000, 0.447, 0.741,
|
19 |
+
# 0.850, 0.325, 0.098,
|
20 |
+
# 0.929, 0.694, 0.125,
|
21 |
+
# 0.494, 0.184, 0.556,
|
22 |
+
# 0.466, 0.674, 0.188,
|
23 |
+
# 0.301, 0.745, 0.933,
|
24 |
+
# 0.635, 0.078, 0.184,
|
25 |
+
# 0.300, 0.300, 0.300,
|
26 |
+
# 0.600, 0.600, 0.600,
|
27 |
+
# 1.000, 0.000, 0.000,
|
28 |
+
# 1.000, 0.500, 0.000,
|
29 |
+
# 0.749, 0.749, 0.000,
|
30 |
+
# 0.000, 1.000, 0.000,
|
31 |
+
# 0.000, 0.000, 1.000,
|
32 |
+
# 0.667, 0.000, 1.000,
|
33 |
+
# 0.333, 0.333, 0.000,
|
34 |
+
# 0.333, 0.667, 0.000,
|
35 |
+
# 0.333, 1.000, 0.000,
|
36 |
+
# 0.667, 0.333, 0.000,
|
37 |
+
# 0.667, 0.667, 0.000,
|
38 |
+
# 0.667, 1.000, 0.000,
|
39 |
+
# 1.000, 0.333, 0.000,
|
40 |
+
# 1.000, 0.667, 0.000,
|
41 |
+
# 1.000, 1.000, 0.000,
|
42 |
+
# 0.000, 0.333, 0.500,
|
43 |
+
# 0.000, 0.667, 0.500,
|
44 |
+
# 0.000, 1.000, 0.500,
|
45 |
+
# 0.333, 0.000, 0.500,
|
46 |
+
# 0.333, 0.333, 0.500,
|
47 |
+
# 0.333, 0.667, 0.500,
|
48 |
+
# 0.333, 1.000, 0.500,
|
49 |
+
# 0.667, 0.000, 0.500,
|
50 |
+
# 0.667, 0.333, 0.500,
|
51 |
+
# 0.667, 0.667, 0.500,
|
52 |
+
# 0.667, 1.000, 0.500,
|
53 |
+
# 1.000, 0.000, 0.500,
|
54 |
+
# 1.000, 0.333, 0.500,
|
55 |
+
# 1.000, 0.667, 0.500,
|
56 |
+
# 1.000, 1.000, 0.500,
|
57 |
+
# 0.000, 0.333, 1.000,
|
58 |
+
# 0.000, 0.667, 1.000,
|
59 |
+
# 0.000, 1.000, 1.000,
|
60 |
+
# 0.333, 0.000, 1.000,
|
61 |
+
# 0.333, 0.333, 1.000,
|
62 |
+
# 0.333, 0.667, 1.000,
|
63 |
+
# 0.333, 1.000, 1.000,
|
64 |
+
# 0.667, 0.000, 1.000,
|
65 |
+
# 0.667, 0.333, 1.000,
|
66 |
+
# 0.667, 0.667, 1.000,
|
67 |
+
# 0.667, 1.000, 1.000,
|
68 |
+
# 1.000, 0.000, 1.000,
|
69 |
+
# 1.000, 0.333, 1.000,
|
70 |
+
# 1.000, 0.667, 1.000,
|
71 |
+
# 0.333, 0.000, 0.000,
|
72 |
+
# 0.500, 0.000, 0.000,
|
73 |
+
# 0.667, 0.000, 0.000,
|
74 |
+
# 0.833, 0.000, 0.000,
|
75 |
+
# 1.000, 0.000, 0.000,
|
76 |
+
# 0.000, 0.167, 0.000,
|
77 |
+
# 0.000, 0.333, 0.000,
|
78 |
+
# 0.000, 0.500, 0.000,
|
79 |
+
# 0.000, 0.667, 0.000,
|
80 |
+
# 0.000, 0.833, 0.000,
|
81 |
+
# 0.000, 1.000, 0.000,
|
82 |
+
# 0.000, 0.000, 0.167,
|
83 |
+
# 0.000, 0.000, 0.333,
|
84 |
+
# 0.000, 0.000, 0.500,
|
85 |
+
# 0.000, 0.000, 0.667,
|
86 |
+
# 0.000, 0.000, 0.833,
|
87 |
+
# 0.000, 0.000, 1.000,
|
88 |
+
# 0.000, 0.000, 0.000,
|
89 |
+
# 0.143, 0.143, 0.143,
|
90 |
+
# 0.857, 0.857, 0.857,
|
91 |
+
# 1.000, 1.000, 1.000
|
92 |
+
# ]
|
93 |
+
# ).astype(np.float32).reshape(-1, 3)
|
94 |
+
# fmt: on
|
95 |
+
|
96 |
+
_COLORS = []
|
97 |
+
|
98 |
+
|
99 |
+
def gen_color():
|
100 |
+
color = tuple(np.round(np.random.choice(range(256), size=3)/255, 3))
|
101 |
+
if color not in _COLORS and np.mean(color) != 0.0:
|
102 |
+
_COLORS.append(color)
|
103 |
+
else:
|
104 |
+
gen_color()
|
105 |
+
|
106 |
+
|
107 |
+
for _ in range(300):
|
108 |
+
gen_color()
|
109 |
+
|
110 |
+
|
111 |
+
def colormap(rgb=False, maximum=255):
|
112 |
+
"""
|
113 |
+
Args:
|
114 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
115 |
+
maximum (int): either 255 or 1
|
116 |
+
Returns:
|
117 |
+
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
|
118 |
+
"""
|
119 |
+
assert maximum in [255, 1], maximum
|
120 |
+
c = _COLORS * maximum
|
121 |
+
if not rgb:
|
122 |
+
c = c[:, ::-1]
|
123 |
+
return c
|
124 |
+
|
125 |
+
|
126 |
+
def random_color(rgb=False, maximum=255):
|
127 |
+
"""
|
128 |
+
Args:
|
129 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
130 |
+
maximum (int): either 255 or 1
|
131 |
+
Returns:
|
132 |
+
ndarray: a vector of 3 numbers
|
133 |
+
"""
|
134 |
+
idx = np.random.randint(0, len(_COLORS))
|
135 |
+
ret = _COLORS[idx] * maximum
|
136 |
+
if not rgb:
|
137 |
+
ret = ret[::-1]
|
138 |
+
return ret
|
139 |
+
|
140 |
+
|
141 |
+
def random_colors(N, rgb=False, maximum=255):
|
142 |
+
"""
|
143 |
+
Args:
|
144 |
+
N (int): number of unique colors needed
|
145 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
146 |
+
maximum (int): either 255 or 1
|
147 |
+
Returns:
|
148 |
+
ndarray: a list of random_color
|
149 |
+
"""
|
150 |
+
indices = random.sample(range(len(_COLORS)), N)
|
151 |
+
ret = [_COLORS[i] * maximum for i in indices]
|
152 |
+
if not rgb:
|
153 |
+
ret = [x[::-1] for x in ret]
|
154 |
+
return ret
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
import cv2
|
159 |
+
|
160 |
+
size = 100
|
161 |
+
H, W = 10, 10
|
162 |
+
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
|
163 |
+
for h in range(H):
|
164 |
+
for w in range(W):
|
165 |
+
idx = h * W + w
|
166 |
+
if idx >= len(_COLORS):
|
167 |
+
break
|
168 |
+
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
|
169 |
+
cv2.imshow("a", canvas)
|
170 |
+
cv2.waitKey(0)
|
demo/defaults.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import detectron2.data.transforms as T
|
3 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
4 |
+
from detectron2.data import (
|
5 |
+
MetadataCatalog,
|
6 |
+
)
|
7 |
+
from detectron2.modeling import build_model
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
"DefaultPredictor",
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
class DefaultPredictor:
|
16 |
+
"""
|
17 |
+
Create a simple end-to-end predictor with the given config that runs on
|
18 |
+
single device for a single input image.
|
19 |
+
Compared to using the model directly, this class does the following additions:
|
20 |
+
1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
|
21 |
+
2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
|
22 |
+
3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
|
23 |
+
4. Take one input image and produce a single output, instead of a batch.
|
24 |
+
This is meant for simple demo purposes, so it does the above steps automatically.
|
25 |
+
This is not meant for benchmarks or running complicated inference logic.
|
26 |
+
If you'd like to do anything more complicated, please refer to its source code as
|
27 |
+
examples to build and use the model manually.
|
28 |
+
Attributes:
|
29 |
+
metadata (Metadata): the metadata of the underlying dataset, obtained from
|
30 |
+
cfg.DATASETS.TEST.
|
31 |
+
Examples:
|
32 |
+
::
|
33 |
+
pred = DefaultPredictor(cfg)
|
34 |
+
inputs = cv2.imread("input.jpg")
|
35 |
+
outputs = pred(inputs)
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, cfg):
|
39 |
+
self.cfg = cfg.clone() # cfg can be modified by model
|
40 |
+
self.model = build_model(self.cfg)
|
41 |
+
self.model.eval()
|
42 |
+
if len(cfg.DATASETS.TEST):
|
43 |
+
self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
|
44 |
+
|
45 |
+
checkpointer = DetectionCheckpointer(self.model)
|
46 |
+
checkpointer.load(cfg.MODEL.WEIGHTS)
|
47 |
+
|
48 |
+
self.aug = T.ResizeShortestEdge(
|
49 |
+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
50 |
+
)
|
51 |
+
|
52 |
+
self.input_format = cfg.INPUT.FORMAT
|
53 |
+
assert self.input_format in ["RGB", "BGR"], self.input_format
|
54 |
+
|
55 |
+
def __call__(self, original_image, task):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
59 |
+
Returns:
|
60 |
+
predictions (dict):
|
61 |
+
the output of the model for one image only.
|
62 |
+
See :doc:`/tutorials/models` for details about the format.
|
63 |
+
"""
|
64 |
+
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
65 |
+
# Apply pre-processing to image.
|
66 |
+
if self.input_format == "RGB":
|
67 |
+
# whether the model expects BGR inputs or RGB
|
68 |
+
original_image = original_image[:, :, ::-1]
|
69 |
+
height, width = original_image.shape[:2]
|
70 |
+
image = self.aug.get_transform(original_image).apply_image(original_image)
|
71 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
72 |
+
|
73 |
+
task = f"The task is {task}"
|
74 |
+
|
75 |
+
inputs = {"image": image, "height": height, "width": width, "task": task}
|
76 |
+
predictions = self.model([inputs])[0]
|
77 |
+
return predictions
|
demo/predictor.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# Copied from: https://github.com/facebookresearch/detectron2/blob/master/demo/predictor.py
|
3 |
+
import atexit
|
4 |
+
import bisect
|
5 |
+
import multiprocessing as mp
|
6 |
+
from collections import deque
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from defaults import DefaultPredictor
|
13 |
+
from detectron2.utils.video_visualizer import VideoVisualizer
|
14 |
+
from visualizer import ColorMode, Visualizer
|
15 |
+
|
16 |
+
|
17 |
+
class VisualizationDemo(object):
|
18 |
+
def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
cfg (CfgNode):
|
22 |
+
instance_mode (ColorMode):
|
23 |
+
parallel (bool): whether to run the model in different processes from visualization.
|
24 |
+
Useful since the visualization logic can be slow.
|
25 |
+
"""
|
26 |
+
self.metadata = MetadataCatalog.get(
|
27 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
28 |
+
)
|
29 |
+
if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST[0]:
|
30 |
+
from cityscapesscripts.helpers.labels import labels
|
31 |
+
stuff_colors = [k.color for k in labels if k.trainId != 255]
|
32 |
+
self.metadata = self.metadata.set(stuff_colors=stuff_colors)
|
33 |
+
self.cpu_device = torch.device("cpu")
|
34 |
+
self.instance_mode = instance_mode
|
35 |
+
|
36 |
+
self.parallel = parallel
|
37 |
+
if parallel:
|
38 |
+
num_gpu = torch.cuda.device_count()
|
39 |
+
self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
|
40 |
+
else:
|
41 |
+
self.predictor = DefaultPredictor(cfg)
|
42 |
+
|
43 |
+
def run_on_image(self, image, task, sem_gt, pan_gt, ins_gt, box_gt):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
image (np.ndarray): an image of shape (H, W, C) (in BGR order).
|
47 |
+
This is the format used by OpenCV.
|
48 |
+
Returns:
|
49 |
+
predictions (dict): the output of the model.
|
50 |
+
vis_output (VisImage): the visualized image output.
|
51 |
+
"""
|
52 |
+
vis_output = None
|
53 |
+
# Convert image from OpenCV BGR format to Matplotlib RGB format.
|
54 |
+
image = image[:, :, ::-1]
|
55 |
+
vis_output = {}
|
56 |
+
|
57 |
+
if task == 'panoptic':
|
58 |
+
visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
|
59 |
+
predictions = self.predictor(image, "panoptic")
|
60 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
61 |
+
vis_output['panoptic'] = visualizer.draw_panoptic_seg_predictions(
|
62 |
+
panoptic_seg.to(self.cpu_device), segments_info, alpha=1
|
63 |
+
)
|
64 |
+
|
65 |
+
# visualizer = Visualizer(image, metadata=self.metadata, instance_mode=0)
|
66 |
+
# vis_output['pan_gt'] = visualizer.draw_panoptic_seg(
|
67 |
+
# pan_gt[0].to(self.cpu_device), pan_gt[1], alpha=1
|
68 |
+
# )
|
69 |
+
|
70 |
+
if task == 'panoptic' or task == 'semantic':
|
71 |
+
visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
|
72 |
+
predictions = self.predictor(image, "semantic")
|
73 |
+
vis_output['semantic'] = visualizer.draw_sem_seg(
|
74 |
+
predictions["sem_seg"].argmax(dim=0).to(self.cpu_device), alpha=1
|
75 |
+
)
|
76 |
+
|
77 |
+
# visualizer = Visualizer(image, metadata=self.metadata, instance_mode=1)
|
78 |
+
# vis_output['gt_sem'] = visualizer.draw_sem_seg(
|
79 |
+
# sem_gt.to(self.cpu_device), alpha=1
|
80 |
+
# )
|
81 |
+
|
82 |
+
if task == 'panoptic' or task == 'instance':
|
83 |
+
visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
|
84 |
+
predictions = self.predictor(image, "instance")
|
85 |
+
instances = predictions["instances"].to(self.cpu_device)
|
86 |
+
vis_output['instance'] = visualizer.draw_instance_predictions(predictions=instances, alpha=1)
|
87 |
+
|
88 |
+
if 'boxes' in predictions:
|
89 |
+
boxes, labels, scores = predictions["boxes"]
|
90 |
+
visualizer = Visualizer(image, False, metadata=self.metadata, instance_mode=0)
|
91 |
+
vis_output['boxes'] = visualizer.draw_box_predictions(
|
92 |
+
boxes.to(self.cpu_device), labels.to(self.cpu_device), scores.to(self.cpu_device))
|
93 |
+
|
94 |
+
|
95 |
+
# visualizer = Visualizer(image, metadata=self.metadata, instance_mode=2)
|
96 |
+
# vis_output['ins_gt'] = visualizer.draw_instance_predictions(predictions=ins_gt.to(self.cpu_device), alpha=1)
|
97 |
+
# vis_output['input'] = visualizer.get_image(image)
|
98 |
+
|
99 |
+
return predictions, vis_output
|
100 |
+
|
101 |
+
|
102 |
+
class AsyncPredictor:
|
103 |
+
"""
|
104 |
+
A predictor that runs the model asynchronously, possibly on >1 GPUs.
|
105 |
+
Because rendering the visualization takes considerably amount of time,
|
106 |
+
this helps improve throughput a little bit when rendering videos.
|
107 |
+
"""
|
108 |
+
|
109 |
+
class _StopToken:
|
110 |
+
pass
|
111 |
+
|
112 |
+
class _PredictWorker(mp.Process):
|
113 |
+
def __init__(self, cfg, task_queue, result_queue):
|
114 |
+
self.cfg = cfg
|
115 |
+
self.task_queue = task_queue
|
116 |
+
self.result_queue = result_queue
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
def run(self):
|
120 |
+
predictor = DefaultPredictor(self.cfg)
|
121 |
+
|
122 |
+
while True:
|
123 |
+
task = self.task_queue.get()
|
124 |
+
if isinstance(task, AsyncPredictor._StopToken):
|
125 |
+
break
|
126 |
+
idx, data = task
|
127 |
+
result = predictor(data)
|
128 |
+
self.result_queue.put((idx, result))
|
129 |
+
|
130 |
+
def __init__(self, cfg, num_gpus: int = 1):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
cfg (CfgNode):
|
134 |
+
num_gpus (int): if 0, will run on CPU
|
135 |
+
"""
|
136 |
+
num_workers = max(num_gpus, 1)
|
137 |
+
self.task_queue = mp.Queue(maxsize=num_workers * 3)
|
138 |
+
self.result_queue = mp.Queue(maxsize=num_workers * 3)
|
139 |
+
self.procs = []
|
140 |
+
for gpuid in range(max(num_gpus, 1)):
|
141 |
+
cfg = cfg.clone()
|
142 |
+
cfg.defrost()
|
143 |
+
cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
|
144 |
+
self.procs.append(
|
145 |
+
AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
|
146 |
+
)
|
147 |
+
|
148 |
+
self.put_idx = 0
|
149 |
+
self.get_idx = 0
|
150 |
+
self.result_rank = []
|
151 |
+
self.result_data = []
|
152 |
+
|
153 |
+
for p in self.procs:
|
154 |
+
p.start()
|
155 |
+
atexit.register(self.shutdown)
|
156 |
+
|
157 |
+
def put(self, image):
|
158 |
+
self.put_idx += 1
|
159 |
+
self.task_queue.put((self.put_idx, image))
|
160 |
+
|
161 |
+
def get(self):
|
162 |
+
self.get_idx += 1 # the index needed for this request
|
163 |
+
if len(self.result_rank) and self.result_rank[0] == self.get_idx:
|
164 |
+
res = self.result_data[0]
|
165 |
+
del self.result_data[0], self.result_rank[0]
|
166 |
+
return res
|
167 |
+
|
168 |
+
while True:
|
169 |
+
# make sure the results are returned in the correct order
|
170 |
+
idx, res = self.result_queue.get()
|
171 |
+
if idx == self.get_idx:
|
172 |
+
return res
|
173 |
+
insert = bisect.bisect(self.result_rank, idx)
|
174 |
+
self.result_rank.insert(insert, idx)
|
175 |
+
self.result_data.insert(insert, res)
|
176 |
+
|
177 |
+
def __len__(self):
|
178 |
+
return self.put_idx - self.get_idx
|
179 |
+
|
180 |
+
def __call__(self, image):
|
181 |
+
self.put(image)
|
182 |
+
return self.get()
|
183 |
+
|
184 |
+
def shutdown(self):
|
185 |
+
for _ in self.procs:
|
186 |
+
self.task_queue.put(AsyncPredictor._StopToken())
|
187 |
+
|
188 |
+
@property
|
189 |
+
def default_buffer_size(self):
|
190 |
+
return len(self.procs) * 5
|
demo/visualizer.py
ADDED
@@ -0,0 +1,1350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
import colorsys
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
from enum import Enum, unique
|
7 |
+
import cv2
|
8 |
+
import matplotlib as mpl
|
9 |
+
import matplotlib.colors as mplc
|
10 |
+
import matplotlib.figure as mplfigure
|
11 |
+
import pycocotools.mask as mask_util
|
12 |
+
import torch
|
13 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from detectron2.data import MetadataCatalog
|
17 |
+
from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
|
18 |
+
from detectron2.utils.file_io import PathManager
|
19 |
+
import random
|
20 |
+
random.seed(0)
|
21 |
+
from .colormap import random_color, _COLORS
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
__all__ = ["ColorMode", "VisImage", "Visualizer"]
|
25 |
+
|
26 |
+
|
27 |
+
_SMALL_OBJECT_AREA_THRESH = 1000
|
28 |
+
_LARGE_MASK_AREA_THRESH = 120000
|
29 |
+
_OFF_WHITE = (1.0, 1.0, 240.0 / 255)
|
30 |
+
_BLACK = (0, 0, 0)
|
31 |
+
_RED = (1.0, 0, 0)
|
32 |
+
|
33 |
+
_KEYPOINT_THRESHOLD = 0.05
|
34 |
+
|
35 |
+
|
36 |
+
def instance_color(rgb=False, idx=1, maximum=255):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
rgb (bool): whether to return RGB colors or BGR colors.
|
40 |
+
maximum (int): either 255 or 1
|
41 |
+
Returns:
|
42 |
+
ndarray: a vector of 3 numbers
|
43 |
+
"""
|
44 |
+
ret = _COLORS[idx] * maximum
|
45 |
+
if not rgb:
|
46 |
+
ret = ret[::-1]
|
47 |
+
return ret
|
48 |
+
|
49 |
+
@unique
|
50 |
+
class ColorMode(Enum):
|
51 |
+
"""
|
52 |
+
Enum of different color modes to use for instance visualizations.
|
53 |
+
"""
|
54 |
+
|
55 |
+
IMAGE = 0
|
56 |
+
"""
|
57 |
+
Picks a random color for every instance and overlay segmentations with low opacity.
|
58 |
+
"""
|
59 |
+
SEGMENTATION = 1
|
60 |
+
"""
|
61 |
+
Let instances of the same category have similar colors
|
62 |
+
(from metadata.thing_colors), and overlay them with
|
63 |
+
high opacity. This provides more attention on the quality of segmentation.
|
64 |
+
"""
|
65 |
+
IMAGE_BW = 2
|
66 |
+
"""
|
67 |
+
Same as IMAGE, but convert all areas without masks to gray-scale.
|
68 |
+
Only available for drawing per-instance mask predictions.
|
69 |
+
"""
|
70 |
+
|
71 |
+
|
72 |
+
class GenericMask:
|
73 |
+
"""
|
74 |
+
Attribute:
|
75 |
+
polygons (list[ndarray]): list[ndarray]: polygons for this mask.
|
76 |
+
Each ndarray has format [x, y, x, y, ...]
|
77 |
+
mask (ndarray): a binary mask
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, mask_or_polygons, height, width):
|
81 |
+
self._mask = self._polygons = self._has_holes = None
|
82 |
+
self.height = height
|
83 |
+
self.width = width
|
84 |
+
|
85 |
+
m = mask_or_polygons
|
86 |
+
if isinstance(m, dict):
|
87 |
+
# RLEs
|
88 |
+
assert "counts" in m and "size" in m
|
89 |
+
if isinstance(m["counts"], list): # uncompressed RLEs
|
90 |
+
h, w = m["size"]
|
91 |
+
assert h == height and w == width
|
92 |
+
m = mask_util.frPyObjects(m, h, w)
|
93 |
+
self._mask = mask_util.decode(m)[:, :]
|
94 |
+
return
|
95 |
+
|
96 |
+
if isinstance(m, list): # list[ndarray]
|
97 |
+
self._polygons = [np.asarray(x).reshape(-1) for x in m]
|
98 |
+
return
|
99 |
+
|
100 |
+
if isinstance(m, np.ndarray): # assumed to be a binary mask
|
101 |
+
assert m.shape[1] != 2, m.shape
|
102 |
+
assert m.shape == (
|
103 |
+
height,
|
104 |
+
width,
|
105 |
+
), f"mask shape: {m.shape}, target dims: {height}, {width}"
|
106 |
+
self._mask = m.astype("uint8")
|
107 |
+
return
|
108 |
+
|
109 |
+
raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m)))
|
110 |
+
|
111 |
+
@property
|
112 |
+
def mask(self):
|
113 |
+
if self._mask is None:
|
114 |
+
self._mask = self.polygons_to_mask(self._polygons)
|
115 |
+
return self._mask
|
116 |
+
|
117 |
+
@property
|
118 |
+
def polygons(self):
|
119 |
+
if self._polygons is None:
|
120 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
121 |
+
return self._polygons
|
122 |
+
|
123 |
+
@property
|
124 |
+
def has_holes(self):
|
125 |
+
if self._has_holes is None:
|
126 |
+
if self._mask is not None:
|
127 |
+
self._polygons, self._has_holes = self.mask_to_polygons(self._mask)
|
128 |
+
else:
|
129 |
+
self._has_holes = False # if original format is polygon, does not have holes
|
130 |
+
return self._has_holes
|
131 |
+
|
132 |
+
def mask_to_polygons(self, mask):
|
133 |
+
# cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level
|
134 |
+
# hierarchy. External contours (boundary) of the object are placed in hierarchy-1.
|
135 |
+
# Internal contours (holes) are placed in hierarchy-2.
|
136 |
+
# cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours.
|
137 |
+
mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr
|
138 |
+
res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
139 |
+
hierarchy = res[-1]
|
140 |
+
if hierarchy is None: # empty mask
|
141 |
+
return [], False
|
142 |
+
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
|
143 |
+
res = res[-2]
|
144 |
+
res = [x.flatten() for x in res]
|
145 |
+
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
|
146 |
+
# We add 0.5 to turn them into real-value coordinate space. A better solution
|
147 |
+
# would be to first +0.5 and then dilate the returned polygon by 0.5.
|
148 |
+
res = [x + 0.5 for x in res if len(x) >= 6]
|
149 |
+
return res, has_holes
|
150 |
+
|
151 |
+
def polygons_to_mask(self, polygons):
|
152 |
+
rle = mask_util.frPyObjects(polygons, self.height, self.width)
|
153 |
+
rle = mask_util.merge(rle)
|
154 |
+
return mask_util.decode(rle)[:, :]
|
155 |
+
|
156 |
+
def area(self):
|
157 |
+
return self.mask.sum()
|
158 |
+
|
159 |
+
def bbox(self):
|
160 |
+
p = mask_util.frPyObjects(self.polygons, self.height, self.width)
|
161 |
+
p = mask_util.merge(p)
|
162 |
+
bbox = mask_util.toBbox(p)
|
163 |
+
bbox[2] += bbox[0]
|
164 |
+
bbox[3] += bbox[1]
|
165 |
+
return bbox
|
166 |
+
|
167 |
+
|
168 |
+
class _PanopticPrediction:
|
169 |
+
"""
|
170 |
+
Unify different panoptic annotation/prediction formats
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, panoptic_seg, segments_info, metadata=None):
|
174 |
+
if segments_info is None:
|
175 |
+
assert metadata is not None
|
176 |
+
# If "segments_info" is None, we assume "panoptic_img" is a
|
177 |
+
# H*W int32 image storing the panoptic_id in the format of
|
178 |
+
# category_id * label_divisor + instance_id. We reserve -1 for
|
179 |
+
# VOID label.
|
180 |
+
label_divisor = metadata.label_divisor
|
181 |
+
segments_info = []
|
182 |
+
for panoptic_label in np.unique(panoptic_seg.numpy()):
|
183 |
+
if panoptic_label == -1:
|
184 |
+
# VOID region.
|
185 |
+
continue
|
186 |
+
pred_class = panoptic_label // label_divisor
|
187 |
+
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
188 |
+
segments_info.append(
|
189 |
+
{
|
190 |
+
"id": int(panoptic_label),
|
191 |
+
"category_id": int(pred_class),
|
192 |
+
"isthing": bool(isthing),
|
193 |
+
}
|
194 |
+
)
|
195 |
+
del metadata
|
196 |
+
|
197 |
+
self._seg = panoptic_seg
|
198 |
+
|
199 |
+
self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info
|
200 |
+
segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True)
|
201 |
+
areas = areas.numpy()
|
202 |
+
sorted_idxs = np.argsort(-areas)
|
203 |
+
self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs]
|
204 |
+
self._seg_ids = self._seg_ids.tolist()
|
205 |
+
for sid, area in zip(self._seg_ids, self._seg_areas):
|
206 |
+
if sid in self._sinfo:
|
207 |
+
self._sinfo[sid]["area"] = float(area)
|
208 |
+
|
209 |
+
def non_empty_mask(self):
|
210 |
+
"""
|
211 |
+
Returns:
|
212 |
+
(H, W) array, a mask for all pixels that have a prediction
|
213 |
+
"""
|
214 |
+
empty_ids = []
|
215 |
+
for id in self._seg_ids:
|
216 |
+
if id not in self._sinfo:
|
217 |
+
empty_ids.append(id)
|
218 |
+
if len(empty_ids) == 0:
|
219 |
+
return np.zeros(self._seg.shape, dtype=np.uint8)
|
220 |
+
assert (
|
221 |
+
len(empty_ids) == 1
|
222 |
+
), ">1 ids corresponds to no labels. This is currently not supported"
|
223 |
+
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
|
224 |
+
|
225 |
+
def semantic_masks(self):
|
226 |
+
for sid in self._seg_ids:
|
227 |
+
sinfo = self._sinfo.get(sid)
|
228 |
+
if sinfo is None or sinfo["isthing"]:
|
229 |
+
# Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions.
|
230 |
+
continue
|
231 |
+
yield (self._seg == sid).numpy().astype(np.bool), sinfo
|
232 |
+
|
233 |
+
def instance_masks(self):
|
234 |
+
for sid in self._seg_ids:
|
235 |
+
sinfo = self._sinfo.get(sid)
|
236 |
+
if sinfo is None or not sinfo["isthing"]:
|
237 |
+
continue
|
238 |
+
mask = (self._seg == sid).numpy().astype(np.bool)
|
239 |
+
if mask.sum() > 0:
|
240 |
+
yield mask, sinfo
|
241 |
+
|
242 |
+
|
243 |
+
def _create_text_labels(classes, scores, class_names, is_crowd=None):
|
244 |
+
"""
|
245 |
+
Args:
|
246 |
+
classes (list[int] or None):
|
247 |
+
scores (list[float] or None):
|
248 |
+
class_names (list[str] or None):
|
249 |
+
is_crowd (list[bool] or None):
|
250 |
+
Returns:
|
251 |
+
list[str] or None
|
252 |
+
"""
|
253 |
+
labels = None
|
254 |
+
if classes is not None:
|
255 |
+
if class_names is not None and len(class_names) > 0:
|
256 |
+
labels = [class_names[i] for i in classes]
|
257 |
+
else:
|
258 |
+
labels = [str(i) for i in classes]
|
259 |
+
if scores is not None:
|
260 |
+
if labels is None:
|
261 |
+
labels = ["{:.0f}%".format(s * 100) for s in scores]
|
262 |
+
else:
|
263 |
+
labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)]
|
264 |
+
if labels is not None and is_crowd is not None:
|
265 |
+
labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)]
|
266 |
+
return labels
|
267 |
+
|
268 |
+
|
269 |
+
class VisImage:
|
270 |
+
def __init__(self, img, scale=1.0):
|
271 |
+
"""
|
272 |
+
Args:
|
273 |
+
img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255].
|
274 |
+
scale (float): scale the input image
|
275 |
+
"""
|
276 |
+
self.img = img
|
277 |
+
self.scale = scale
|
278 |
+
self.width, self.height = img.shape[1], img.shape[0]
|
279 |
+
self._setup_figure(img)
|
280 |
+
|
281 |
+
def _setup_figure(self, img):
|
282 |
+
"""
|
283 |
+
Args:
|
284 |
+
Same as in :meth:`__init__()`.
|
285 |
+
Returns:
|
286 |
+
fig (matplotlib.pyplot.figure): top level container for all the image plot elements.
|
287 |
+
ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system.
|
288 |
+
"""
|
289 |
+
fig = mplfigure.Figure(frameon=False)
|
290 |
+
self.dpi = fig.get_dpi()
|
291 |
+
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
|
292 |
+
# (https://github.com/matplotlib/matplotlib/issues/15363)
|
293 |
+
fig.set_size_inches(
|
294 |
+
(self.width * self.scale + 1e-2) / self.dpi,
|
295 |
+
(self.height * self.scale + 1e-2) / self.dpi,
|
296 |
+
)
|
297 |
+
self.canvas = FigureCanvasAgg(fig)
|
298 |
+
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
|
299 |
+
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
|
300 |
+
ax.axis("off")
|
301 |
+
self.fig = fig
|
302 |
+
self.ax = ax
|
303 |
+
self.reset_image(img)
|
304 |
+
|
305 |
+
def reset_image(self, img):
|
306 |
+
"""
|
307 |
+
Args:
|
308 |
+
img: same as in __init__
|
309 |
+
"""
|
310 |
+
img = img.astype("uint8")
|
311 |
+
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
|
312 |
+
|
313 |
+
def save(self, filepath):
|
314 |
+
"""
|
315 |
+
Args:
|
316 |
+
filepath (str): a string that contains the absolute path, including the file name, where
|
317 |
+
the visualized image will be saved.
|
318 |
+
"""
|
319 |
+
self.fig.savefig(filepath)
|
320 |
+
|
321 |
+
def get_image(self):
|
322 |
+
"""
|
323 |
+
Returns:
|
324 |
+
ndarray:
|
325 |
+
the visualized image of shape (H, W, 3) (RGB) in uint8 type.
|
326 |
+
The shape is scaled w.r.t the input image using the given `scale` argument.
|
327 |
+
"""
|
328 |
+
canvas = self.canvas
|
329 |
+
s, (width, height) = canvas.print_to_buffer()
|
330 |
+
# buf = io.BytesIO() # works for cairo backend
|
331 |
+
# canvas.print_rgba(buf)
|
332 |
+
# width, height = self.width, self.height
|
333 |
+
# s = buf.getvalue()
|
334 |
+
|
335 |
+
buffer = np.frombuffer(s, dtype="uint8")
|
336 |
+
|
337 |
+
img_rgba = buffer.reshape(height, width, 4)
|
338 |
+
rgb, alpha = np.split(img_rgba, [3], axis=2)
|
339 |
+
return rgb.astype("uint8")
|
340 |
+
|
341 |
+
|
342 |
+
class Visualizer:
|
343 |
+
"""
|
344 |
+
Visualizer that draws data about detection/segmentation on images.
|
345 |
+
It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}`
|
346 |
+
that draw primitive objects to images, as well as high-level wrappers like
|
347 |
+
`draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}`
|
348 |
+
that draw composite data in some pre-defined style.
|
349 |
+
Note that the exact visualization style for the high-level wrappers are subject to change.
|
350 |
+
Style such as color, opacity, label contents, visibility of labels, or even the visibility
|
351 |
+
of objects themselves (e.g. when the object is too small) may change according
|
352 |
+
to different heuristics, as long as the results still look visually reasonable.
|
353 |
+
To obtain a consistent style, you can implement custom drawing functions with the
|
354 |
+
abovementioned primitive methods instead. If you need more customized visualization
|
355 |
+
styles, you can process the data yourself following their format documented in
|
356 |
+
tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not
|
357 |
+
intend to satisfy everyone's preference on drawing styles.
|
358 |
+
This visualizer focuses on high rendering quality rather than performance. It is not
|
359 |
+
designed to be used for real-time applications.
|
360 |
+
"""
|
361 |
+
|
362 |
+
# TODO implement a fast, rasterized version using OpenCV
|
363 |
+
|
364 |
+
def __init__(self, img_rgb, is_img=True, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE):
|
365 |
+
"""
|
366 |
+
Args:
|
367 |
+
img_rgb: a numpy array of shape (H, W, C), where H and W correspond to
|
368 |
+
the height and width of the image respectively. C is the number of
|
369 |
+
color channels. The image is required to be in RGB format since that
|
370 |
+
is a requirement of the Matplotlib library. The image is also expected
|
371 |
+
to be in the range [0, 255].
|
372 |
+
metadata (Metadata): dataset metadata (e.g. class names and colors)
|
373 |
+
instance_mode (ColorMode): defines one of the pre-defined style for drawing
|
374 |
+
instances on an image.
|
375 |
+
"""
|
376 |
+
if is_img:
|
377 |
+
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
|
378 |
+
else:
|
379 |
+
self.img = np.zeros_like(img_rgb).clip(0, 255).astype(np.uint8)
|
380 |
+
if metadata is None:
|
381 |
+
metadata = MetadataCatalog.get("__nonexist__")
|
382 |
+
self.metadata = metadata
|
383 |
+
self.output = VisImage(self.img, scale=scale)
|
384 |
+
self.cpu_device = torch.device("cpu")
|
385 |
+
|
386 |
+
# too small texts are useless, therefore clamp to 9
|
387 |
+
self._default_font_size = max(
|
388 |
+
np.sqrt(self.output.height * self.output.width) // 90, 10 // scale
|
389 |
+
)
|
390 |
+
self._instance_mode = instance_mode
|
391 |
+
self.keypoint_threshold = _KEYPOINT_THRESHOLD
|
392 |
+
|
393 |
+
def get_image(self, img):
|
394 |
+
img = np.asarray(img).clip(0, 255).astype(np.uint8)
|
395 |
+
return VisImage(img, scale=1.0)
|
396 |
+
|
397 |
+
def draw_box_predictions(
|
398 |
+
self,
|
399 |
+
boxes=None,
|
400 |
+
labels=None,
|
401 |
+
scores=None,
|
402 |
+
assigned_colors=None
|
403 |
+
):
|
404 |
+
"""
|
405 |
+
Args:
|
406 |
+
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
|
407 |
+
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
|
408 |
+
or a :class:`RotatedBoxes`,
|
409 |
+
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
|
410 |
+
for the N objects in a single image,
|
411 |
+
labels (list[str]): the text to be displayed for each instance.
|
412 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
413 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
414 |
+
for full list of formats that the colors are accepted in.
|
415 |
+
Returns:
|
416 |
+
output (VisImage): image object with visualizations.
|
417 |
+
"""
|
418 |
+
num_instances = 0
|
419 |
+
boxes = self._convert_boxes(boxes)
|
420 |
+
classes = labels.tolist()
|
421 |
+
scores = scores.tolist()
|
422 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
|
423 |
+
num_instances = len(boxes)
|
424 |
+
assert len(labels) == num_instances
|
425 |
+
if assigned_colors is None:
|
426 |
+
# assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
427 |
+
assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
|
428 |
+
if num_instances == 0:
|
429 |
+
return self.output
|
430 |
+
|
431 |
+
# Display in largest to smallest order to reduce occlusion.
|
432 |
+
areas = None
|
433 |
+
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
434 |
+
|
435 |
+
if areas is not None:
|
436 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
437 |
+
# Re-order overlapped instances in descending order.
|
438 |
+
boxes = boxes[sorted_idxs] if boxes is not None else None
|
439 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
440 |
+
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
441 |
+
|
442 |
+
for i in range(num_instances):
|
443 |
+
color = assigned_colors[i]
|
444 |
+
if boxes is not None:
|
445 |
+
self.draw_box(boxes[i], edge_color=color)
|
446 |
+
|
447 |
+
if labels is not None:
|
448 |
+
# first get a box
|
449 |
+
if boxes is not None:
|
450 |
+
x0, y0, x1, y1 = boxes[i]
|
451 |
+
text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
|
452 |
+
horiz_align = "left"
|
453 |
+
else:
|
454 |
+
continue # drawing the box confidence for keypoints isn't very useful.
|
455 |
+
# for small objects, draw text at the side to avoid occlusion
|
456 |
+
instance_area = (y1 - y0) * (x1 - x0)
|
457 |
+
if (
|
458 |
+
instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
|
459 |
+
or y1 - y0 < 40 * self.output.scale
|
460 |
+
):
|
461 |
+
if y1 >= self.output.height - 5:
|
462 |
+
text_pos = (x1, y0)
|
463 |
+
else:
|
464 |
+
text_pos = (x0, y1)
|
465 |
+
|
466 |
+
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
|
467 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
468 |
+
font_size = (
|
469 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
470 |
+
* 0.5
|
471 |
+
* self._default_font_size
|
472 |
+
)
|
473 |
+
self.draw_text(
|
474 |
+
labels[i],
|
475 |
+
text_pos,
|
476 |
+
color=lighter_color,
|
477 |
+
horizontal_alignment=horiz_align,
|
478 |
+
font_size=font_size,
|
479 |
+
)
|
480 |
+
|
481 |
+
return self.output
|
482 |
+
|
483 |
+
|
484 |
+
def draw_instance_predictions(self, predictions, alpha=0.8, is_text=True):
|
485 |
+
"""
|
486 |
+
Draw instance-level prediction results on an image.
|
487 |
+
Args:
|
488 |
+
predictions (Instances): the output of an instance detection/segmentation
|
489 |
+
model. Following fields will be used to draw:
|
490 |
+
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
|
491 |
+
Returns:
|
492 |
+
output (VisImage): image object with visualizations.
|
493 |
+
"""
|
494 |
+
boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
|
495 |
+
scores = predictions.scores if predictions.has("scores") else None
|
496 |
+
classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
|
497 |
+
labels = _create_text_labels(classes, scores, self.metadata.get("stuff_classes", None))
|
498 |
+
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None
|
499 |
+
|
500 |
+
if predictions.has("pred_masks"):
|
501 |
+
masks = np.asarray(predictions.pred_masks)
|
502 |
+
masks = [GenericMask(x, self.output.height, self.output.width) for x in masks]
|
503 |
+
else:
|
504 |
+
masks = None
|
505 |
+
|
506 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
|
507 |
+
# colors = [
|
508 |
+
# self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes
|
509 |
+
# ]
|
510 |
+
colors = [
|
511 |
+
instance_color(rgb=True, idx=c, maximum=1) for c in classes
|
512 |
+
]
|
513 |
+
else:
|
514 |
+
colors = None
|
515 |
+
|
516 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
517 |
+
self.output.reset_image(
|
518 |
+
self._create_grayscale_image(
|
519 |
+
(predictions.pred_masks.any(dim=0) > 0).numpy()
|
520 |
+
if predictions.has("pred_masks")
|
521 |
+
else None
|
522 |
+
)
|
523 |
+
)
|
524 |
+
|
525 |
+
self.overlay_instances(
|
526 |
+
masks=masks,
|
527 |
+
boxes=boxes,
|
528 |
+
labels=labels,
|
529 |
+
keypoints=keypoints,
|
530 |
+
assigned_colors=colors,
|
531 |
+
alpha=alpha,
|
532 |
+
is_text=is_text,
|
533 |
+
)
|
534 |
+
return self.output
|
535 |
+
|
536 |
+
def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8, is_text=True):
|
537 |
+
"""
|
538 |
+
Draw semantic segmentation predictions/labels.
|
539 |
+
Args:
|
540 |
+
sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
|
541 |
+
Each value is the integer label of the pixel.
|
542 |
+
area_threshold (int): segments with less than `area_threshold` are not drawn.
|
543 |
+
alpha (float): the larger it is, the more opaque the segmentations are.
|
544 |
+
Returns:
|
545 |
+
output (VisImage): image object with visualizations.
|
546 |
+
"""
|
547 |
+
if isinstance(sem_seg, torch.Tensor):
|
548 |
+
sem_seg = sem_seg.numpy()
|
549 |
+
labels, areas = np.unique(sem_seg, return_counts=True)
|
550 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
551 |
+
labels = labels[sorted_idxs]
|
552 |
+
for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels):
|
553 |
+
try:
|
554 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
|
555 |
+
except (AttributeError, IndexError):
|
556 |
+
mask_color = None
|
557 |
+
|
558 |
+
binary_mask = (sem_seg == label).astype(np.uint8)
|
559 |
+
text = self.metadata.stuff_classes[label]
|
560 |
+
self.draw_binary_mask(
|
561 |
+
binary_mask,
|
562 |
+
color=mask_color,
|
563 |
+
edge_color=_OFF_WHITE,
|
564 |
+
text=text,
|
565 |
+
alpha=alpha,
|
566 |
+
area_threshold=area_threshold,
|
567 |
+
is_text=is_text,
|
568 |
+
)
|
569 |
+
return self.output
|
570 |
+
|
571 |
+
def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7, is_text=True,):
|
572 |
+
"""
|
573 |
+
Draw panoptic prediction annotations or results.
|
574 |
+
Args:
|
575 |
+
panoptic_seg (Tensor): of shape (height, width) where the values are ids for each
|
576 |
+
segment.
|
577 |
+
segments_info (list[dict] or None): Describe each segment in `panoptic_seg`.
|
578 |
+
If it is a ``list[dict]``, each dict contains keys "id", "category_id".
|
579 |
+
If None, category id of each pixel is computed by
|
580 |
+
``pixel // metadata.label_divisor``.
|
581 |
+
area_threshold (int): stuff segments with less than `area_threshold` are not drawn.
|
582 |
+
Returns:
|
583 |
+
output (VisImage): image object with visualizations.
|
584 |
+
"""
|
585 |
+
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
|
586 |
+
|
587 |
+
if self._instance_mode == ColorMode.IMAGE_BW:
|
588 |
+
self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
|
589 |
+
|
590 |
+
# draw mask for all semantic segments first i.e. "stuff"
|
591 |
+
for mask, sinfo in pred.semantic_masks():
|
592 |
+
category_idx = sinfo["category_id"]
|
593 |
+
try:
|
594 |
+
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]]
|
595 |
+
except AttributeError:
|
596 |
+
mask_color = None
|
597 |
+
|
598 |
+
text = self.metadata.stuff_classes[category_idx]
|
599 |
+
self.draw_binary_mask(
|
600 |
+
mask,
|
601 |
+
color=mask_color,
|
602 |
+
edge_color=_OFF_WHITE,
|
603 |
+
text=text,
|
604 |
+
alpha=alpha,
|
605 |
+
area_threshold=area_threshold,
|
606 |
+
is_text=is_text,
|
607 |
+
)
|
608 |
+
|
609 |
+
# draw mask for all instances second
|
610 |
+
all_instances = list(pred.instance_masks())
|
611 |
+
if len(all_instances) == 0:
|
612 |
+
return self.output
|
613 |
+
masks, sinfo = list(zip(*all_instances))
|
614 |
+
category_ids = [x["category_id"] for x in sinfo]
|
615 |
+
|
616 |
+
try:
|
617 |
+
scores = [x["score"] for x in sinfo]
|
618 |
+
except KeyError:
|
619 |
+
scores = None
|
620 |
+
labels = _create_text_labels(
|
621 |
+
category_ids, scores, self.metadata.stuff_classes, [x.get("iscrowd", 0) for x in sinfo]
|
622 |
+
)
|
623 |
+
|
624 |
+
try:
|
625 |
+
colors = [
|
626 |
+
self._jitter([x / 255 for x in self.metadata.stuff_colors[c]]) for c in category_ids
|
627 |
+
]
|
628 |
+
except AttributeError:
|
629 |
+
colors = None
|
630 |
+
self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha, is_text=is_text)
|
631 |
+
|
632 |
+
return self.output
|
633 |
+
|
634 |
+
draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility
|
635 |
+
|
636 |
+
def draw_dataset_dict(self, dic):
|
637 |
+
"""
|
638 |
+
Draw annotations/segmentaions in Detectron2 Dataset format.
|
639 |
+
Args:
|
640 |
+
dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format.
|
641 |
+
Returns:
|
642 |
+
output (VisImage): image object with visualizations.
|
643 |
+
"""
|
644 |
+
annos = dic.get("annotations", None)
|
645 |
+
if annos:
|
646 |
+
if "segmentation" in annos[0]:
|
647 |
+
masks = [x["segmentation"] for x in annos]
|
648 |
+
else:
|
649 |
+
masks = None
|
650 |
+
if "keypoints" in annos[0]:
|
651 |
+
keypts = [x["keypoints"] for x in annos]
|
652 |
+
keypts = np.array(keypts).reshape(len(annos), -1, 3)
|
653 |
+
else:
|
654 |
+
keypts = None
|
655 |
+
|
656 |
+
boxes = [
|
657 |
+
BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS)
|
658 |
+
if len(x["bbox"]) == 4
|
659 |
+
else x["bbox"]
|
660 |
+
for x in annos
|
661 |
+
]
|
662 |
+
|
663 |
+
colors = None
|
664 |
+
category_ids = [x["category_id"] for x in annos]
|
665 |
+
if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("stuff_colors"):
|
666 |
+
colors = [
|
667 |
+
self._jitter([x / 255 for x in self.metadata.stuff_colors[c]])
|
668 |
+
for c in category_ids
|
669 |
+
]
|
670 |
+
names = self.metadata.get("stuff_classes", None)
|
671 |
+
labels = _create_text_labels(
|
672 |
+
category_ids,
|
673 |
+
scores=None,
|
674 |
+
class_names=names,
|
675 |
+
is_crowd=[x.get("iscrowd", 0) for x in annos],
|
676 |
+
)
|
677 |
+
self.overlay_instances(
|
678 |
+
labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors
|
679 |
+
)
|
680 |
+
|
681 |
+
sem_seg = dic.get("sem_seg", None)
|
682 |
+
if sem_seg is None and "sem_seg_file_name" in dic:
|
683 |
+
with PathManager.open(dic["sem_seg_file_name"], "rb") as f:
|
684 |
+
sem_seg = Image.open(f)
|
685 |
+
sem_seg = np.asarray(sem_seg, dtype="uint8")
|
686 |
+
if sem_seg is not None:
|
687 |
+
self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5)
|
688 |
+
|
689 |
+
pan_seg = dic.get("pan_seg", None)
|
690 |
+
if pan_seg is None and "pan_seg_file_name" in dic:
|
691 |
+
with PathManager.open(dic["pan_seg_file_name"], "rb") as f:
|
692 |
+
pan_seg = Image.open(f)
|
693 |
+
pan_seg = np.asarray(pan_seg)
|
694 |
+
from panopticapi.utils import rgb2id
|
695 |
+
|
696 |
+
pan_seg = rgb2id(pan_seg)
|
697 |
+
if pan_seg is not None:
|
698 |
+
segments_info = dic["segments_info"]
|
699 |
+
pan_seg = torch.tensor(pan_seg)
|
700 |
+
self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5)
|
701 |
+
return self.output
|
702 |
+
|
703 |
+
def overlay_instances(
|
704 |
+
self,
|
705 |
+
*,
|
706 |
+
boxes=None,
|
707 |
+
labels=None,
|
708 |
+
masks=None,
|
709 |
+
keypoints=None,
|
710 |
+
assigned_colors=None,
|
711 |
+
alpha=0.5,
|
712 |
+
is_text=True,
|
713 |
+
):
|
714 |
+
"""
|
715 |
+
Args:
|
716 |
+
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
|
717 |
+
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
|
718 |
+
or a :class:`RotatedBoxes`,
|
719 |
+
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
|
720 |
+
for the N objects in a single image,
|
721 |
+
labels (list[str]): the text to be displayed for each instance.
|
722 |
+
masks (masks-like object): Supported types are:
|
723 |
+
* :class:`detectron2.structures.PolygonMasks`,
|
724 |
+
:class:`detectron2.structures.BitMasks`.
|
725 |
+
* list[list[ndarray]]: contains the segmentation masks for all objects in one image.
|
726 |
+
The first level of the list corresponds to individual instances. The second
|
727 |
+
level to all the polygon that compose the instance, and the third level
|
728 |
+
to the polygon coordinates. The third level should have the format of
|
729 |
+
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
|
730 |
+
* list[ndarray]: each ndarray is a binary mask of shape (H, W).
|
731 |
+
* list[dict]: each dict is a COCO-style RLE.
|
732 |
+
keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
|
733 |
+
where the N is the number of instances and K is the number of keypoints.
|
734 |
+
The last dimension corresponds to (x, y, visibility or score).
|
735 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
736 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
737 |
+
for full list of formats that the colors are accepted in.
|
738 |
+
Returns:
|
739 |
+
output (VisImage): image object with visualizations.
|
740 |
+
"""
|
741 |
+
num_instances = 0
|
742 |
+
if boxes is not None:
|
743 |
+
boxes = self._convert_boxes(boxes)
|
744 |
+
num_instances = len(boxes)
|
745 |
+
if masks is not None:
|
746 |
+
masks = self._convert_masks(masks)
|
747 |
+
if num_instances:
|
748 |
+
assert len(masks) == num_instances
|
749 |
+
else:
|
750 |
+
num_instances = len(masks)
|
751 |
+
if keypoints is not None:
|
752 |
+
if num_instances:
|
753 |
+
assert len(keypoints) == num_instances
|
754 |
+
else:
|
755 |
+
num_instances = len(keypoints)
|
756 |
+
keypoints = self._convert_keypoints(keypoints)
|
757 |
+
if labels is not None:
|
758 |
+
assert len(labels) == num_instances
|
759 |
+
if assigned_colors is None:
|
760 |
+
# assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
761 |
+
assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
|
762 |
+
if num_instances == 0:
|
763 |
+
return self.output
|
764 |
+
if boxes is not None and boxes.shape[1] == 5:
|
765 |
+
return self.overlay_rotated_instances(
|
766 |
+
boxes=boxes, labels=labels, assigned_colors=assigned_colors
|
767 |
+
)
|
768 |
+
|
769 |
+
# Display in largest to smallest order to reduce occlusion.
|
770 |
+
areas = None
|
771 |
+
if boxes is not None:
|
772 |
+
areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
|
773 |
+
elif masks is not None:
|
774 |
+
areas = np.asarray([x.area() for x in masks])
|
775 |
+
|
776 |
+
if areas is not None:
|
777 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
778 |
+
# Re-order overlapped instances in descending order.
|
779 |
+
boxes = boxes[sorted_idxs] if boxes is not None else None
|
780 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
781 |
+
masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
|
782 |
+
assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]
|
783 |
+
keypoints = keypoints[sorted_idxs] if keypoints is not None else None
|
784 |
+
|
785 |
+
for i in range(num_instances):
|
786 |
+
color = assigned_colors[i]
|
787 |
+
if boxes is not None:
|
788 |
+
self.draw_box(boxes[i], edge_color=color)
|
789 |
+
|
790 |
+
if masks is not None:
|
791 |
+
for segment in masks[i].polygons:
|
792 |
+
self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
|
793 |
+
|
794 |
+
if labels is not None:
|
795 |
+
# first get a box
|
796 |
+
if boxes is not None:
|
797 |
+
x0, y0, x1, y1 = boxes[i]
|
798 |
+
text_pos = (x0, y0) # if drawing boxes, put text on the box corner.
|
799 |
+
horiz_align = "left"
|
800 |
+
elif masks is not None:
|
801 |
+
# skip small mask without polygon
|
802 |
+
if len(masks[i].polygons) == 0:
|
803 |
+
continue
|
804 |
+
|
805 |
+
x0, y0, x1, y1 = masks[i].bbox()
|
806 |
+
|
807 |
+
# draw text in the center (defined by median) when box is not drawn
|
808 |
+
# median is less sensitive to outliers.
|
809 |
+
text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
|
810 |
+
horiz_align = "center"
|
811 |
+
else:
|
812 |
+
continue # drawing the box confidence for keypoints isn't very useful.
|
813 |
+
# for small objects, draw text at the side to avoid occlusion
|
814 |
+
instance_area = (y1 - y0) * (x1 - x0)
|
815 |
+
if (
|
816 |
+
instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale
|
817 |
+
or y1 - y0 < 40 * self.output.scale
|
818 |
+
):
|
819 |
+
if y1 >= self.output.height - 5:
|
820 |
+
text_pos = (x1, y0)
|
821 |
+
else:
|
822 |
+
text_pos = (x0, y1)
|
823 |
+
|
824 |
+
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
|
825 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
826 |
+
font_size = (
|
827 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
|
828 |
+
* 0.5
|
829 |
+
* self._default_font_size
|
830 |
+
)
|
831 |
+
if is_text:
|
832 |
+
self.draw_text(
|
833 |
+
labels[i],
|
834 |
+
text_pos,
|
835 |
+
color=lighter_color,
|
836 |
+
horizontal_alignment=horiz_align,
|
837 |
+
font_size=font_size,
|
838 |
+
)
|
839 |
+
|
840 |
+
# draw keypoints
|
841 |
+
if keypoints is not None:
|
842 |
+
for keypoints_per_instance in keypoints:
|
843 |
+
self.draw_and_connect_keypoints(keypoints_per_instance)
|
844 |
+
|
845 |
+
return self.output
|
846 |
+
|
847 |
+
def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None):
|
848 |
+
"""
|
849 |
+
Args:
|
850 |
+
boxes (ndarray): an Nx5 numpy array of
|
851 |
+
(x_center, y_center, width, height, angle_degrees) format
|
852 |
+
for the N objects in a single image.
|
853 |
+
labels (list[str]): the text to be displayed for each instance.
|
854 |
+
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
|
855 |
+
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
|
856 |
+
for full list of formats that the colors are accepted in.
|
857 |
+
Returns:
|
858 |
+
output (VisImage): image object with visualizations.
|
859 |
+
"""
|
860 |
+
num_instances = len(boxes)
|
861 |
+
|
862 |
+
if assigned_colors is None:
|
863 |
+
# assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
|
864 |
+
assigned_colors = [instance_color(rgb=True, idx=i, maximum=1) for i in range(num_instances)]
|
865 |
+
if num_instances == 0:
|
866 |
+
return self.output
|
867 |
+
|
868 |
+
# Display in largest to smallest order to reduce occlusion.
|
869 |
+
if boxes is not None:
|
870 |
+
areas = boxes[:, 2] * boxes[:, 3]
|
871 |
+
|
872 |
+
sorted_idxs = np.argsort(-areas).tolist()
|
873 |
+
# Re-order overlapped instances in descending order.
|
874 |
+
boxes = boxes[sorted_idxs]
|
875 |
+
labels = [labels[k] for k in sorted_idxs] if labels is not None else None
|
876 |
+
colors = [assigned_colors[idx] for idx in sorted_idxs]
|
877 |
+
|
878 |
+
for i in range(num_instances):
|
879 |
+
self.draw_rotated_box_with_label(
|
880 |
+
boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None
|
881 |
+
)
|
882 |
+
|
883 |
+
return self.output
|
884 |
+
|
885 |
+
def draw_and_connect_keypoints(self, keypoints):
|
886 |
+
"""
|
887 |
+
Draws keypoints of an instance and follows the rules for keypoint connections
|
888 |
+
to draw lines between appropriate keypoints. This follows color heuristics for
|
889 |
+
line color.
|
890 |
+
Args:
|
891 |
+
keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints
|
892 |
+
and the last dimension corresponds to (x, y, probability).
|
893 |
+
Returns:
|
894 |
+
output (VisImage): image object with visualizations.
|
895 |
+
"""
|
896 |
+
visible = {}
|
897 |
+
keypoint_names = self.metadata.get("keypoint_names")
|
898 |
+
for idx, keypoint in enumerate(keypoints):
|
899 |
+
|
900 |
+
# draw keypoint
|
901 |
+
x, y, prob = keypoint
|
902 |
+
if prob > self.keypoint_threshold:
|
903 |
+
self.draw_circle((x, y), color=_RED)
|
904 |
+
if keypoint_names:
|
905 |
+
keypoint_name = keypoint_names[idx]
|
906 |
+
visible[keypoint_name] = (x, y)
|
907 |
+
|
908 |
+
if self.metadata.get("keypoint_connection_rules"):
|
909 |
+
for kp0, kp1, color in self.metadata.keypoint_connection_rules:
|
910 |
+
if kp0 in visible and kp1 in visible:
|
911 |
+
x0, y0 = visible[kp0]
|
912 |
+
x1, y1 = visible[kp1]
|
913 |
+
color = tuple(x / 255.0 for x in color)
|
914 |
+
self.draw_line([x0, x1], [y0, y1], color=color)
|
915 |
+
|
916 |
+
# draw lines from nose to mid-shoulder and mid-shoulder to mid-hip
|
917 |
+
# Note that this strategy is specific to person keypoints.
|
918 |
+
# For other keypoints, it should just do nothing
|
919 |
+
try:
|
920 |
+
ls_x, ls_y = visible["left_shoulder"]
|
921 |
+
rs_x, rs_y = visible["right_shoulder"]
|
922 |
+
mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2
|
923 |
+
except KeyError:
|
924 |
+
pass
|
925 |
+
else:
|
926 |
+
# draw line from nose to mid-shoulder
|
927 |
+
nose_x, nose_y = visible.get("nose", (None, None))
|
928 |
+
if nose_x is not None:
|
929 |
+
self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED)
|
930 |
+
|
931 |
+
try:
|
932 |
+
# draw line from mid-shoulder to mid-hip
|
933 |
+
lh_x, lh_y = visible["left_hip"]
|
934 |
+
rh_x, rh_y = visible["right_hip"]
|
935 |
+
except KeyError:
|
936 |
+
pass
|
937 |
+
else:
|
938 |
+
mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2
|
939 |
+
self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED)
|
940 |
+
return self.output
|
941 |
+
|
942 |
+
"""
|
943 |
+
Primitive drawing functions:
|
944 |
+
"""
|
945 |
+
|
946 |
+
def draw_text(
|
947 |
+
self,
|
948 |
+
text,
|
949 |
+
position,
|
950 |
+
*,
|
951 |
+
font_size=None,
|
952 |
+
color="g",
|
953 |
+
horizontal_alignment="center",
|
954 |
+
rotation=0,
|
955 |
+
):
|
956 |
+
"""
|
957 |
+
Args:
|
958 |
+
text (str): class label
|
959 |
+
position (tuple): a tuple of the x and y coordinates to place text on image.
|
960 |
+
font_size (int, optional): font of the text. If not provided, a font size
|
961 |
+
proportional to the image width is calculated and used.
|
962 |
+
color: color of the text. Refer to `matplotlib.colors` for full list
|
963 |
+
of formats that are accepted.
|
964 |
+
horizontal_alignment (str): see `matplotlib.text.Text`
|
965 |
+
rotation: rotation angle in degrees CCW
|
966 |
+
Returns:
|
967 |
+
output (VisImage): image object with text drawn.
|
968 |
+
"""
|
969 |
+
if not font_size:
|
970 |
+
font_size = self._default_font_size
|
971 |
+
|
972 |
+
# since the text background is dark, we don't want the text to be dark
|
973 |
+
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
|
974 |
+
color[np.argmax(color)] = max(0.8, np.max(color))
|
975 |
+
|
976 |
+
x, y = position
|
977 |
+
self.output.ax.text(
|
978 |
+
x,
|
979 |
+
y,
|
980 |
+
text,
|
981 |
+
size=font_size * self.output.scale,
|
982 |
+
family="sans-serif",
|
983 |
+
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
|
984 |
+
verticalalignment="top",
|
985 |
+
horizontalalignment=horizontal_alignment,
|
986 |
+
color=color,
|
987 |
+
zorder=10,
|
988 |
+
rotation=rotation,
|
989 |
+
)
|
990 |
+
return self.output
|
991 |
+
|
992 |
+
def draw_box(self, box_coord, alpha=1.0, edge_color="g", line_style="-"):
|
993 |
+
"""
|
994 |
+
Args:
|
995 |
+
box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
|
996 |
+
are the coordinates of the image's top left corner. x1 and y1 are the
|
997 |
+
coordinates of the image's bottom right corner.
|
998 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
999 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
1000 |
+
for full list of formats that are accepted.
|
1001 |
+
line_style (string): the string to use to create the outline of the boxes.
|
1002 |
+
Returns:
|
1003 |
+
output (VisImage): image object with box drawn.
|
1004 |
+
"""
|
1005 |
+
x0, y0, x1, y1 = box_coord
|
1006 |
+
width = x1 - x0
|
1007 |
+
height = y1 - y0
|
1008 |
+
|
1009 |
+
linewidth = 2
|
1010 |
+
|
1011 |
+
self.output.ax.add_patch(
|
1012 |
+
mpl.patches.Rectangle(
|
1013 |
+
(x0, y0),
|
1014 |
+
width,
|
1015 |
+
height,
|
1016 |
+
fill=False,
|
1017 |
+
edgecolor=edge_color,
|
1018 |
+
linewidth=linewidth * self.output.scale,
|
1019 |
+
alpha=alpha,
|
1020 |
+
linestyle=line_style,
|
1021 |
+
)
|
1022 |
+
)
|
1023 |
+
return self.output
|
1024 |
+
|
1025 |
+
def draw_rotated_box_with_label(
|
1026 |
+
self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None
|
1027 |
+
):
|
1028 |
+
"""
|
1029 |
+
Draw a rotated box with label on its top-left corner.
|
1030 |
+
Args:
|
1031 |
+
rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle),
|
1032 |
+
where cnt_x and cnt_y are the center coordinates of the box.
|
1033 |
+
w and h are the width and height of the box. angle represents how
|
1034 |
+
many degrees the box is rotated CCW with regard to the 0-degree box.
|
1035 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1036 |
+
edge_color: color of the outline of the box. Refer to `matplotlib.colors`
|
1037 |
+
for full list of formats that are accepted.
|
1038 |
+
line_style (string): the string to use to create the outline of the boxes.
|
1039 |
+
label (string): label for rotated box. It will not be rendered when set to None.
|
1040 |
+
Returns:
|
1041 |
+
output (VisImage): image object with box drawn.
|
1042 |
+
"""
|
1043 |
+
cnt_x, cnt_y, w, h, angle = rotated_box
|
1044 |
+
area = w * h
|
1045 |
+
# use thinner lines when the box is small
|
1046 |
+
linewidth = self._default_font_size / (
|
1047 |
+
6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
theta = angle * math.pi / 180.0
|
1051 |
+
c = math.cos(theta)
|
1052 |
+
s = math.sin(theta)
|
1053 |
+
rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)]
|
1054 |
+
# x: left->right ; y: top->down
|
1055 |
+
rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect]
|
1056 |
+
for k in range(4):
|
1057 |
+
j = (k + 1) % 4
|
1058 |
+
self.draw_line(
|
1059 |
+
[rotated_rect[k][0], rotated_rect[j][0]],
|
1060 |
+
[rotated_rect[k][1], rotated_rect[j][1]],
|
1061 |
+
color=edge_color,
|
1062 |
+
linestyle="--" if k == 1 else line_style,
|
1063 |
+
linewidth=linewidth,
|
1064 |
+
)
|
1065 |
+
|
1066 |
+
if label is not None:
|
1067 |
+
text_pos = rotated_rect[1] # topleft corner
|
1068 |
+
|
1069 |
+
height_ratio = h / np.sqrt(self.output.height * self.output.width)
|
1070 |
+
label_color = self._change_color_brightness(edge_color, brightness_factor=0.7)
|
1071 |
+
font_size = (
|
1072 |
+
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size
|
1073 |
+
)
|
1074 |
+
self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle)
|
1075 |
+
|
1076 |
+
return self.output
|
1077 |
+
|
1078 |
+
def draw_circle(self, circle_coord, color, radius=3):
|
1079 |
+
"""
|
1080 |
+
Args:
|
1081 |
+
circle_coord (list(int) or tuple(int)): contains the x and y coordinates
|
1082 |
+
of the center of the circle.
|
1083 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
1084 |
+
formats that are accepted.
|
1085 |
+
radius (int): radius of the circle.
|
1086 |
+
Returns:
|
1087 |
+
output (VisImage): image object with box drawn.
|
1088 |
+
"""
|
1089 |
+
x, y = circle_coord
|
1090 |
+
self.output.ax.add_patch(
|
1091 |
+
mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color)
|
1092 |
+
)
|
1093 |
+
return self.output
|
1094 |
+
|
1095 |
+
def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None):
|
1096 |
+
"""
|
1097 |
+
Args:
|
1098 |
+
x_data (list[int]): a list containing x values of all the points being drawn.
|
1099 |
+
Length of list should match the length of y_data.
|
1100 |
+
y_data (list[int]): a list containing y values of all the points being drawn.
|
1101 |
+
Length of list should match the length of x_data.
|
1102 |
+
color: color of the line. Refer to `matplotlib.colors` for a full list of
|
1103 |
+
formats that are accepted.
|
1104 |
+
linestyle: style of the line. Refer to `matplotlib.lines.Line2D`
|
1105 |
+
for a full list of formats that are accepted.
|
1106 |
+
linewidth (float or None): width of the line. When it's None,
|
1107 |
+
a default value will be computed and used.
|
1108 |
+
Returns:
|
1109 |
+
output (VisImage): image object with line drawn.
|
1110 |
+
"""
|
1111 |
+
if linewidth is None:
|
1112 |
+
linewidth = self._default_font_size / 3
|
1113 |
+
linewidth = max(linewidth, 1)
|
1114 |
+
self.output.ax.add_line(
|
1115 |
+
mpl.lines.Line2D(
|
1116 |
+
x_data,
|
1117 |
+
y_data,
|
1118 |
+
linewidth=linewidth * self.output.scale,
|
1119 |
+
color=color,
|
1120 |
+
linestyle=linestyle,
|
1121 |
+
)
|
1122 |
+
)
|
1123 |
+
return self.output
|
1124 |
+
|
1125 |
+
def draw_binary_mask(
|
1126 |
+
self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10, is_text=True,
|
1127 |
+
):
|
1128 |
+
"""
|
1129 |
+
Args:
|
1130 |
+
binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and
|
1131 |
+
W is the image width. Each value in the array is either a 0 or 1 value of uint8
|
1132 |
+
type.
|
1133 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
1134 |
+
formats that are accepted. If None, will pick a random color.
|
1135 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
1136 |
+
full list of formats that are accepted.
|
1137 |
+
text (str): if None, will be drawn on the object
|
1138 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1139 |
+
area_threshold (float): a connected component smaller than this area will not be shown.
|
1140 |
+
Returns:
|
1141 |
+
output (VisImage): image object with mask drawn.
|
1142 |
+
"""
|
1143 |
+
if color is None:
|
1144 |
+
color = random_color(rgb=True, maximum=1)
|
1145 |
+
color = mplc.to_rgb(color)
|
1146 |
+
|
1147 |
+
has_valid_segment = False
|
1148 |
+
binary_mask = binary_mask.astype("uint8") # opencv needs uint8
|
1149 |
+
mask = GenericMask(binary_mask, self.output.height, self.output.width)
|
1150 |
+
shape2d = (binary_mask.shape[0], binary_mask.shape[1])
|
1151 |
+
|
1152 |
+
if not mask.has_holes:
|
1153 |
+
# draw polygons for regular masks
|
1154 |
+
for segment in mask.polygons:
|
1155 |
+
area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1]))
|
1156 |
+
if area < (area_threshold or 0):
|
1157 |
+
continue
|
1158 |
+
has_valid_segment = True
|
1159 |
+
segment = segment.reshape(-1, 2)
|
1160 |
+
self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha)
|
1161 |
+
else:
|
1162 |
+
# TODO: Use Path/PathPatch to draw vector graphics:
|
1163 |
+
# https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon
|
1164 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
1165 |
+
rgba[:, :, :3] = color
|
1166 |
+
rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha
|
1167 |
+
has_valid_segment = True
|
1168 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
1169 |
+
|
1170 |
+
if is_text:
|
1171 |
+
if text is not None and has_valid_segment:
|
1172 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
1173 |
+
self._draw_text_in_mask(binary_mask, text, lighter_color)
|
1174 |
+
return self.output
|
1175 |
+
|
1176 |
+
def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5):
|
1177 |
+
"""
|
1178 |
+
Args:
|
1179 |
+
soft_mask (ndarray): float array of shape (H, W), each value in [0, 1].
|
1180 |
+
color: color of the mask. Refer to `matplotlib.colors` for a full list of
|
1181 |
+
formats that are accepted. If None, will pick a random color.
|
1182 |
+
text (str): if None, will be drawn on the object
|
1183 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1184 |
+
Returns:
|
1185 |
+
output (VisImage): image object with mask drawn.
|
1186 |
+
"""
|
1187 |
+
if color is None:
|
1188 |
+
color = random_color(rgb=True, maximum=1)
|
1189 |
+
color = mplc.to_rgb(color)
|
1190 |
+
|
1191 |
+
shape2d = (soft_mask.shape[0], soft_mask.shape[1])
|
1192 |
+
rgba = np.zeros(shape2d + (4,), dtype="float32")
|
1193 |
+
rgba[:, :, :3] = color
|
1194 |
+
rgba[:, :, 3] = soft_mask * alpha
|
1195 |
+
self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0))
|
1196 |
+
|
1197 |
+
if text is not None:
|
1198 |
+
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
|
1199 |
+
binary_mask = (soft_mask > 0.5).astype("uint8")
|
1200 |
+
# self._draw_text_in_mask(binary_mask, text, lighter_color)
|
1201 |
+
return self.output
|
1202 |
+
|
1203 |
+
def draw_polygon(self, segment, color, edge_color=None, alpha=0.5):
|
1204 |
+
"""
|
1205 |
+
Args:
|
1206 |
+
segment: numpy array of shape Nx2, containing all the points in the polygon.
|
1207 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
1208 |
+
formats that are accepted.
|
1209 |
+
edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
|
1210 |
+
full list of formats that are accepted. If not provided, a darker shade
|
1211 |
+
of the polygon color will be used instead.
|
1212 |
+
alpha (float): blending efficient. Smaller values lead to more transparent masks.
|
1213 |
+
Returns:
|
1214 |
+
output (VisImage): image object with polygon drawn.
|
1215 |
+
"""
|
1216 |
+
if edge_color is None:
|
1217 |
+
# make edge color darker than the polygon color
|
1218 |
+
if alpha > 0.8:
|
1219 |
+
edge_color = self._change_color_brightness(color, brightness_factor=-0.7)
|
1220 |
+
else:
|
1221 |
+
edge_color = color
|
1222 |
+
edge_color = mplc.to_rgb(edge_color) + (1,)
|
1223 |
+
|
1224 |
+
polygon = mpl.patches.Polygon(
|
1225 |
+
segment,
|
1226 |
+
fill=True,
|
1227 |
+
facecolor=mplc.to_rgb(color) + (alpha,),
|
1228 |
+
edgecolor=edge_color,
|
1229 |
+
linewidth=max(self._default_font_size // 15 * self.output.scale, 1),
|
1230 |
+
)
|
1231 |
+
self.output.ax.add_patch(polygon)
|
1232 |
+
return self.output
|
1233 |
+
|
1234 |
+
"""
|
1235 |
+
Internal methods:
|
1236 |
+
"""
|
1237 |
+
|
1238 |
+
def _jitter(self, color):
|
1239 |
+
"""
|
1240 |
+
Randomly modifies given color to produce a slightly different color than the color given.
|
1241 |
+
Args:
|
1242 |
+
color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color
|
1243 |
+
picked. The values in the list are in the [0.0, 1.0] range.
|
1244 |
+
Returns:
|
1245 |
+
jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the
|
1246 |
+
color after being jittered. The values in the list are in the [0.0, 1.0] range.
|
1247 |
+
"""
|
1248 |
+
color = mplc.to_rgb(color)
|
1249 |
+
vec = np.random.rand(3)
|
1250 |
+
# better to do it in another color space
|
1251 |
+
vec = vec / np.linalg.norm(vec) * 0.5
|
1252 |
+
res = np.clip(vec + color, 0, 1)
|
1253 |
+
return tuple(res)
|
1254 |
+
|
1255 |
+
def _create_grayscale_image(self, mask=None):
|
1256 |
+
"""
|
1257 |
+
Create a grayscale version of the original image.
|
1258 |
+
The colors in masked area, if given, will be kept.
|
1259 |
+
"""
|
1260 |
+
img_bw = self.img.astype("f4").mean(axis=2)
|
1261 |
+
img_bw = np.stack([img_bw] * 3, axis=2)
|
1262 |
+
if mask is not None:
|
1263 |
+
img_bw[mask] = self.img[mask]
|
1264 |
+
return img_bw
|
1265 |
+
|
1266 |
+
def _change_color_brightness(self, color, brightness_factor):
|
1267 |
+
"""
|
1268 |
+
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
|
1269 |
+
less or more saturation than the original color.
|
1270 |
+
Args:
|
1271 |
+
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
|
1272 |
+
formats that are accepted.
|
1273 |
+
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
|
1274 |
+
0 will correspond to no change, a factor in [-1.0, 0) range will result in
|
1275 |
+
a darker color and a factor in (0, 1.0] range will result in a lighter color.
|
1276 |
+
Returns:
|
1277 |
+
modified_color (tuple[double]): a tuple containing the RGB values of the
|
1278 |
+
modified color. Each value in the tuple is in the [0.0, 1.0] range.
|
1279 |
+
"""
|
1280 |
+
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
|
1281 |
+
color = mplc.to_rgb(color)
|
1282 |
+
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
|
1283 |
+
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
|
1284 |
+
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
|
1285 |
+
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
|
1286 |
+
modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
|
1287 |
+
return modified_color
|
1288 |
+
|
1289 |
+
def _convert_boxes(self, boxes):
|
1290 |
+
"""
|
1291 |
+
Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
|
1292 |
+
"""
|
1293 |
+
if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
|
1294 |
+
return boxes.tensor.detach().numpy()
|
1295 |
+
else:
|
1296 |
+
return np.asarray(boxes)
|
1297 |
+
|
1298 |
+
def _convert_masks(self, masks_or_polygons):
|
1299 |
+
"""
|
1300 |
+
Convert different format of masks or polygons to a tuple of masks and polygons.
|
1301 |
+
Returns:
|
1302 |
+
list[GenericMask]:
|
1303 |
+
"""
|
1304 |
+
|
1305 |
+
m = masks_or_polygons
|
1306 |
+
if isinstance(m, PolygonMasks):
|
1307 |
+
m = m.polygons
|
1308 |
+
if isinstance(m, BitMasks):
|
1309 |
+
m = m.tensor.numpy()
|
1310 |
+
if isinstance(m, torch.Tensor):
|
1311 |
+
m = m.numpy()
|
1312 |
+
ret = []
|
1313 |
+
for x in m:
|
1314 |
+
if isinstance(x, GenericMask):
|
1315 |
+
ret.append(x)
|
1316 |
+
else:
|
1317 |
+
ret.append(GenericMask(x, self.output.height, self.output.width))
|
1318 |
+
return ret
|
1319 |
+
|
1320 |
+
def _draw_text_in_mask(self, binary_mask, text, color):
|
1321 |
+
"""
|
1322 |
+
Find proper places to draw text given a binary mask.
|
1323 |
+
"""
|
1324 |
+
# TODO sometimes drawn on wrong objects. the heuristics here can improve.
|
1325 |
+
_num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8)
|
1326 |
+
if stats[1:, -1].size == 0:
|
1327 |
+
return
|
1328 |
+
largest_component_id = np.argmax(stats[1:, -1]) + 1
|
1329 |
+
|
1330 |
+
# draw text on the largest component, as well as other very large components.
|
1331 |
+
for cid in range(1, _num_cc):
|
1332 |
+
if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH:
|
1333 |
+
# median is more stable than centroid
|
1334 |
+
# center = centroids[largest_component_id]
|
1335 |
+
center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1]
|
1336 |
+
self.draw_text(text, center, color=color)
|
1337 |
+
|
1338 |
+
def _convert_keypoints(self, keypoints):
|
1339 |
+
if isinstance(keypoints, Keypoints):
|
1340 |
+
keypoints = keypoints.tensor
|
1341 |
+
keypoints = np.asarray(keypoints)
|
1342 |
+
return keypoints
|
1343 |
+
|
1344 |
+
def get_output(self):
|
1345 |
+
"""
|
1346 |
+
Returns:
|
1347 |
+
output (VisImage): the image output containing the visualizations added
|
1348 |
+
to the image.
|
1349 |
+
"""
|
1350 |
+
return self.output
|
gradio_app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
print("Installed the dependencies!")
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
import imutils
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
import time
|
12 |
+
from detectron2.config import get_cfg
|
13 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
14 |
+
from detectron2.data import MetadataCatalog
|
15 |
+
torch.set_num_threads(16) # Use 16 CPU threads
|
16 |
+
torch.set_num_interop_threads(16) # Inter-op parallelism
|
17 |
+
from oneformer import (
|
18 |
+
add_oneformer_config,
|
19 |
+
add_common_config,
|
20 |
+
add_swin_config,
|
21 |
+
add_dinat_config,
|
22 |
+
)
|
23 |
+
|
24 |
+
from demo.defaults import DefaultPredictor
|
25 |
+
from demo.visualizer import Visualizer, ColorMode
|
26 |
+
|
27 |
+
import gradio as gr
|
28 |
+
from huggingface_hub import hf_hub_download
|
29 |
+
|
30 |
+
KEY_DICT = {"Cityscapes (19 classes)": "cityscapes",
|
31 |
+
"COCO (133 classes)": "coco",
|
32 |
+
"ADE20K (150 classes)": "ade20k",}
|
33 |
+
|
34 |
+
SWIN_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml",
|
35 |
+
"coco": "configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml",
|
36 |
+
"ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",}
|
37 |
+
|
38 |
+
SWIN_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_swin_large",
|
39 |
+
filename="250_16_swin_l_oneformer_cityscapes_90k.pth"),
|
40 |
+
"coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_swin_large",
|
41 |
+
filename="150_16_swin_l_oneformer_coco_100ep.pth"),
|
42 |
+
"ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_swin_large",
|
43 |
+
filename="250_16_swin_l_oneformer_ade20k_160k.pth")
|
44 |
+
}
|
45 |
+
|
46 |
+
DINAT_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml",
|
47 |
+
"coco": "configs/coco/oneformer_dinat_large_bs16_100ep.yaml",
|
48 |
+
"ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",}
|
49 |
+
|
50 |
+
DINAT_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_dinat_large",
|
51 |
+
filename="250_16_dinat_l_oneformer_cityscapes_90k.pth"),
|
52 |
+
"coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_dinat_large",
|
53 |
+
filename="150_16_dinat_l_oneformer_coco_100ep.pth"),
|
54 |
+
"ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_dinat_large",
|
55 |
+
filename="250_16_dinat_l_oneformer_ade20k_160k.pth")
|
56 |
+
}
|
57 |
+
|
58 |
+
MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT,
|
59 |
+
"Swin-L": SWIN_MODEL_DICT }
|
60 |
+
|
61 |
+
CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT,
|
62 |
+
"Swin-L": SWIN_CFG_DICT }
|
63 |
+
|
64 |
+
WIDTH_DICT = {"cityscapes": 512,
|
65 |
+
"coco": 512,
|
66 |
+
"ade20k": 640}
|
67 |
+
|
68 |
+
cpu_device = torch.device("cpu")
|
69 |
+
|
70 |
+
PREDICTORS = {
|
71 |
+
"DiNAT-L": {
|
72 |
+
"Cityscapes (19 classes)": None,
|
73 |
+
"COCO (133 classes)": None,
|
74 |
+
"ADE20K (150 classes)": None
|
75 |
+
},
|
76 |
+
"Swin-L": {
|
77 |
+
"Cityscapes (19 classes)": None,
|
78 |
+
"COCO (133 classes)": None,
|
79 |
+
"ADE20K (150 classes)": None
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
+
METADATA = {
|
84 |
+
"DiNAT-L": {
|
85 |
+
"Cityscapes (19 classes)": None,
|
86 |
+
"COCO (133 classes)": None,
|
87 |
+
"ADE20K (150 classes)": None
|
88 |
+
},
|
89 |
+
"Swin-L": {
|
90 |
+
"Cityscapes (19 classes)": None,
|
91 |
+
"COCO (133 classes)": None,
|
92 |
+
"ADE20K (150 classes)": None
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
def setup_modules():
|
97 |
+
for dataset in ["Cityscapes (19 classes)", "COCO (133 classes)", "ADE20K (150 classes)"]:
|
98 |
+
for backbone in ["DiNAT-L", "Swin-L"]:
|
99 |
+
cfg = setup_cfg(dataset, backbone)
|
100 |
+
metadata = MetadataCatalog.get(
|
101 |
+
cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
102 |
+
)
|
103 |
+
if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
|
104 |
+
from cityscapesscripts.helpers.labels import labels
|
105 |
+
stuff_colors = [k.color for k in labels if k.trainId != 255]
|
106 |
+
metadata = metadata.set(stuff_colors=stuff_colors)
|
107 |
+
PREDICTORS[backbone][dataset] = DefaultPredictor(cfg)
|
108 |
+
METADATA[backbone][dataset] = metadata
|
109 |
+
|
110 |
+
def setup_cfg(dataset, backbone):
|
111 |
+
# load config from file and command-line arguments
|
112 |
+
cfg = get_cfg()
|
113 |
+
add_deeplab_config(cfg)
|
114 |
+
add_common_config(cfg)
|
115 |
+
add_swin_config(cfg)
|
116 |
+
add_oneformer_config(cfg)
|
117 |
+
add_dinat_config(cfg)
|
118 |
+
dataset = KEY_DICT[dataset]
|
119 |
+
cfg_path = CFG_DICT[backbone][dataset]
|
120 |
+
cfg.merge_from_file(cfg_path)
|
121 |
+
if torch.cuda.is_available():
|
122 |
+
cfg.MODEL.DEVICE = 'cuda'
|
123 |
+
else:
|
124 |
+
cfg.MODEL.DEVICE = 'cpu'
|
125 |
+
cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
|
126 |
+
cfg.freeze()
|
127 |
+
return cfg
|
128 |
+
|
129 |
+
# def setup_modules(dataset, backbone):
|
130 |
+
# cfg = setup_cfg(dataset, backbone)
|
131 |
+
# predictor = DefaultPredictor(cfg)
|
132 |
+
# # predictor = PREDICTORS[backbone][dataset]
|
133 |
+
# metadata = MetadataCatalog.get(
|
134 |
+
# cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
135 |
+
# )
|
136 |
+
# if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
|
137 |
+
# from cityscapesscripts.helpers.labels import labels
|
138 |
+
# stuff_colors = [k.color for k in labels if k.trainId != 255]
|
139 |
+
# metadata = metadata.set(stuff_colors=stuff_colors)
|
140 |
+
|
141 |
+
# return predictor, metadata
|
142 |
+
|
143 |
+
def panoptic_run(img, predictor, metadata):
|
144 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
145 |
+
predictions = predictor(img, "panoptic")
|
146 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
147 |
+
out = visualizer.draw_panoptic_seg_predictions(
|
148 |
+
panoptic_seg.to(cpu_device), segments_info, alpha=0.5
|
149 |
+
)
|
150 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
151 |
+
out_map = visualizer_map.draw_panoptic_seg_predictions(
|
152 |
+
panoptic_seg.to(cpu_device), segments_info, alpha=1, is_text=False
|
153 |
+
)
|
154 |
+
return out, out_map
|
155 |
+
|
156 |
+
def instance_run(img, predictor, metadata):
|
157 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
158 |
+
predictions = predictor(img, "instance")
|
159 |
+
instances = predictions["instances"].to(cpu_device)
|
160 |
+
out = visualizer.draw_instance_predictions(predictions=instances, alpha=0.5)
|
161 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
162 |
+
out_map = visualizer_map.draw_instance_predictions(predictions=instances, alpha=1, is_text=False)
|
163 |
+
return out, out_map
|
164 |
+
|
165 |
+
def semantic_run(img, predictor, metadata):
|
166 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
167 |
+
predictions = predictor(img, "semantic")
|
168 |
+
out = visualizer.draw_sem_seg(
|
169 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
|
170 |
+
)
|
171 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
172 |
+
out_map = visualizer_map.draw_sem_seg(
|
173 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=1, is_text=False
|
174 |
+
)
|
175 |
+
return out, out_map
|
176 |
+
|
177 |
+
TASK_INFER = {"the task is panoptic": panoptic_run, "the task is instance": instance_run, "the task is semantic": semantic_run}
|
178 |
+
|
179 |
+
def segment(path, task, dataset, backbone):
|
180 |
+
# predictor, metadata = setup_modules(dataset, backbone)
|
181 |
+
predictor = PREDICTORS[backbone][dataset]
|
182 |
+
metadata = METADATA[backbone][dataset]
|
183 |
+
img = cv2.imread(path)
|
184 |
+
width = WIDTH_DICT[KEY_DICT[dataset]]
|
185 |
+
img = imutils.resize(img, width=width)
|
186 |
+
out, out_map = TASK_INFER[task](img, predictor, metadata)
|
187 |
+
out = Image.fromarray(out.get_image())
|
188 |
+
out_map = Image.fromarray(out_map.get_image())
|
189 |
+
return out, out_map
|
190 |
+
|
191 |
+
title = "<h1 style='text-align: center'>OneFormer:DIEGO MENTORIA MILIONÁRIA - APP 1</h1>"
|
192 |
+
# style='margin-bottom: -10px;
|
193 |
+
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain, </a> <a href='https://chrisjuniorli.github.io/' style='text-decoration:none' target='_blank'>Jiachen Li<sup>*</sup>, </a> <a href='https://www.linkedin.com/in/mtchiu/' style='text-decoration:none' target='_blank'>MangTik Chiu<sup>*</sup>, </a> <a href='https://alihassanijr.com/' style='text-decoration:none' target='_blank'>Ali Hassani, </a> <a href='https://www.linkedin.com/in/nukich74/' style='text-decoration:none' target='_blank'>Nikita Orlov, </a> <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi</a></p>" \
|
194 |
+
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/oneformer/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2211.06220' target='_blank'>ArXiv Paper</a> | <a href='https://github.com/SHI-Labs/OneFormer' target='_blank'>Github Repo</a></p>" \
|
195 |
+
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
|
196 |
+
OneFormer is the first multi-task universal image segmentation framework based on transformers. Our single OneFormer model achieves state-of-the-art performance across all three segmentation tasks with a single task-conditioned joint training process. OneFormer uses a task token to condition the model on the task in focus, making our architecture task-guided for training, and task-dynamic for inference, all with a single model. We believe OneFormer is a significant step towards making image segmentation more universal and accessible.\
|
197 |
+
</p>" \
|
198 |
+
+ "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a single RTX A6000 GPU, OneFormer is able to inference at more than 15 FPS.]</p>"
|
199 |
+
|
200 |
+
setup_modules()
|
201 |
+
|
202 |
+
gradio_inputs = [gr.Image(label="Input Image",type="filepath"),
|
203 |
+
gr.Radio(choices=["the task is panoptic" ,"the task is instance", "the task is semantic"], type="value", value="the task is panoptic", label="Task Token Input"),
|
204 |
+
gr.Radio(choices=["COCO (133 classes)" ,"Cityscapes (19 classes)", "ADE20K (150 classes)"], type="value", value="COCO (133 classes)", label="Model"),
|
205 |
+
gr.Radio(choices=["DiNAT-L" ,"Swin-L"], type="value", value="DiNAT-L", label="Backbone"),
|
206 |
+
]
|
207 |
+
gradio_outputs = [gr.Image(type="pil", label="Segmentation Overlay"), gr.Image(type="pil", label="Segmentation Map")]
|
208 |
+
|
209 |
+
|
210 |
+
examples = [["examples/coco.jpeg", "the task is panoptic", "COCO (133 classes)", "DiNAT-L"],
|
211 |
+
["examples/cityscapes.png", "the task is panoptic", "Cityscapes (19 classes)", "DiNAT-L"],
|
212 |
+
["examples/ade20k.jpeg", "the task is panoptic", "ADE20K (150 classes)", "DiNAT-L"]]
|
213 |
+
|
214 |
+
|
215 |
+
iface = gr.Interface(fn=segment, inputs=gradio_inputs,
|
216 |
+
outputs=gradio_outputs,
|
217 |
+
examples_per_page=5,
|
218 |
+
allow_flagging="never",
|
219 |
+
examples=examples, title=title,
|
220 |
+
description=description)
|
221 |
+
|
222 |
+
iface.launch(server_name="0.0.0.0",share=True)
|
gradio_combine.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import imutils
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
import colorsys
|
10 |
+
from scipy import ndimage
|
11 |
+
import gradio as gr
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
# Detectron2 imports
|
15 |
+
from detectron2.config import get_cfg
|
16 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
17 |
+
from detectron2.data import MetadataCatalog
|
18 |
+
from detectron2.engine import DefaultPredictor as DetectronPredictor
|
19 |
+
from detectron2 import model_zoo
|
20 |
+
|
21 |
+
# OneFormer imports
|
22 |
+
from oneformer import (
|
23 |
+
add_oneformer_config,
|
24 |
+
add_common_config,
|
25 |
+
add_swin_config,
|
26 |
+
add_dinat_config,
|
27 |
+
)
|
28 |
+
from demo.defaults import DefaultPredictor as OneFormerPredictor
|
29 |
+
from demo.visualizer import Visualizer, ColorMode
|
30 |
+
|
31 |
+
# NeuroNest contrast detection imports
|
32 |
+
from utils.contrast_detector import ContrastDetector
|
33 |
+
from utils.luminance_contrast import LuminanceContrastDetector
|
34 |
+
from utils.hue_contrast import HueContrastDetector
|
35 |
+
from utils.saturation_contrast import SaturationContrastDetector
|
36 |
+
from utils.combined_contrast import CombinedContrastDetector
|
37 |
+
|
38 |
+
# Set threads for CPU optimization
|
39 |
+
torch.set_num_threads(4)
|
40 |
+
|
41 |
+
########################################
|
42 |
+
# GLOBAL CONFIGURATIONS
|
43 |
+
########################################
|
44 |
+
|
45 |
+
# OneFormer configurations
|
46 |
+
KEY_DICT = {"ADE20K (150 classes)": "ade20k"}
|
47 |
+
|
48 |
+
SWIN_CFG_DICT = {
|
49 |
+
"ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
|
50 |
+
}
|
51 |
+
|
52 |
+
SWIN_MODEL_DICT = {
|
53 |
+
"ade20k": hf_hub_download(
|
54 |
+
repo_id="shi-labs/oneformer_ade20k_swin_large",
|
55 |
+
filename="250_16_swin_l_oneformer_ade20k_160k.pth"
|
56 |
+
)
|
57 |
+
}
|
58 |
+
|
59 |
+
DINAT_CFG_DICT = {
|
60 |
+
"ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",
|
61 |
+
}
|
62 |
+
|
63 |
+
DINAT_MODEL_DICT = {
|
64 |
+
"ade20k": hf_hub_download(
|
65 |
+
repo_id="shi-labs/oneformer_ade20k_dinat_large",
|
66 |
+
filename="250_16_dinat_l_oneformer_ade20k_160k.pth"
|
67 |
+
)
|
68 |
+
}
|
69 |
+
|
70 |
+
MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT, "Swin-L": SWIN_MODEL_DICT}
|
71 |
+
CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT, "Swin-L": SWIN_CFG_DICT}
|
72 |
+
WIDTH_DICT = {"ade20k": 640}
|
73 |
+
|
74 |
+
# Contrast detector mapping
|
75 |
+
CONTRAST_DETECTORS = {
|
76 |
+
"Luminance (WCAG)": LuminanceContrastDetector(),
|
77 |
+
"Hue": HueContrastDetector(),
|
78 |
+
"Saturation": SaturationContrastDetector(),
|
79 |
+
"Combined": CombinedContrastDetector()
|
80 |
+
}
|
81 |
+
|
82 |
+
# Device configuration
|
83 |
+
cpu_device = torch.device("cpu")
|
84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
+
print(f"Using device: {device}")
|
86 |
+
|
87 |
+
# Model storage
|
88 |
+
ONEFORMER_PREDICTORS = {
|
89 |
+
"DiNAT-L": {"ADE20K (150 classes)": None},
|
90 |
+
"Swin-L": {"ADE20K (150 classes)": None}
|
91 |
+
}
|
92 |
+
|
93 |
+
ONEFORMER_METADATA = {
|
94 |
+
"DiNAT-L": {"ADE20K (150 classes)": None},
|
95 |
+
"Swin-L": {"ADE20K (150 classes)": None}
|
96 |
+
}
|
97 |
+
|
98 |
+
########################################
|
99 |
+
# MASK R-CNN SETUP AND FUNCTIONS
|
100 |
+
########################################
|
101 |
+
|
102 |
+
def load_maskrcnn_model(weights_path, device="cuda", threshold=0.5):
|
103 |
+
"""Load Mask R-CNN model for blackspot detection"""
|
104 |
+
cfg = get_cfg()
|
105 |
+
cfg.merge_from_file(
|
106 |
+
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
107 |
+
)
|
108 |
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # [Floors, blackspot]
|
109 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
|
110 |
+
cfg.MODEL.WEIGHTS = weights_path
|
111 |
+
# Fix: Convert torch.device to string
|
112 |
+
cfg.MODEL.DEVICE = str(device) if isinstance(device, torch.device) else device
|
113 |
+
return DetectronPredictor(cfg)
|
114 |
+
def postprocess_blackspot_masks(im, instances, show_floor=True, show_blackspot=True):
|
115 |
+
"""Extract floor and blackspot masks from Mask R-CNN predictions"""
|
116 |
+
height, width = im.shape[:2]
|
117 |
+
pred_classes = instances.pred_classes.cpu().numpy()
|
118 |
+
pred_masks = instances.pred_masks.cpu().numpy()
|
119 |
+
|
120 |
+
combined_floor_mask = np.zeros((height, width), dtype=bool)
|
121 |
+
final_blackspot = np.zeros((height, width), dtype=bool)
|
122 |
+
|
123 |
+
for cls_id, mask in zip(pred_classes, pred_masks):
|
124 |
+
if cls_id == 0 and show_floor: # Floor class
|
125 |
+
combined_floor_mask |= mask
|
126 |
+
elif cls_id == 1 and show_blackspot: # Blackspot class
|
127 |
+
final_blackspot |= mask
|
128 |
+
|
129 |
+
return combined_floor_mask.astype(np.uint8), final_blackspot.astype(np.uint8)
|
130 |
+
|
131 |
+
########################################
|
132 |
+
# ONEFORMER SETUP AND FUNCTIONS
|
133 |
+
########################################
|
134 |
+
|
135 |
+
def setup_oneformer_modules():
|
136 |
+
"""Initialize OneFormer models"""
|
137 |
+
for dataset in ["ADE20K (150 classes)"]:
|
138 |
+
for backbone in ["DiNAT-L", "Swin-L"]:
|
139 |
+
cfg = setup_oneformer_cfg(dataset, backbone)
|
140 |
+
metadata = MetadataCatalog.get(
|
141 |
+
cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
142 |
+
)
|
143 |
+
ONEFORMER_PREDICTORS[backbone][dataset] = OneFormerPredictor(cfg)
|
144 |
+
ONEFORMER_METADATA[backbone][dataset] = metadata
|
145 |
+
|
146 |
+
def setup_oneformer_cfg(dataset, backbone):
|
147 |
+
"""Setup OneFormer configuration"""
|
148 |
+
cfg = get_cfg()
|
149 |
+
add_deeplab_config(cfg)
|
150 |
+
add_common_config(cfg)
|
151 |
+
add_swin_config(cfg)
|
152 |
+
add_oneformer_config(cfg)
|
153 |
+
add_dinat_config(cfg)
|
154 |
+
dataset = KEY_DICT[dataset]
|
155 |
+
cfg_path = CFG_DICT[backbone][dataset]
|
156 |
+
cfg.merge_from_file(cfg_path)
|
157 |
+
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
158 |
+
cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
|
159 |
+
cfg.freeze()
|
160 |
+
return cfg
|
161 |
+
|
162 |
+
def semantic_run(img, predictor, metadata):
|
163 |
+
"""Run OneFormer semantic segmentation"""
|
164 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
165 |
+
predictions = predictor(img, "semantic")
|
166 |
+
out = visualizer.draw_sem_seg(
|
167 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
|
168 |
+
)
|
169 |
+
return out, predictions["sem_seg"].argmax(dim=0).to(cpu_device).numpy()
|
170 |
+
|
171 |
+
########################################
|
172 |
+
# INTEGRATED ANALYSIS FUNCTION
|
173 |
+
########################################
|
174 |
+
|
175 |
+
def integrated_analysis(image_path,
|
176 |
+
# Blackspot detection parameters
|
177 |
+
blackspot_threshold, show_floor, show_blackspot,
|
178 |
+
# Contrast detection parameters
|
179 |
+
enable_contrast, backbone, contrast_method, contrast_threshold):
|
180 |
+
"""
|
181 |
+
Perform integrated analysis with both blackspot detection and contrast analysis
|
182 |
+
"""
|
183 |
+
# Read the image
|
184 |
+
im = cv2.imread(image_path)
|
185 |
+
if im is None:
|
186 |
+
return "Error: could not read image!", None, None, None
|
187 |
+
|
188 |
+
# Resize for OneFormer if contrast analysis is enabled
|
189 |
+
if enable_contrast:
|
190 |
+
width = WIDTH_DICT["ade20k"]
|
191 |
+
im_resized = imutils.resize(im, width=width)
|
192 |
+
else:
|
193 |
+
im_resized = im
|
194 |
+
|
195 |
+
# Part 1: Blackspot Detection with Mask R-CNN
|
196 |
+
blackspot_text = []
|
197 |
+
blackspot_viz = None
|
198 |
+
|
199 |
+
if show_floor or show_blackspot:
|
200 |
+
weights_path = "./output_floor_blackspot/model_0004999.pth"
|
201 |
+
maskrcnn_predictor = load_maskrcnn_model(weights_path, device, blackspot_threshold)
|
202 |
+
|
203 |
+
# Run blackspot detection
|
204 |
+
outputs = maskrcnn_predictor(im)
|
205 |
+
instances = outputs["instances"]
|
206 |
+
|
207 |
+
# Post-process masks
|
208 |
+
floor_mask, blackspot_mask = postprocess_blackspot_masks(im, instances, show_floor, show_blackspot)
|
209 |
+
|
210 |
+
# Create visualization
|
211 |
+
blackspot_overlay = im.copy()
|
212 |
+
overlay = np.zeros_like(im)
|
213 |
+
|
214 |
+
if show_floor:
|
215 |
+
overlay[floor_mask > 0] = (0, 255, 0) # Green for floor
|
216 |
+
if show_blackspot:
|
217 |
+
overlay[blackspot_mask > 0] = (0, 0, 255) # Red for blackspot
|
218 |
+
|
219 |
+
blackspot_overlay = cv2.addWeighted(im, 1.0, overlay, 0.5, 0)
|
220 |
+
blackspot_viz = Image.fromarray(cv2.cvtColor(blackspot_overlay, cv2.COLOR_BGR2RGB))
|
221 |
+
|
222 |
+
# Calculate statistics
|
223 |
+
blackspot_area = int(blackspot_mask.sum())
|
224 |
+
floor_area = int(floor_mask.sum())
|
225 |
+
|
226 |
+
blackspot_text.append(f"### Blackspot Detection Results")
|
227 |
+
blackspot_text.append(f"**Threshold:** {blackspot_threshold:.2f}")
|
228 |
+
|
229 |
+
if show_floor:
|
230 |
+
blackspot_text.append(f"**Floor area:** {floor_area} pixels")
|
231 |
+
if show_blackspot:
|
232 |
+
blackspot_text.append(f"**Blackspot area:** {blackspot_area} pixels")
|
233 |
+
if floor_area > 0 and show_floor:
|
234 |
+
percentage = (blackspot_area / floor_area) * 100
|
235 |
+
blackspot_text.append(f"**Blackspot coverage:** {percentage:.2f}% of floor area")
|
236 |
+
|
237 |
+
# Part 2: Contrast Analysis with OneFormer
|
238 |
+
segmentation_viz = None
|
239 |
+
contrast_viz = None
|
240 |
+
contrast_text = []
|
241 |
+
|
242 |
+
if enable_contrast:
|
243 |
+
dataset = "ADE20K (150 classes)"
|
244 |
+
predictor = ONEFORMER_PREDICTORS[backbone][dataset]
|
245 |
+
metadata = ONEFORMER_METADATA[backbone][dataset]
|
246 |
+
|
247 |
+
# Get segmentation
|
248 |
+
out, seg_mask = semantic_run(im_resized, predictor, metadata)
|
249 |
+
segmentation_viz = Image.fromarray(out.get_image())
|
250 |
+
|
251 |
+
# Analyze contrast
|
252 |
+
img_rgb = cv2.cvtColor(im_resized, cv2.COLOR_BGR2RGB)
|
253 |
+
detector = CONTRAST_DETECTORS[contrast_method]
|
254 |
+
contrast_image, problem_areas, stats = detector.analyze(
|
255 |
+
img_rgb, seg_mask, contrast_threshold
|
256 |
+
)
|
257 |
+
|
258 |
+
contrast_viz = Image.fromarray(contrast_image)
|
259 |
+
|
260 |
+
# Create stats text
|
261 |
+
contrast_text.append(f"### Contrast Analysis Results")
|
262 |
+
contrast_text.append(f"**Method:** {contrast_method}")
|
263 |
+
contrast_text.append(f"**Threshold:** {contrast_threshold:.2f}")
|
264 |
+
contrast_text.append(f"**Problem Areas:** {stats['problem_count']}")
|
265 |
+
|
266 |
+
if 'min_contrast' in stats:
|
267 |
+
contrast_text.append(f"**Min Contrast:** {stats['min_contrast']:.2f}")
|
268 |
+
if 'max_contrast' in stats:
|
269 |
+
contrast_text.append(f"**Max Contrast:** {stats['max_contrast']:.2f}")
|
270 |
+
if 'average_contrast' in stats:
|
271 |
+
contrast_text.append(f"**Average Contrast:** {stats['average_contrast']:.2f}")
|
272 |
+
|
273 |
+
# Combine results
|
274 |
+
combined_text = []
|
275 |
+
if blackspot_text:
|
276 |
+
combined_text.extend(blackspot_text)
|
277 |
+
if contrast_text:
|
278 |
+
if blackspot_text:
|
279 |
+
combined_text.append("\n")
|
280 |
+
combined_text.extend(contrast_text)
|
281 |
+
|
282 |
+
return "\n".join(combined_text), blackspot_viz, segmentation_viz, contrast_viz
|
283 |
+
|
284 |
+
########################################
|
285 |
+
# GRADIO INTERFACE
|
286 |
+
########################################
|
287 |
+
|
288 |
+
# Initialize models
|
289 |
+
print("Initializing OneFormer models...")
|
290 |
+
setup_oneformer_modules()
|
291 |
+
|
292 |
+
title = "NeuroNest: Integrated Blackspot & Contrast Detection"
|
293 |
+
description = """
|
294 |
+
This integrated system combines:
|
295 |
+
1. **Blackspot Detection**: Uses Mask R-CNN to detect blackspots on floors
|
296 |
+
2. **Contrast Analysis**: Uses OneFormer segmentation to analyze contrast between objects
|
297 |
+
|
298 |
+
Both analyses help identify potential accessibility issues for individuals with Alzheimer's disease.
|
299 |
+
"""
|
300 |
+
|
301 |
+
# Create the Gradio interface
|
302 |
+
with gr.Blocks(title=title) as demo:
|
303 |
+
gr.Markdown(f"# {title}")
|
304 |
+
gr.Markdown(description)
|
305 |
+
|
306 |
+
with gr.Row():
|
307 |
+
with gr.Column(scale=1):
|
308 |
+
# Input image
|
309 |
+
image_input = gr.Image(label="Input Image", type="filepath")
|
310 |
+
|
311 |
+
# Blackspot detection controls
|
312 |
+
with gr.Accordion("Blackspot Detection Settings", open=True):
|
313 |
+
blackspot_threshold = gr.Slider(
|
314 |
+
minimum=0.1, maximum=0.9, value=0.5, step=0.05,
|
315 |
+
label="Blackspot Detection Threshold"
|
316 |
+
)
|
317 |
+
with gr.Row():
|
318 |
+
show_floor = gr.Checkbox(value=True, label="Show Floor")
|
319 |
+
show_blackspot = gr.Checkbox(value=True, label="Show Blackspots")
|
320 |
+
|
321 |
+
# Contrast analysis controls
|
322 |
+
with gr.Accordion("Contrast Analysis Settings", open=True):
|
323 |
+
enable_contrast = gr.Checkbox(value=True, label="Enable Contrast Analysis")
|
324 |
+
backbone = gr.Radio(
|
325 |
+
choices=["Swin-L", "DiNAT-L"],
|
326 |
+
value="Swin-L",
|
327 |
+
label="OneFormer Backbone"
|
328 |
+
)
|
329 |
+
contrast_method = gr.Radio(
|
330 |
+
choices=["Luminance (WCAG)", "Hue", "Saturation", "Combined"],
|
331 |
+
value="Luminance (WCAG)",
|
332 |
+
label="Contrast Detection Method"
|
333 |
+
)
|
334 |
+
contrast_threshold = gr.Slider(
|
335 |
+
minimum=1.0, maximum=10.0, value=4.5, step=0.1,
|
336 |
+
label="Contrast Threshold"
|
337 |
+
)
|
338 |
+
|
339 |
+
analyze_btn = gr.Button("Analyze", variant="primary")
|
340 |
+
|
341 |
+
with gr.Column(scale=2):
|
342 |
+
# Output displays
|
343 |
+
with gr.Tabs():
|
344 |
+
with gr.Tab("Analysis Report"):
|
345 |
+
analysis_text = gr.Textbox(label="Analysis Results", lines=10)
|
346 |
+
|
347 |
+
with gr.Tab("Blackspot Detection"):
|
348 |
+
blackspot_output = gr.Image(label="Blackspot Visualization")
|
349 |
+
|
350 |
+
with gr.Tab("Segmentation"):
|
351 |
+
segmentation_output = gr.Image(label="OneFormer Segmentation")
|
352 |
+
|
353 |
+
with gr.Tab("Contrast Analysis"):
|
354 |
+
contrast_output = gr.Image(label="Contrast Visualization")
|
355 |
+
|
356 |
+
# Connect the interface
|
357 |
+
analyze_btn.click(
|
358 |
+
fn=integrated_analysis,
|
359 |
+
inputs=[
|
360 |
+
image_input,
|
361 |
+
blackspot_threshold, show_floor, show_blackspot,
|
362 |
+
enable_contrast, backbone, contrast_method, contrast_threshold
|
363 |
+
],
|
364 |
+
outputs=[
|
365 |
+
analysis_text, blackspot_output, segmentation_output, contrast_output
|
366 |
+
]
|
367 |
+
)
|
368 |
+
|
369 |
+
# Examples
|
370 |
+
gr.Examples(
|
371 |
+
examples=[
|
372 |
+
["examples/indoor_room.jpg", 0.5, True, True, True, "Swin-L", "Luminance (WCAG)", 4.5],
|
373 |
+
["examples/living_room.jpg", 0.7, True, True, True, "DiNAT-L", "Combined", 3.0],
|
374 |
+
],
|
375 |
+
inputs=[
|
376 |
+
image_input,
|
377 |
+
blackspot_threshold, show_floor, show_blackspot,
|
378 |
+
enable_contrast, backbone, contrast_method, contrast_threshold
|
379 |
+
]
|
380 |
+
)
|
381 |
+
|
382 |
+
if __name__ == "__main__":
|
383 |
+
print(f"Launching integrated NeuroNest app on device: {device}")
|
384 |
+
demo.queue().launch(server_name="0.0.0.0", share=True)
|
gradio_contrast.py
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import cv2
|
8 |
+
import imutils
|
9 |
+
import colorsys
|
10 |
+
from scipy import ndimage
|
11 |
+
|
12 |
+
# Set CUDA device explicitly at the start
|
13 |
+
if torch.cuda.is_available():
|
14 |
+
torch.cuda.set_device(0) # Use first GPU
|
15 |
+
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
16 |
+
else:
|
17 |
+
print("WARNING: No GPU available, using CPU")
|
18 |
+
|
19 |
+
print("Installed the dependencies!")
|
20 |
+
|
21 |
+
from detectron2.config import get_cfg
|
22 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
23 |
+
from detectron2.data import MetadataCatalog
|
24 |
+
|
25 |
+
from oneformer import (
|
26 |
+
add_oneformer_config,
|
27 |
+
add_common_config,
|
28 |
+
add_swin_config,
|
29 |
+
add_dinat_config,
|
30 |
+
)
|
31 |
+
|
32 |
+
from demo.defaults import DefaultPredictor
|
33 |
+
from demo.visualizer import Visualizer, ColorMode
|
34 |
+
|
35 |
+
import gradio as gr
|
36 |
+
from huggingface_hub import hf_hub_download
|
37 |
+
|
38 |
+
# Force unbuffered output for SLURM logs
|
39 |
+
sys.stdout = sys.__stdout__
|
40 |
+
sys.stderr = sys.__stderr__
|
41 |
+
|
42 |
+
# Set environment variables for better GPU performance
|
43 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
44 |
+
os.environ['TORCH_USE_CUDA_DSA'] = '1'
|
45 |
+
|
46 |
+
# Contrast Detection Classes
|
47 |
+
class ContrastDetector:
|
48 |
+
"""Base class for contrast detection between segments"""
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def calculate_luminance_contrast(color1, color2):
|
52 |
+
"""Calculate WCAG luminance contrast ratio"""
|
53 |
+
def get_relative_luminance(rgb):
|
54 |
+
r, g, b = [val/255.0 for val in rgb]
|
55 |
+
r = r/12.92 if r <= 0.03928 else ((r + 0.055)/1.055) ** 2.4
|
56 |
+
g = g/12.92 if g <= 0.03928 else ((g + 0.055)/1.055) ** 2.4
|
57 |
+
b = b/12.92 if b <= 0.03928 else ((b + 0.055)/1.055) ** 2.4
|
58 |
+
return 0.2126 * r + 0.7152 * g + 0.0722 * b
|
59 |
+
|
60 |
+
lum1 = get_relative_luminance(color1)
|
61 |
+
lum2 = get_relative_luminance(color2)
|
62 |
+
|
63 |
+
lighter = max(lum1, lum2)
|
64 |
+
darker = min(lum1, lum2)
|
65 |
+
|
66 |
+
return (lighter + 0.05) / (darker + 0.05)
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def calculate_hue_contrast(color1, color2):
|
70 |
+
"""Calculate hue difference between two colors"""
|
71 |
+
hsv1 = colorsys.rgb_to_hsv(color1[0]/255.0, color1[1]/255.0, color1[2]/255.0)
|
72 |
+
hsv2 = colorsys.rgb_to_hsv(color2[0]/255.0, color2[1]/255.0, color2[2]/255.0)
|
73 |
+
|
74 |
+
hue_diff = abs(hsv1[0] - hsv2[0])
|
75 |
+
if hue_diff > 0.5:
|
76 |
+
hue_diff = 1 - hue_diff
|
77 |
+
|
78 |
+
return hue_diff * 2
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def calculate_saturation_contrast(color1, color2):
|
82 |
+
"""Calculate saturation difference between two colors"""
|
83 |
+
hsv1 = colorsys.rgb_to_hsv(color1[0]/255.0, color1[1]/255.0, color1[2]/255.0)
|
84 |
+
hsv2 = colorsys.rgb_to_hsv(color2[0]/255.0, color2[1]/255.0, color2[2]/255.0)
|
85 |
+
|
86 |
+
return abs(hsv1[1] - hsv2[1])
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def analyze_contrast(image, segmentation, method="luminance", threshold=4.5):
|
90 |
+
"""Analyze contrast between adjacent segments"""
|
91 |
+
unique_segments = np.unique(segmentation)
|
92 |
+
h, w = segmentation.shape
|
93 |
+
contrast_mask = np.zeros((h, w), dtype=bool)
|
94 |
+
problem_areas = []
|
95 |
+
|
96 |
+
# Calculate average colors for each segment
|
97 |
+
segment_colors = {}
|
98 |
+
for seg_id in unique_segments:
|
99 |
+
mask = segmentation == seg_id
|
100 |
+
if np.any(mask):
|
101 |
+
segment_colors[seg_id] = np.mean(image[mask], axis=0).astype(int)
|
102 |
+
|
103 |
+
# Check contrast between adjacent segments
|
104 |
+
for i in range(h):
|
105 |
+
for j in range(w):
|
106 |
+
current_seg = segmentation[i, j]
|
107 |
+
|
108 |
+
# Check 4-connected neighbors
|
109 |
+
for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
|
110 |
+
ni, nj = i + di, j + dj
|
111 |
+
if 0 <= ni < h and 0 <= nj < w:
|
112 |
+
neighbor_seg = segmentation[ni, nj]
|
113 |
+
|
114 |
+
if current_seg != neighbor_seg:
|
115 |
+
color1 = segment_colors[current_seg]
|
116 |
+
color2 = segment_colors[neighbor_seg]
|
117 |
+
|
118 |
+
if method == "luminance":
|
119 |
+
contrast = ContrastDetector.calculate_luminance_contrast(color1, color2)
|
120 |
+
elif method == "hue":
|
121 |
+
contrast = ContrastDetector.calculate_hue_contrast(color1, color2)
|
122 |
+
threshold = 0.3 # Adjust threshold for hue
|
123 |
+
elif method == "saturation":
|
124 |
+
contrast = ContrastDetector.calculate_saturation_contrast(color1, color2)
|
125 |
+
threshold = 0.3 # Adjust threshold for saturation
|
126 |
+
|
127 |
+
if contrast < threshold:
|
128 |
+
contrast_mask[i, j] = True
|
129 |
+
problem_areas.append((current_seg, neighbor_seg, contrast))
|
130 |
+
|
131 |
+
return contrast_mask, problem_areas, segment_colors
|
132 |
+
|
133 |
+
# Rest of your code remains the same until setup_cfg function
|
134 |
+
KEY_DICT = {"Cityscapes (19 classes)": "cityscapes",
|
135 |
+
"COCO (133 classes)": "coco",
|
136 |
+
"ADE20K (150 classes)": "ade20k",}
|
137 |
+
|
138 |
+
SWIN_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_swin_large_IN21k_384_bs16_90k.yaml",
|
139 |
+
"coco": "configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml",
|
140 |
+
"ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",}
|
141 |
+
|
142 |
+
SWIN_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_swin_large",
|
143 |
+
filename="250_16_swin_l_oneformer_cityscapes_90k.pth"),
|
144 |
+
"coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_swin_large",
|
145 |
+
filename="150_16_swin_l_oneformer_coco_100ep.pth"),
|
146 |
+
"ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_swin_large",
|
147 |
+
filename="250_16_swin_l_oneformer_ade20k_160k.pth")
|
148 |
+
}
|
149 |
+
|
150 |
+
DINAT_CFG_DICT = {"cityscapes": "configs/cityscapes/oneformer_dinat_large_bs16_90k.yaml",
|
151 |
+
"coco": "configs/coco/oneformer_dinat_large_bs16_100ep.yaml",
|
152 |
+
"ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",}
|
153 |
+
|
154 |
+
DINAT_MODEL_DICT = {"cityscapes": hf_hub_download(repo_id="shi-labs/oneformer_cityscapes_dinat_large",
|
155 |
+
filename="250_16_dinat_l_oneformer_cityscapes_90k.pth"),
|
156 |
+
"coco": hf_hub_download(repo_id="shi-labs/oneformer_coco_dinat_large",
|
157 |
+
filename="150_16_dinat_l_oneformer_coco_100ep.pth"),
|
158 |
+
"ade20k": hf_hub_download(repo_id="shi-labs/oneformer_ade20k_dinat_large",
|
159 |
+
filename="250_16_dinat_l_oneformer_ade20k_160k.pth")
|
160 |
+
}
|
161 |
+
|
162 |
+
MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT,
|
163 |
+
"Swin-L": SWIN_MODEL_DICT }
|
164 |
+
|
165 |
+
CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT,
|
166 |
+
"Swin-L": SWIN_CFG_DICT }
|
167 |
+
|
168 |
+
WIDTH_DICT = {"cityscapes": 512,
|
169 |
+
"coco": 512,
|
170 |
+
"ade20k": 640}
|
171 |
+
|
172 |
+
# Modified to ensure CUDA device
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
device = torch.device("cuda:0")
|
175 |
+
print(f"Using device: {device}")
|
176 |
+
else:
|
177 |
+
device = torch.device("cpu")
|
178 |
+
print(f"WARNING: Using CPU device")
|
179 |
+
|
180 |
+
cpu_device = torch.device("cpu")
|
181 |
+
|
182 |
+
PREDICTORS = {
|
183 |
+
"DiNAT-L": {
|
184 |
+
"Cityscapes (19 classes)": None,
|
185 |
+
"COCO (133 classes)": None,
|
186 |
+
"ADE20K (150 classes)": None
|
187 |
+
},
|
188 |
+
"Swin-L": {
|
189 |
+
"Cityscapes (19 classes)": None,
|
190 |
+
"COCO (133 classes)": None,
|
191 |
+
"ADE20K (150 classes)": None
|
192 |
+
}
|
193 |
+
}
|
194 |
+
|
195 |
+
METADATA = {
|
196 |
+
"DiNAT-L": {
|
197 |
+
"Cityscapes (19 classes)": None,
|
198 |
+
"COCO (133 classes)": None,
|
199 |
+
"ADE20K (150 classes)": None
|
200 |
+
},
|
201 |
+
"Swin-L": {
|
202 |
+
"Cityscapes (19 classes)": None,
|
203 |
+
"COCO (133 classes)": None,
|
204 |
+
"ADE20K (150 classes)": None
|
205 |
+
}
|
206 |
+
}
|
207 |
+
|
208 |
+
def setup_modules():
|
209 |
+
print("Setting up modules...")
|
210 |
+
for dataset in ["Cityscapes (19 classes)", "COCO (133 classes)", "ADE20K (150 classes)"]:
|
211 |
+
for backbone in ["DiNAT-L", "Swin-L"]:
|
212 |
+
print(f"Loading {backbone} for {dataset}...")
|
213 |
+
cfg = setup_cfg(dataset, backbone)
|
214 |
+
metadata = MetadataCatalog.get(
|
215 |
+
cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
216 |
+
)
|
217 |
+
if 'cityscapes_fine_sem_seg_val' in cfg.DATASETS.TEST_PANOPTIC[0]:
|
218 |
+
from cityscapesscripts.helpers.labels import labels
|
219 |
+
stuff_colors = [k.color for k in labels if k.trainId != 255]
|
220 |
+
metadata = metadata.set(stuff_colors=stuff_colors)
|
221 |
+
|
222 |
+
# Create predictor with explicit device
|
223 |
+
predictor = DefaultPredictor(cfg)
|
224 |
+
predictor.model.to(device)
|
225 |
+
|
226 |
+
PREDICTORS[backbone][dataset] = predictor
|
227 |
+
METADATA[backbone][dataset] = metadata
|
228 |
+
print(f"✓ Loaded {backbone} for {dataset}")
|
229 |
+
print("All modules setup complete!")
|
230 |
+
|
231 |
+
def setup_cfg(dataset, backbone):
|
232 |
+
# load config from file and command-line arguments
|
233 |
+
cfg = get_cfg()
|
234 |
+
add_deeplab_config(cfg)
|
235 |
+
add_common_config(cfg)
|
236 |
+
add_swin_config(cfg)
|
237 |
+
add_oneformer_config(cfg)
|
238 |
+
add_dinat_config(cfg)
|
239 |
+
dataset = KEY_DICT[dataset]
|
240 |
+
cfg_path = CFG_DICT[backbone][dataset]
|
241 |
+
cfg.merge_from_file(cfg_path)
|
242 |
+
|
243 |
+
# Explicitly set device to CUDA if available
|
244 |
+
if torch.cuda.is_available():
|
245 |
+
cfg.MODEL.DEVICE = 'cuda:0'
|
246 |
+
print(f"Config set to use CUDA device")
|
247 |
+
else:
|
248 |
+
cfg.MODEL.DEVICE = 'cpu'
|
249 |
+
print(f"Config set to use CPU device")
|
250 |
+
|
251 |
+
cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
|
252 |
+
cfg.freeze()
|
253 |
+
return cfg
|
254 |
+
|
255 |
+
# Rest of your functions remain the same
|
256 |
+
def panoptic_run(img, predictor, metadata):
|
257 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
258 |
+
predictions = predictor(img, "panoptic")
|
259 |
+
panoptic_seg, segments_info = predictions["panoptic_seg"]
|
260 |
+
out = visualizer.draw_panoptic_seg_predictions(
|
261 |
+
panoptic_seg.to(cpu_device), segments_info, alpha=0.5
|
262 |
+
)
|
263 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
264 |
+
out_map = visualizer_map.draw_panoptic_seg_predictions(
|
265 |
+
panoptic_seg.to(cpu_device), segments_info, alpha=1, is_text=False
|
266 |
+
)
|
267 |
+
return out, out_map, predictions
|
268 |
+
|
269 |
+
def instance_run(img, predictor, metadata):
|
270 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
271 |
+
predictions = predictor(img, "instance")
|
272 |
+
instances = predictions["instances"].to(cpu_device)
|
273 |
+
out = visualizer.draw_instance_predictions(predictions=instances, alpha=0.5)
|
274 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
275 |
+
out_map = visualizer_map.draw_instance_predictions(predictions=instances, alpha=1, is_text=False)
|
276 |
+
return out, out_map, predictions
|
277 |
+
|
278 |
+
def semantic_run(img, predictor, metadata):
|
279 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
280 |
+
predictions = predictor(img, "semantic")
|
281 |
+
out = visualizer.draw_sem_seg(
|
282 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
|
283 |
+
)
|
284 |
+
visualizer_map = Visualizer(img[:, :, ::-1], is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
|
285 |
+
out_map = visualizer_map.draw_sem_seg(
|
286 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=1, is_text=False
|
287 |
+
)
|
288 |
+
return out, out_map, predictions
|
289 |
+
|
290 |
+
TASK_INFER = {"the task is panoptic": panoptic_run, "the task is instance": instance_run, "the task is semantic": semantic_run}
|
291 |
+
|
292 |
+
def create_contrast_visualization(img, contrast_mask, problem_areas, segment_colors):
|
293 |
+
"""Create visualization of contrast issues"""
|
294 |
+
# Copy original image
|
295 |
+
contrast_viz = img.copy()
|
296 |
+
|
297 |
+
# Highlight low contrast boundaries
|
298 |
+
boundary_color = (255, 0, 0) # Red for problem areas
|
299 |
+
contrast_viz[contrast_mask] = boundary_color
|
300 |
+
|
301 |
+
# Add information overlay
|
302 |
+
info_text = f"Low contrast areas detected: {len(problem_areas)}"
|
303 |
+
cv2.putText(contrast_viz, info_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
304 |
+
|
305 |
+
return contrast_viz
|
306 |
+
|
307 |
+
def segment_and_analyze(path, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold):
|
308 |
+
# Get predictions and segmentation visualization
|
309 |
+
predictor = PREDICTORS[backbone][dataset]
|
310 |
+
metadata = METADATA[backbone][dataset]
|
311 |
+
img = cv2.imread(path)
|
312 |
+
width = WIDTH_DICT[KEY_DICT[dataset]]
|
313 |
+
img = imutils.resize(img, width=width)
|
314 |
+
|
315 |
+
out, out_map, predictions = TASK_INFER[task](img, predictor, metadata)
|
316 |
+
out_img = Image.fromarray(out.get_image())
|
317 |
+
out_map_img = Image.fromarray(out_map.get_image())
|
318 |
+
|
319 |
+
if not enable_contrast:
|
320 |
+
return out_img, out_map_img, None, None
|
321 |
+
|
322 |
+
# Extract segmentation mask from predictions
|
323 |
+
if task == "the task is semantic":
|
324 |
+
seg_mask = predictions["sem_seg"].argmax(dim=0).cpu().numpy()
|
325 |
+
elif task == "the task is panoptic":
|
326 |
+
seg_mask, _ = predictions["panoptic_seg"]
|
327 |
+
seg_mask = seg_mask.cpu().numpy()
|
328 |
+
elif task == "the task is instance":
|
329 |
+
# For instance segmentation, create a mask from instances
|
330 |
+
instances = predictions["instances"].to("cpu")
|
331 |
+
seg_mask = np.zeros(img.shape[:2], dtype=np.int32)
|
332 |
+
for i, mask in enumerate(instances.pred_masks):
|
333 |
+
seg_mask[mask] = i + 1
|
334 |
+
|
335 |
+
# Analyze contrast
|
336 |
+
contrast_mask, problem_areas, segment_colors = ContrastDetector.analyze_contrast(
|
337 |
+
img, seg_mask, method=contrast_method, threshold=contrast_threshold
|
338 |
+
)
|
339 |
+
|
340 |
+
# Create contrast visualization
|
341 |
+
contrast_viz = create_contrast_visualization(img, contrast_mask, problem_areas, segment_colors)
|
342 |
+
contrast_viz_img = Image.fromarray(contrast_viz[:, :, ::-1]) # Convert BGR to RGB
|
343 |
+
|
344 |
+
# Generate analysis report
|
345 |
+
report = f"### Contrast Analysis Report\n\n"
|
346 |
+
report += f"**Method:** {contrast_method.capitalize()}\n"
|
347 |
+
report += f"**Threshold:** {contrast_threshold}\n"
|
348 |
+
report += f"**Total segments:** {len(segment_colors)}\n"
|
349 |
+
report += f"**Low contrast boundaries found:** {len(problem_areas)}\n\n"
|
350 |
+
|
351 |
+
if problem_areas:
|
352 |
+
report += "**Problem Areas:**\n"
|
353 |
+
for i, (seg1, seg2, contrast_value) in enumerate(problem_areas[:10]): # Show first 10
|
354 |
+
report += f"- Segments {seg1} and {seg2}: Contrast ratio = {contrast_value:.2f}\n"
|
355 |
+
if len(problem_areas) > 10:
|
356 |
+
report += f"... and {len(problem_areas) - 10} more\n"
|
357 |
+
|
358 |
+
return out_img, out_map_img, contrast_viz_img, report
|
359 |
+
|
360 |
+
title = "<h1 style='text-align: center'>OneFormer:DIEGO MENTORIA MILIONÁRIA - APP 1</h1>"
|
361 |
+
description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://github.com/lolout1/sam2Contrast' style='text-decoration:none' target='_blank'>NeuroNest Contrast Model</a></p>" \
|
362 |
+
+ "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/oneformer/' target='_blank'>Project Page</a> | <a href='https://arxiv.org/abs/2211.06220' target='_blank'>ArXiv Paper</a> | <a href='https://github.com/SHI-Labs/OneFormer' target='_blank'>Github Repo</a></p>" \
|
363 |
+
+ "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
|
364 |
+
This model leverages the OneFormer architecture to perform comprehensive image segmentation and labeling across multiple tasks. The system can identify and segment various objects, structures, and regions within images with high accuracy. It supports semantic, instance, and panoptic segmentation modes, enabling detailed analysis of indoor and outdoor environments. The model excels at distinguishing between different classes of objects, from common everyday items to complex urban structures, making it particularly useful for environmental analysis and scene understanding applications.\
|
365 |
+
</p>" \
|
366 |
+
+ "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a single RTX A6000 GPU, OneFormer is able to inference at more than 15 FPS.]</p>"
|
367 |
+
|
368 |
+
# Main execution with error handling
|
369 |
+
if __name__ == "__main__":
|
370 |
+
try:
|
371 |
+
print("Starting setup...")
|
372 |
+
setup_modules()
|
373 |
+
|
374 |
+
print("Creating Gradio interface...")
|
375 |
+
with gr.Blocks(title="OneFormer with Contrast Detection") as iface:
|
376 |
+
gr.Markdown(title)
|
377 |
+
gr.Markdown(description)
|
378 |
+
|
379 |
+
with gr.Row():
|
380 |
+
with gr.Column(scale=1):
|
381 |
+
input_image = gr.Image(label="Input Image", type="filepath")
|
382 |
+
task = gr.Radio(
|
383 |
+
choices=["the task is panoptic", "the task is instance", "the task is semantic"],
|
384 |
+
value="the task is panoptic",
|
385 |
+
label="Task Token Input"
|
386 |
+
)
|
387 |
+
dataset = gr.Radio(
|
388 |
+
choices=["COCO (133 classes)", "Cityscapes (19 classes)", "ADE20K (150 classes)"],
|
389 |
+
value="COCO (133 classes)",
|
390 |
+
label="Model"
|
391 |
+
)
|
392 |
+
backbone = gr.Radio(
|
393 |
+
choices=["DiNAT-L", "Swin-L"],
|
394 |
+
value="DiNAT-L",
|
395 |
+
label="Backbone"
|
396 |
+
)
|
397 |
+
|
398 |
+
with gr.Accordion("Contrast Detection Options", open=False):
|
399 |
+
enable_contrast = gr.Checkbox(
|
400 |
+
label="Enable Contrast Detection",
|
401 |
+
value=False
|
402 |
+
)
|
403 |
+
contrast_method = gr.Radio(
|
404 |
+
choices=["luminance", "hue", "saturation"],
|
405 |
+
value="luminance",
|
406 |
+
label="Contrast Method"
|
407 |
+
)
|
408 |
+
contrast_threshold = gr.Slider(
|
409 |
+
minimum=1.0,
|
410 |
+
maximum=10.0,
|
411 |
+
value=4.5,
|
412 |
+
step=0.1,
|
413 |
+
label="Contrast Threshold (WCAG AA is 4.5)"
|
414 |
+
)
|
415 |
+
|
416 |
+
submit_btn = gr.Button("Analyze", variant="primary")
|
417 |
+
|
418 |
+
with gr.Column(scale=2):
|
419 |
+
with gr.Tabs():
|
420 |
+
with gr.TabItem("Segmentation Results"):
|
421 |
+
seg_output = gr.Image(type="pil", label="Segmentation Overlay")
|
422 |
+
seg_map = gr.Image(type="pil", label="Segmentation Map")
|
423 |
+
|
424 |
+
with gr.TabItem("Contrast Analysis"):
|
425 |
+
contrast_viz = gr.Image(type="pil", label="Contrast Visualization")
|
426 |
+
contrast_report = gr.Markdown(label="Contrast Analysis Report")
|
427 |
+
|
428 |
+
examples = [
|
429 |
+
["examples/coco.jpeg", "the task is panoptic", "COCO (133 classes)", "DiNAT-L", False, "luminance", 4.5],
|
430 |
+
["examples/cityscapes.png", "the task is panoptic", "Cityscapes (19 classes)", "DiNAT-L", False, "luminance", 4.5],
|
431 |
+
["examples/ade20k.jpeg", "the task is panoptic", "ADE20K (150 classes)", "DiNAT-L", False, "luminance", 4.5]
|
432 |
+
]
|
433 |
+
|
434 |
+
gr.Examples(
|
435 |
+
examples=examples,
|
436 |
+
inputs=[input_image, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold],
|
437 |
+
outputs=[seg_output, seg_map, contrast_viz, contrast_report],
|
438 |
+
fn=segment_and_analyze,
|
439 |
+
cache_examples=False
|
440 |
+
)
|
441 |
+
|
442 |
+
submit_btn.click(
|
443 |
+
fn=segment_and_analyze,
|
444 |
+
inputs=[input_image, task, dataset, backbone, enable_contrast, contrast_method, contrast_threshold],
|
445 |
+
outputs=[seg_output, seg_map, contrast_viz, contrast_report]
|
446 |
+
)
|
447 |
+
|
448 |
+
print("Launching Gradio app...")
|
449 |
+
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
|
450 |
+
|
451 |
+
except Exception as e:
|
452 |
+
print(f"Error occurred: {str(e)}")
|
453 |
+
import traceback
|
454 |
+
traceback.print_exc()
|
455 |
+
sys.exit(1)
|
gradio_gpu.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import imutils
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from detectron2.config import get_cfg
|
10 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from scipy import ndimage
|
13 |
+
import colorsys
|
14 |
+
import math
|
15 |
+
|
16 |
+
torch.set_num_threads(16)
|
17 |
+
torch.set_num_interop_threads(16)
|
18 |
+
|
19 |
+
from oneformer import (
|
20 |
+
add_oneformer_config,
|
21 |
+
add_common_config,
|
22 |
+
add_swin_config,
|
23 |
+
add_dinat_config,
|
24 |
+
)
|
25 |
+
|
26 |
+
from demo.defaults import DefaultPredictor
|
27 |
+
from demo.visualizer import Visualizer, ColorMode
|
28 |
+
|
29 |
+
import gradio as gr
|
30 |
+
from huggingface_hub import hf_hub_download
|
31 |
+
|
32 |
+
# NeuroNest specific imports
|
33 |
+
from utils.contrast_detector import ContrastDetector
|
34 |
+
from utils.luminance_contrast import LuminanceContrastDetector
|
35 |
+
from utils.hue_contrast import HueContrastDetector
|
36 |
+
from utils.saturation_contrast import SaturationContrastDetector
|
37 |
+
from utils.combined_contrast import CombinedContrastDetector
|
38 |
+
|
39 |
+
KEY_DICT = {
|
40 |
+
"ADE20K (150 classes)": "ade20k",
|
41 |
+
}
|
42 |
+
|
43 |
+
SWIN_CFG_DICT = {
|
44 |
+
"ade20k": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
|
45 |
+
}
|
46 |
+
|
47 |
+
SWIN_MODEL_DICT = {
|
48 |
+
"ade20k": hf_hub_download(
|
49 |
+
repo_id="shi-labs/oneformer_ade20k_swin_large",
|
50 |
+
filename="250_16_swin_l_oneformer_ade20k_160k.pth"
|
51 |
+
)
|
52 |
+
}
|
53 |
+
|
54 |
+
DINAT_CFG_DICT = {
|
55 |
+
"ade20k": "configs/ade20k/oneformer_dinat_large_IN21k_384_bs16_160k.yaml",
|
56 |
+
}
|
57 |
+
|
58 |
+
DINAT_MODEL_DICT = {
|
59 |
+
"ade20k": hf_hub_download(
|
60 |
+
repo_id="shi-labs/oneformer_ade20k_dinat_large",
|
61 |
+
filename="250_16_dinat_l_oneformer_ade20k_160k.pth"
|
62 |
+
)
|
63 |
+
}
|
64 |
+
|
65 |
+
MODEL_DICT = {"DiNAT-L": DINAT_MODEL_DICT, "Swin-L": SWIN_MODEL_DICT}
|
66 |
+
CFG_DICT = {"DiNAT-L": DINAT_CFG_DICT, "Swin-L": SWIN_CFG_DICT}
|
67 |
+
WIDTH_DICT = {"ade20k": 640}
|
68 |
+
|
69 |
+
cpu_device = torch.device("cpu")
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
+
print(f"Using device: {device}")
|
72 |
+
|
73 |
+
PREDICTORS = {
|
74 |
+
"DiNAT-L": {"ADE20K (150 classes)": None},
|
75 |
+
"Swin-L": {"ADE20K (150 classes)": None}
|
76 |
+
}
|
77 |
+
|
78 |
+
METADATA = {
|
79 |
+
"DiNAT-L": {"ADE20K (150 classes)": None},
|
80 |
+
"Swin-L": {"ADE20K (150 classes)": None}
|
81 |
+
}
|
82 |
+
|
83 |
+
# Contrast detector mapping
|
84 |
+
CONTRAST_DETECTORS = {
|
85 |
+
"Luminance (WCAG)": LuminanceContrastDetector(),
|
86 |
+
"Hue": HueContrastDetector(),
|
87 |
+
"Saturation": SaturationContrastDetector(),
|
88 |
+
"Combined": CombinedContrastDetector()
|
89 |
+
}
|
90 |
+
|
91 |
+
def setup_modules():
|
92 |
+
for dataset in ["ADE20K (150 classes)"]:
|
93 |
+
for backbone in ["DiNAT-L", "Swin-L"]:
|
94 |
+
cfg = setup_cfg(dataset, backbone)
|
95 |
+
metadata = MetadataCatalog.get(
|
96 |
+
cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
97 |
+
)
|
98 |
+
PREDICTORS[backbone][dataset] = DefaultPredictor(cfg)
|
99 |
+
METADATA[backbone][dataset] = metadata
|
100 |
+
|
101 |
+
def setup_cfg(dataset, backbone):
|
102 |
+
cfg = get_cfg()
|
103 |
+
add_deeplab_config(cfg)
|
104 |
+
add_common_config(cfg)
|
105 |
+
add_swin_config(cfg)
|
106 |
+
add_oneformer_config(cfg)
|
107 |
+
add_dinat_config(cfg)
|
108 |
+
dataset = KEY_DICT[dataset]
|
109 |
+
cfg_path = CFG_DICT[backbone][dataset]
|
110 |
+
cfg.merge_from_file(cfg_path)
|
111 |
+
cfg.MODEL.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
112 |
+
cfg.MODEL.WEIGHTS = MODEL_DICT[backbone][dataset]
|
113 |
+
cfg.freeze()
|
114 |
+
return cfg
|
115 |
+
|
116 |
+
def semantic_run(img, predictor, metadata):
|
117 |
+
visualizer = Visualizer(img[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
118 |
+
predictions = predictor(img, "semantic")
|
119 |
+
out = visualizer.draw_sem_seg(
|
120 |
+
predictions["sem_seg"].argmax(dim=0).to(cpu_device), alpha=0.5
|
121 |
+
)
|
122 |
+
return out, predictions["sem_seg"].argmax(dim=0).to(cpu_device).numpy()
|
123 |
+
|
124 |
+
def analyze_contrast(image, segmentation, contrast_method, threshold):
|
125 |
+
"""Analyze contrast between segments using selected method"""
|
126 |
+
detector = CONTRAST_DETECTORS[contrast_method]
|
127 |
+
|
128 |
+
# Perform contrast analysis
|
129 |
+
contrast_image, problem_areas, stats = detector.analyze(
|
130 |
+
image, segmentation, threshold
|
131 |
+
)
|
132 |
+
|
133 |
+
return contrast_image, problem_areas, stats
|
134 |
+
|
135 |
+
def segment_and_analyze_contrast(path, backbone, contrast_method, threshold):
|
136 |
+
"""Main function to segment and analyze contrast"""
|
137 |
+
dataset = "ADE20K (150 classes)"
|
138 |
+
predictor = PREDICTORS[backbone][dataset]
|
139 |
+
metadata = METADATA[backbone][dataset]
|
140 |
+
|
141 |
+
# Read and resize image
|
142 |
+
img = cv2.imread(path)
|
143 |
+
if img is None:
|
144 |
+
return None, None, "Error: Could not load image"
|
145 |
+
|
146 |
+
width = WIDTH_DICT[KEY_DICT[dataset]]
|
147 |
+
img = imutils.resize(img, width=width)
|
148 |
+
|
149 |
+
# Get segmentation
|
150 |
+
out, seg_mask = semantic_run(img, predictor, metadata)
|
151 |
+
out_img = Image.fromarray(out.get_image())
|
152 |
+
|
153 |
+
# Analyze contrast
|
154 |
+
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
155 |
+
contrast_img, problem_areas, stats = analyze_contrast(
|
156 |
+
img_rgb, seg_mask, contrast_method, threshold
|
157 |
+
)
|
158 |
+
|
159 |
+
# Create stats text
|
160 |
+
stats_text = f"### Contrast Analysis Results\n\n"
|
161 |
+
stats_text += f"**Method:** {contrast_method}\n"
|
162 |
+
stats_text += f"**Threshold:** {threshold:.2f}\n"
|
163 |
+
stats_text += f"**Problem Areas:** {stats['problem_count']}\n"
|
164 |
+
|
165 |
+
if 'min_contrast' in stats:
|
166 |
+
stats_text += f"**Min Contrast:** {stats['min_contrast']:.2f}\n"
|
167 |
+
if 'max_contrast' in stats:
|
168 |
+
stats_text += f"**Max Contrast:** {stats['max_contrast']:.2f}\n"
|
169 |
+
if 'average_contrast' in stats:
|
170 |
+
stats_text += f"**Average Contrast:** {stats['average_contrast']:.2f}\n"
|
171 |
+
|
172 |
+
# Convert contrast image to PIL
|
173 |
+
contrast_pil = Image.fromarray(contrast_img)
|
174 |
+
|
175 |
+
return out_img, contrast_pil, stats_text
|
176 |
+
|
177 |
+
# Initialize models
|
178 |
+
setup_modules()
|
179 |
+
|
180 |
+
# Gradio Interface
|
181 |
+
title = "<h1 style='text-align: center'>NeuroNest: Abheek Pradhan - Contrast Model</h1>"
|
182 |
+
description = "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> "\
|
183 |
+
"<a href='https://github.com/lolout1/sam2Contrast' target='_blank'>Github Repo</a></p>" \
|
184 |
+
"<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'>" \
|
185 |
+
"I am developing NeuroNest, a contrast detection system designed to identify areas with insufficient contrast " \
|
186 |
+
"for individuals with Alzheimer's disease. This tool leverages OneFormer's state-of-the-art segmentation " \
|
187 |
+
"capabilities trained on ADE20K dataset to detect indoor objects like floors, furniture, walls, and ceilings. " \
|
188 |
+
"By analyzing contrast ratios between adjacent segments, NeuroNest flags potential visual accessibility issues " \
|
189 |
+
"that may trigger confusion or disorientation in elderly individuals with cognitive impairments.</p>" \
|
190 |
+
"<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'>" \
|
191 |
+
"[Note: When running on my Linux cluster, please request a GPU node for optimal performance. " \
|
192 |
+
"On login nodes, CUDA may not be available.]</p>"
|
193 |
+
|
194 |
+
gradio_inputs = [
|
195 |
+
gr.Image(label="Input Image", type="filepath"),
|
196 |
+
gr.Radio(choices=["Swin-L", "DiNAT-L"], value="Swin-L", label="Backbone"),
|
197 |
+
gr.Radio(
|
198 |
+
choices=["Luminance (WCAG)", "Hue", "Saturation", "Combined"],
|
199 |
+
value="Luminance (WCAG)",
|
200 |
+
label="Contrast Detection Method"
|
201 |
+
),
|
202 |
+
gr.Slider(
|
203 |
+
minimum=1.0,
|
204 |
+
maximum=10.0,
|
205 |
+
value=4.5,
|
206 |
+
step=0.1,
|
207 |
+
label="Contrast Threshold (Lower = More Strict)"
|
208 |
+
)
|
209 |
+
]
|
210 |
+
|
211 |
+
gradio_outputs = [
|
212 |
+
gr.Image(type="pil", label="Segmentation Result"),
|
213 |
+
gr.Image(type="pil", label="Contrast Analysis"),
|
214 |
+
gr.Markdown(label="Analysis Results")
|
215 |
+
]
|
216 |
+
|
217 |
+
examples = [
|
218 |
+
["examples/indoor_room.jpg", "Swin-L", "Luminance (WCAG)", 4.5],
|
219 |
+
["examples/living_room.jpg", "DiNAT-L", "Combined", 3.0],
|
220 |
+
]
|
221 |
+
|
222 |
+
iface = gr.Interface(
|
223 |
+
fn=segment_and_analyze_contrast,
|
224 |
+
inputs=gradio_inputs,
|
225 |
+
outputs=gradio_outputs,
|
226 |
+
examples_per_page=5,
|
227 |
+
allow_flagging="never",
|
228 |
+
examples=examples,
|
229 |
+
title=title,
|
230 |
+
description=description
|
231 |
+
)
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
iface.launch(server_name="0.0.0.0", share=True)
|
gradio_test.py
ADDED
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Tuple, Dict, List, Optional, Union
|
11 |
+
import gradio as gr
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
import warnings
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
# Detectron2 imports
|
17 |
+
from detectron2.config import get_cfg
|
18 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
19 |
+
from detectron2.data import MetadataCatalog
|
20 |
+
from detectron2.engine import DefaultPredictor as DetectronPredictor
|
21 |
+
from detectron2 import model_zoo
|
22 |
+
from detectron2.utils.visualizer import Visualizer, ColorMode
|
23 |
+
|
24 |
+
# OneFormer imports
|
25 |
+
try:
|
26 |
+
from oneformer import (
|
27 |
+
add_oneformer_config,
|
28 |
+
add_common_config,
|
29 |
+
add_swin_config,
|
30 |
+
add_dinat_config,
|
31 |
+
)
|
32 |
+
from demo.defaults import DefaultPredictor as OneFormerPredictor
|
33 |
+
ONEFORMER_AVAILABLE = True
|
34 |
+
except ImportError as e:
|
35 |
+
print(f"OneFormer not available: {e}")
|
36 |
+
ONEFORMER_AVAILABLE = False
|
37 |
+
|
38 |
+
# Setup logging
|
39 |
+
logging.basicConfig(level=logging.INFO)
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
########################################
|
43 |
+
# GLOBAL CONFIGURATIONS
|
44 |
+
########################################
|
45 |
+
|
46 |
+
# Device configuration
|
47 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
48 |
+
CPU_DEVICE = torch.device("cpu")
|
49 |
+
torch.set_num_threads(4)
|
50 |
+
|
51 |
+
# ADE20K class mappings for floor detection
|
52 |
+
FLOOR_CLASSES = {
|
53 |
+
'floor': [3, 4, 13], # floor, wood floor, rug
|
54 |
+
'carpet': [28], # carpet
|
55 |
+
'mat': [78], # mat
|
56 |
+
}
|
57 |
+
|
58 |
+
# OneFormer configurations
|
59 |
+
ONEFORMER_CONFIG = {
|
60 |
+
"ADE20K": {
|
61 |
+
"key": "ade20k",
|
62 |
+
"swin_cfg": "configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml",
|
63 |
+
"swin_model": "shi-labs/oneformer_ade20k_swin_large",
|
64 |
+
"swin_file": "250_16_swin_l_oneformer_ade20k_160k.pth",
|
65 |
+
"width": 640
|
66 |
+
}
|
67 |
+
}
|
68 |
+
|
69 |
+
########################################
|
70 |
+
# IMPORT UNIVERSAL CONTRAST ANALYZER
|
71 |
+
########################################
|
72 |
+
|
73 |
+
from utils.universal_contrast_analyzer import UniversalContrastAnalyzer
|
74 |
+
|
75 |
+
# Keep old class for compatibility but deprecated
|
76 |
+
class RobustContrastAnalyzer:
|
77 |
+
"""Advanced contrast analyzer for Alzheimer's-friendly environments"""
|
78 |
+
|
79 |
+
def __init__(self, wcag_threshold: float = 4.5):
|
80 |
+
self.wcag_threshold = wcag_threshold
|
81 |
+
|
82 |
+
# ADE20K class mappings for important objects
|
83 |
+
self.semantic_classes = {
|
84 |
+
'floor': [3, 4, 13, 28, 78], # floor, wood floor, rug, carpet, mat
|
85 |
+
'wall': [0, 1, 9], # wall, building, brick
|
86 |
+
'ceiling': [5], # ceiling
|
87 |
+
'furniture': [10, 19, 15, 7, 18, 23], # sofa, chair, table, bed, armchair, cabinet
|
88 |
+
'door': [25], # door
|
89 |
+
'window': [8], # window
|
90 |
+
'stairs': [53], # stairs
|
91 |
+
}
|
92 |
+
|
93 |
+
# Priority relationships for safety
|
94 |
+
self.priority_relationships = {
|
95 |
+
('floor', 'furniture'): ('critical', 'Furniture must be clearly visible against floor'),
|
96 |
+
('floor', 'stairs'): ('critical', 'Stairs must have clear contrast with floor'),
|
97 |
+
('floor', 'door'): ('high', 'Door should be easily distinguishable from floor'),
|
98 |
+
('wall', 'furniture'): ('high', 'Furniture should stand out from walls'),
|
99 |
+
('wall', 'door'): ('high', 'Doors should be clearly visible on walls'),
|
100 |
+
('wall', 'window'): ('medium', 'Windows should have adequate contrast'),
|
101 |
+
('ceiling', 'wall'): ('low', 'Ceiling-wall contrast is less critical'),
|
102 |
+
}
|
103 |
+
|
104 |
+
def get_object_category(self, class_id: int) -> str:
|
105 |
+
"""Map segmentation class to object category"""
|
106 |
+
for category, class_ids in self.semantic_classes.items():
|
107 |
+
if class_id in class_ids:
|
108 |
+
return category
|
109 |
+
return 'other'
|
110 |
+
|
111 |
+
def calculate_wcag_contrast(self, color1: np.ndarray, color2: np.ndarray) -> float:
|
112 |
+
"""Calculate WCAG contrast ratio"""
|
113 |
+
def relative_luminance(rgb):
|
114 |
+
rgb_norm = rgb / 255.0
|
115 |
+
rgb_linear = np.where(rgb_norm <= 0.03928,
|
116 |
+
rgb_norm / 12.92,
|
117 |
+
((rgb_norm + 0.055) / 1.055) ** 2.4)
|
118 |
+
return np.dot(rgb_linear, [0.2126, 0.7152, 0.0722])
|
119 |
+
|
120 |
+
lum1 = relative_luminance(color1)
|
121 |
+
lum2 = relative_luminance(color2)
|
122 |
+
|
123 |
+
lighter = max(lum1, lum2)
|
124 |
+
darker = min(lum1, lum2)
|
125 |
+
|
126 |
+
return (lighter + 0.05) / (darker + 0.05)
|
127 |
+
|
128 |
+
def extract_dominant_color(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
129 |
+
"""Extract dominant color from masked region"""
|
130 |
+
if not np.any(mask):
|
131 |
+
return np.array([128, 128, 128])
|
132 |
+
|
133 |
+
masked_pixels = image[mask]
|
134 |
+
if len(masked_pixels) == 0:
|
135 |
+
return np.array([128, 128, 128])
|
136 |
+
|
137 |
+
# Use median for robustness against outliers
|
138 |
+
return np.median(masked_pixels, axis=0).astype(int)
|
139 |
+
|
140 |
+
def find_adjacent_segments(self, seg1_mask: np.ndarray, seg2_mask: np.ndarray,
|
141 |
+
min_boundary_length: int = 30) -> np.ndarray:
|
142 |
+
"""Find clean boundaries between segments"""
|
143 |
+
kernel = np.ones((3, 3), np.uint8)
|
144 |
+
dilated1 = cv2.dilate(seg1_mask.astype(np.uint8), kernel, iterations=1)
|
145 |
+
dilated2 = cv2.dilate(seg2_mask.astype(np.uint8), kernel, iterations=1)
|
146 |
+
|
147 |
+
boundary = dilated1 & dilated2
|
148 |
+
|
149 |
+
# Remove small disconnected components
|
150 |
+
contours, _ = cv2.findContours(boundary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
151 |
+
clean_boundary = np.zeros_like(boundary)
|
152 |
+
|
153 |
+
for contour in contours:
|
154 |
+
if cv2.contourArea(contour) >= min_boundary_length:
|
155 |
+
cv2.fillPoly(clean_boundary, [contour], 1)
|
156 |
+
|
157 |
+
return clean_boundary.astype(bool)
|
158 |
+
|
159 |
+
def analyze_contrast(self, image: np.ndarray, segmentation: np.ndarray) -> Dict:
|
160 |
+
"""Perform comprehensive contrast analysis"""
|
161 |
+
h, w = segmentation.shape
|
162 |
+
results = {
|
163 |
+
'critical_issues': [],
|
164 |
+
'high_issues': [],
|
165 |
+
'medium_issues': [],
|
166 |
+
'visualization': image.copy(),
|
167 |
+
'statistics': {}
|
168 |
+
}
|
169 |
+
|
170 |
+
# Build segment information
|
171 |
+
unique_segments = np.unique(segmentation)
|
172 |
+
segment_info = {}
|
173 |
+
|
174 |
+
for seg_id in unique_segments:
|
175 |
+
if seg_id == 0: # Skip background
|
176 |
+
continue
|
177 |
+
|
178 |
+
mask = segmentation == seg_id
|
179 |
+
if np.sum(mask) < 100: # Skip very small segments
|
180 |
+
continue
|
181 |
+
|
182 |
+
category = self.get_object_category(seg_id)
|
183 |
+
if category == 'other':
|
184 |
+
continue
|
185 |
+
|
186 |
+
segment_info[seg_id] = {
|
187 |
+
'category': category,
|
188 |
+
'mask': mask,
|
189 |
+
'color': self.extract_dominant_color(image, mask),
|
190 |
+
'area': np.sum(mask)
|
191 |
+
}
|
192 |
+
|
193 |
+
# Analyze priority relationships
|
194 |
+
issue_counts = {'critical': 0, 'high': 0, 'medium': 0}
|
195 |
+
|
196 |
+
for seg_id1, info1 in segment_info.items():
|
197 |
+
for seg_id2, info2 in segment_info.items():
|
198 |
+
if seg_id1 >= seg_id2:
|
199 |
+
continue
|
200 |
+
|
201 |
+
# Check if this is a priority relationship
|
202 |
+
relationship = tuple(sorted([info1['category'], info2['category']]))
|
203 |
+
if relationship not in self.priority_relationships:
|
204 |
+
continue
|
205 |
+
|
206 |
+
priority, description = self.priority_relationships[relationship]
|
207 |
+
|
208 |
+
# Check if segments are adjacent
|
209 |
+
boundary = self.find_adjacent_segments(info1['mask'], info2['mask'])
|
210 |
+
if not np.any(boundary):
|
211 |
+
continue
|
212 |
+
|
213 |
+
# Calculate contrast
|
214 |
+
wcag_contrast = self.calculate_wcag_contrast(info1['color'], info2['color'])
|
215 |
+
|
216 |
+
# Determine if there's an issue
|
217 |
+
if wcag_contrast < self.wcag_threshold:
|
218 |
+
issue = {
|
219 |
+
'categories': (info1['category'], info2['category']),
|
220 |
+
'contrast_ratio': wcag_contrast,
|
221 |
+
'boundary_area': np.sum(boundary),
|
222 |
+
'description': description,
|
223 |
+
'priority': priority
|
224 |
+
}
|
225 |
+
|
226 |
+
# Color-code boundaries and store issues
|
227 |
+
if priority == 'critical':
|
228 |
+
results['critical_issues'].append(issue)
|
229 |
+
results['visualization'][boundary] = [255, 0, 0] # Red
|
230 |
+
issue_counts['critical'] += 1
|
231 |
+
elif priority == 'high':
|
232 |
+
results['high_issues'].append(issue)
|
233 |
+
results['visualization'][boundary] = [255, 165, 0] # Orange
|
234 |
+
issue_counts['high'] += 1
|
235 |
+
elif priority == 'medium':
|
236 |
+
results['medium_issues'].append(issue)
|
237 |
+
results['visualization'][boundary] = [255, 255, 0] # Yellow
|
238 |
+
issue_counts['medium'] += 1
|
239 |
+
|
240 |
+
# Calculate statistics
|
241 |
+
results['statistics'] = {
|
242 |
+
'total_segments': len(segment_info),
|
243 |
+
'total_issues': sum(issue_counts.values()),
|
244 |
+
'critical_count': issue_counts['critical'],
|
245 |
+
'high_count': issue_counts['high'],
|
246 |
+
'medium_count': issue_counts['medium'],
|
247 |
+
'wcag_threshold': self.wcag_threshold
|
248 |
+
}
|
249 |
+
|
250 |
+
return results
|
251 |
+
|
252 |
+
########################################
|
253 |
+
# ONEFORMER INTEGRATION
|
254 |
+
########################################
|
255 |
+
|
256 |
+
class OneFormerManager:
|
257 |
+
"""Manages OneFormer model loading and inference"""
|
258 |
+
|
259 |
+
def __init__(self):
|
260 |
+
self.predictor = None
|
261 |
+
self.metadata = None
|
262 |
+
self.initialized = False
|
263 |
+
|
264 |
+
def initialize(self, backbone: str = "swin"):
|
265 |
+
"""Initialize OneFormer model"""
|
266 |
+
if not ONEFORMER_AVAILABLE:
|
267 |
+
logger.error("OneFormer not available")
|
268 |
+
return False
|
269 |
+
|
270 |
+
try:
|
271 |
+
cfg = get_cfg()
|
272 |
+
add_deeplab_config(cfg)
|
273 |
+
add_common_config(cfg)
|
274 |
+
add_swin_config(cfg)
|
275 |
+
add_oneformer_config(cfg)
|
276 |
+
add_dinat_config(cfg)
|
277 |
+
|
278 |
+
config = ONEFORMER_CONFIG["ADE20K"]
|
279 |
+
cfg.merge_from_file(config["swin_cfg"])
|
280 |
+
cfg.MODEL.DEVICE = DEVICE
|
281 |
+
|
282 |
+
# Download model if not exists
|
283 |
+
model_path = hf_hub_download(
|
284 |
+
repo_id=config["swin_model"],
|
285 |
+
filename=config["swin_file"]
|
286 |
+
)
|
287 |
+
cfg.MODEL.WEIGHTS = model_path
|
288 |
+
cfg.freeze()
|
289 |
+
|
290 |
+
self.predictor = OneFormerPredictor(cfg)
|
291 |
+
self.metadata = MetadataCatalog.get(
|
292 |
+
cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused"
|
293 |
+
)
|
294 |
+
self.initialized = True
|
295 |
+
logger.info("OneFormer initialized successfully")
|
296 |
+
return True
|
297 |
+
|
298 |
+
except Exception as e:
|
299 |
+
logger.error(f"Failed to initialize OneFormer: {e}")
|
300 |
+
return False
|
301 |
+
|
302 |
+
def semantic_segmentation(self, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
303 |
+
"""Perform semantic segmentation"""
|
304 |
+
if not self.initialized:
|
305 |
+
raise RuntimeError("OneFormer not initialized")
|
306 |
+
|
307 |
+
# Resize image to expected width
|
308 |
+
width = ONEFORMER_CONFIG["ADE20K"]["width"]
|
309 |
+
h, w = image.shape[:2]
|
310 |
+
if w != width:
|
311 |
+
scale = width / w
|
312 |
+
new_h = int(h * scale)
|
313 |
+
image_resized = cv2.resize(image, (width, new_h))
|
314 |
+
else:
|
315 |
+
image_resized = image
|
316 |
+
|
317 |
+
# Run prediction
|
318 |
+
predictions = self.predictor(image_resized, "semantic")
|
319 |
+
seg_mask = predictions["sem_seg"].argmax(dim=0).cpu().numpy()
|
320 |
+
|
321 |
+
# Create visualization
|
322 |
+
visualizer = Visualizer(
|
323 |
+
image_resized[:, :, ::-1],
|
324 |
+
metadata=self.metadata,
|
325 |
+
instance_mode=ColorMode.IMAGE
|
326 |
+
)
|
327 |
+
vis_output = visualizer.draw_sem_seg(seg_mask, alpha=0.5)
|
328 |
+
vis_image = vis_output.get_image()[:, :, ::-1] # BGR to RGB
|
329 |
+
|
330 |
+
return seg_mask, vis_image
|
331 |
+
|
332 |
+
def extract_floor_areas(self, segmentation: np.ndarray) -> np.ndarray:
|
333 |
+
"""Extract floor areas from segmentation"""
|
334 |
+
floor_mask = np.zeros_like(segmentation, dtype=bool)
|
335 |
+
|
336 |
+
for class_ids in FLOOR_CLASSES.values():
|
337 |
+
for class_id in class_ids:
|
338 |
+
floor_mask |= (segmentation == class_id)
|
339 |
+
|
340 |
+
return floor_mask
|
341 |
+
|
342 |
+
|
343 |
+
########################################
|
344 |
+
# ENHANCED BLACKSPOT DETECTION WITH CLEAR VISUALIZATION
|
345 |
+
########################################
|
346 |
+
|
347 |
+
class BlackspotDetector:
|
348 |
+
"""Manages blackspot detection with MaskRCNN - Enhanced Version"""
|
349 |
+
|
350 |
+
def __init__(self, model_path: str):
|
351 |
+
self.model_path = model_path
|
352 |
+
self.predictor = None
|
353 |
+
|
354 |
+
def initialize(self, threshold: float = 0.5) -> bool:
|
355 |
+
"""Initialize MaskRCNN model"""
|
356 |
+
try:
|
357 |
+
cfg = get_cfg()
|
358 |
+
cfg.merge_from_file(
|
359 |
+
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
360 |
+
)
|
361 |
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # [floors, blackspot]
|
362 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = threshold
|
363 |
+
cfg.MODEL.WEIGHTS = self.model_path
|
364 |
+
cfg.MODEL.DEVICE = DEVICE
|
365 |
+
|
366 |
+
self.predictor = DetectronPredictor(cfg)
|
367 |
+
logger.info("MaskRCNN blackspot detector initialized")
|
368 |
+
return True
|
369 |
+
|
370 |
+
except Exception as e:
|
371 |
+
logger.error(f"Failed to initialize blackspot detector: {e}")
|
372 |
+
return False
|
373 |
+
|
374 |
+
def create_enhanced_visualizations(self, image: np.ndarray, floor_mask: np.ndarray,
|
375 |
+
blackspot_mask: np.ndarray) -> Dict:
|
376 |
+
"""Create multiple enhanced visualizations of blackspot detection"""
|
377 |
+
|
378 |
+
# 1. Pure Segmentation View (like semantic segmentation output)
|
379 |
+
segmentation_view = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
|
380 |
+
segmentation_view[floor_mask] = [34, 139, 34] # Forest green for floor
|
381 |
+
segmentation_view[blackspot_mask] = [255, 0, 0] # Bright red for blackspots
|
382 |
+
segmentation_view[~(floor_mask | blackspot_mask)] = [128, 128, 128] # Gray for other areas
|
383 |
+
|
384 |
+
# 2. High Contrast Overlay
|
385 |
+
high_contrast_overlay = image.copy()
|
386 |
+
# Make background slightly darker to emphasize blackspots
|
387 |
+
high_contrast_overlay = cv2.convertScaleAbs(high_contrast_overlay, alpha=0.6, beta=0)
|
388 |
+
# Add bright overlays
|
389 |
+
high_contrast_overlay[floor_mask] = cv2.addWeighted(
|
390 |
+
high_contrast_overlay[floor_mask], 0.7,
|
391 |
+
np.full_like(high_contrast_overlay[floor_mask], [0, 255, 0]), 0.3, 0
|
392 |
+
)
|
393 |
+
high_contrast_overlay[blackspot_mask] = [255, 0, 255] # Magenta for maximum visibility
|
394 |
+
|
395 |
+
# 3. Blackspot-only View (white blackspots on black background)
|
396 |
+
blackspot_only = np.zeros((*image.shape[:2], 3), dtype=np.uint8)
|
397 |
+
blackspot_only[blackspot_mask] = [255, 255, 255] # White blackspots
|
398 |
+
blackspot_only[floor_mask & ~blackspot_mask] = [64, 64, 64] # Dark gray for floor areas
|
399 |
+
|
400 |
+
# 4. Side-by-side comparison
|
401 |
+
h, w = image.shape[:2]
|
402 |
+
side_by_side = np.zeros((h, w * 2, 3), dtype=np.uint8)
|
403 |
+
side_by_side[:, :w] = image
|
404 |
+
side_by_side[:, w:] = segmentation_view
|
405 |
+
|
406 |
+
# Add text labels
|
407 |
+
cv2.putText(side_by_side, "Original", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
408 |
+
cv2.putText(side_by_side, "Blackspot Detection", (w + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
409 |
+
|
410 |
+
# 5. Annotated view with bounding boxes and labels
|
411 |
+
annotated_view = image.copy()
|
412 |
+
|
413 |
+
# Find blackspot contours for bounding boxes
|
414 |
+
blackspot_contours, _ = cv2.findContours(
|
415 |
+
blackspot_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
416 |
+
)
|
417 |
+
|
418 |
+
for i, contour in enumerate(blackspot_contours):
|
419 |
+
if cv2.contourArea(contour) > 50: # Filter small artifacts
|
420 |
+
# Draw bounding box
|
421 |
+
x, y, w, h = cv2.boundingRect(contour)
|
422 |
+
cv2.rectangle(annotated_view, (x, y), (x + w, y + h), (255, 0, 0), 2)
|
423 |
+
|
424 |
+
# Draw contour
|
425 |
+
cv2.drawContours(annotated_view, [contour], -1, (255, 0, 255), 2)
|
426 |
+
|
427 |
+
# Add label
|
428 |
+
area = cv2.contourArea(contour)
|
429 |
+
label = f"Blackspot {i+1}: {area:.0f}px"
|
430 |
+
cv2.putText(annotated_view, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
|
431 |
+
|
432 |
+
return {
|
433 |
+
'segmentation_view': segmentation_view,
|
434 |
+
'high_contrast_overlay': high_contrast_overlay,
|
435 |
+
'blackspot_only': blackspot_only,
|
436 |
+
'side_by_side': side_by_side,
|
437 |
+
'annotated_view': annotated_view
|
438 |
+
}
|
439 |
+
|
440 |
+
def detect_blackspots(self, image: np.ndarray, floor_prior: Optional[np.ndarray] = None) -> Dict:
|
441 |
+
"""Detect blackspots with enhanced visualizations"""
|
442 |
+
if self.predictor is None:
|
443 |
+
raise RuntimeError("Blackspot detector not initialized")
|
444 |
+
|
445 |
+
# Get original image dimensions
|
446 |
+
original_h, original_w = image.shape[:2]
|
447 |
+
|
448 |
+
# Handle floor prior shape mismatch
|
449 |
+
processed_image = image.copy()
|
450 |
+
if floor_prior is not None:
|
451 |
+
prior_h, prior_w = floor_prior.shape
|
452 |
+
|
453 |
+
# Resize floor_prior to match original image if needed
|
454 |
+
if (prior_h, prior_w) != (original_h, original_w):
|
455 |
+
logger.info(f"Resizing floor prior from {(prior_h, prior_w)} to {(original_h, original_w)}")
|
456 |
+
floor_prior_resized = cv2.resize(
|
457 |
+
floor_prior.astype(np.uint8),
|
458 |
+
(original_w, original_h),
|
459 |
+
interpolation=cv2.INTER_NEAREST
|
460 |
+
).astype(bool)
|
461 |
+
else:
|
462 |
+
floor_prior_resized = floor_prior
|
463 |
+
else:
|
464 |
+
floor_prior_resized = None
|
465 |
+
|
466 |
+
# Run detection on the processed image
|
467 |
+
try:
|
468 |
+
outputs = self.predictor(processed_image)
|
469 |
+
instances = outputs["instances"].to("cpu")
|
470 |
+
except Exception as e:
|
471 |
+
logger.error(f"Error in MaskRCNN prediction: {e}")
|
472 |
+
# Return empty results
|
473 |
+
empty_mask = np.zeros(image.shape[:2], dtype=bool)
|
474 |
+
return {
|
475 |
+
'visualization': image,
|
476 |
+
'floor_mask': empty_mask,
|
477 |
+
'blackspot_mask': empty_mask,
|
478 |
+
'floor_area': 0,
|
479 |
+
'blackspot_area': 0,
|
480 |
+
'coverage_percentage': 0,
|
481 |
+
'num_detections': 0,
|
482 |
+
'avg_confidence': 0.0,
|
483 |
+
'enhanced_views': self.create_enhanced_visualizations(image, empty_mask, empty_mask)
|
484 |
+
}
|
485 |
+
|
486 |
+
# Process results
|
487 |
+
if len(instances) == 0:
|
488 |
+
# No detections
|
489 |
+
combined_floor = floor_prior_resized if floor_prior_resized is not None else np.zeros(image.shape[:2], dtype=bool)
|
490 |
+
combined_blackspot = np.zeros(image.shape[:2], dtype=bool)
|
491 |
+
blackspot_scores = []
|
492 |
+
else:
|
493 |
+
pred_classes = instances.pred_classes.numpy()
|
494 |
+
pred_masks = instances.pred_masks.numpy()
|
495 |
+
scores = instances.scores.numpy()
|
496 |
+
|
497 |
+
# Separate floor and blackspot masks
|
498 |
+
floor_indices = pred_classes == 0
|
499 |
+
blackspot_indices = pred_classes == 1
|
500 |
+
|
501 |
+
floor_masks = pred_masks[floor_indices] if np.any(floor_indices) else []
|
502 |
+
blackspot_masks = pred_masks[blackspot_indices] if np.any(blackspot_indices) else []
|
503 |
+
blackspot_scores = scores[blackspot_indices] if np.any(blackspot_indices) else []
|
504 |
+
|
505 |
+
# Combine masks
|
506 |
+
combined_floor = np.zeros(image.shape[:2], dtype=bool)
|
507 |
+
combined_blackspot = np.zeros(image.shape[:2], dtype=bool)
|
508 |
+
|
509 |
+
for mask in floor_masks:
|
510 |
+
combined_floor |= mask
|
511 |
+
|
512 |
+
for mask in blackspot_masks:
|
513 |
+
combined_blackspot |= mask
|
514 |
+
|
515 |
+
# Apply floor prior if available
|
516 |
+
if floor_prior_resized is not None:
|
517 |
+
# Combine OneFormer floor detection with MaskRCNN floor detection
|
518 |
+
combined_floor |= floor_prior_resized
|
519 |
+
# Keep only blackspots that are on floors
|
520 |
+
combined_blackspot &= combined_floor
|
521 |
+
|
522 |
+
# Create all enhanced visualizations
|
523 |
+
enhanced_views = self.create_enhanced_visualizations(image, combined_floor, combined_blackspot)
|
524 |
+
|
525 |
+
# Calculate statistics
|
526 |
+
floor_area = int(np.sum(combined_floor))
|
527 |
+
blackspot_area = int(np.sum(combined_blackspot))
|
528 |
+
coverage_percentage = (blackspot_area / floor_area * 100) if floor_area > 0 else 0
|
529 |
+
|
530 |
+
# Count individual blackspot instances
|
531 |
+
blackspot_contours, _ = cv2.findContours(
|
532 |
+
combined_blackspot.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
533 |
+
)
|
534 |
+
actual_detections = len([c for c in blackspot_contours if cv2.contourArea(c) > 50])
|
535 |
+
|
536 |
+
return {
|
537 |
+
'visualization': enhanced_views['high_contrast_overlay'], # Main view
|
538 |
+
'floor_mask': combined_floor,
|
539 |
+
'blackspot_mask': combined_blackspot,
|
540 |
+
'floor_area': floor_area,
|
541 |
+
'blackspot_area': blackspot_area,
|
542 |
+
'coverage_percentage': coverage_percentage,
|
543 |
+
'num_detections': actual_detections,
|
544 |
+
'avg_confidence': float(np.mean(blackspot_scores)) if len(blackspot_scores) > 0 else 0.0,
|
545 |
+
'enhanced_views': enhanced_views # All visualization options
|
546 |
+
}
|
547 |
+
########################################
|
548 |
+
# FIXED MAIN APPLICATION CLASS
|
549 |
+
########################################
|
550 |
+
|
551 |
+
class NeuroNestApp:
|
552 |
+
"""Main application class integrating all components - FIXED VERSION"""
|
553 |
+
|
554 |
+
def __init__(self):
|
555 |
+
self.oneformer = OneFormerManager()
|
556 |
+
self.blackspot_detector = None
|
557 |
+
self.contrast_analyzer = UniversalContrastAnalyzer()
|
558 |
+
self.initialized = False
|
559 |
+
|
560 |
+
def initialize(self, blackspot_model_path: str = "./output_floor_blackspot/model_0004999.pth"):
|
561 |
+
"""Initialize all components"""
|
562 |
+
logger.info("Initializing NeuroNest application...")
|
563 |
+
|
564 |
+
# Initialize OneFormer
|
565 |
+
oneformer_success = self.oneformer.initialize()
|
566 |
+
|
567 |
+
# Initialize blackspot detector if model exists
|
568 |
+
blackspot_success = False
|
569 |
+
if os.path.exists(blackspot_model_path):
|
570 |
+
self.blackspot_detector = BlackspotDetector(blackspot_model_path)
|
571 |
+
blackspot_success = True
|
572 |
+
else:
|
573 |
+
logger.warning(f"Blackspot model not found at {blackspot_model_path}")
|
574 |
+
|
575 |
+
self.initialized = oneformer_success
|
576 |
+
return oneformer_success, blackspot_success
|
577 |
+
|
578 |
+
def analyze_image(self,
|
579 |
+
image_path: str,
|
580 |
+
blackspot_threshold: float = 0.5,
|
581 |
+
contrast_threshold: float = 4.5,
|
582 |
+
enable_blackspot: bool = True,
|
583 |
+
enable_contrast: bool = True) -> Dict:
|
584 |
+
"""Perform complete image analysis - FIXED VERSION"""
|
585 |
+
|
586 |
+
if not self.initialized:
|
587 |
+
return {"error": "Application not properly initialized"}
|
588 |
+
|
589 |
+
try:
|
590 |
+
# Load and preprocess image
|
591 |
+
image = cv2.imread(image_path)
|
592 |
+
if image is None:
|
593 |
+
return {"error": "Could not load image"}
|
594 |
+
|
595 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
596 |
+
logger.info(f"Loaded image with shape: {image_rgb.shape}")
|
597 |
+
|
598 |
+
results = {
|
599 |
+
'original_image': image_rgb,
|
600 |
+
'segmentation': None,
|
601 |
+
'blackspot': None,
|
602 |
+
'contrast': None,
|
603 |
+
'statistics': {}
|
604 |
+
}
|
605 |
+
|
606 |
+
# 1. Semantic Segmentation (always performed)
|
607 |
+
logger.info("Running semantic segmentation...")
|
608 |
+
seg_mask, seg_visualization = self.oneformer.semantic_segmentation(image_rgb)
|
609 |
+
logger.info(f"Segmentation mask shape: {seg_mask.shape}")
|
610 |
+
|
611 |
+
results['segmentation'] = {
|
612 |
+
'visualization': seg_visualization,
|
613 |
+
'mask': seg_mask
|
614 |
+
}
|
615 |
+
|
616 |
+
# Extract floor areas for blackspot detection
|
617 |
+
floor_prior = self.oneformer.extract_floor_areas(seg_mask)
|
618 |
+
logger.info(f"Floor prior shape: {floor_prior.shape}, total floor pixels: {np.sum(floor_prior)}")
|
619 |
+
|
620 |
+
# 2. Blackspot Detection (if enabled and model available)
|
621 |
+
if enable_blackspot and self.blackspot_detector is not None:
|
622 |
+
logger.info("Running blackspot detection...")
|
623 |
+
try:
|
624 |
+
self.blackspot_detector.initialize(threshold=blackspot_threshold)
|
625 |
+
blackspot_results = self.blackspot_detector.detect_blackspots(image_rgb, floor_prior)
|
626 |
+
results['blackspot'] = blackspot_results
|
627 |
+
logger.info("Blackspot detection completed successfully")
|
628 |
+
except Exception as e:
|
629 |
+
logger.error(f"Error in blackspot detection: {e}")
|
630 |
+
# Continue without blackspot results
|
631 |
+
results['blackspot'] = None
|
632 |
+
|
633 |
+
# 3. Contrast Analysis (if enabled)
|
634 |
+
if enable_contrast:
|
635 |
+
logger.info("Running contrast analysis...")
|
636 |
+
try:
|
637 |
+
# Use the resized image for contrast analysis to match segmentation
|
638 |
+
width = ONEFORMER_CONFIG["ADE20K"]["width"]
|
639 |
+
h, w = image_rgb.shape[:2]
|
640 |
+
if w != width:
|
641 |
+
scale = width / w
|
642 |
+
new_h = int(h * scale)
|
643 |
+
image_for_contrast = cv2.resize(image_rgb, (width, new_h))
|
644 |
+
else:
|
645 |
+
image_for_contrast = image_rgb
|
646 |
+
|
647 |
+
contrast_results = self.contrast_analyzer.analyze_contrast(image_for_contrast, seg_mask)
|
648 |
+
results['contrast'] = contrast_results
|
649 |
+
logger.info("Contrast analysis completed successfully")
|
650 |
+
except Exception as e:
|
651 |
+
logger.error(f"Error in contrast analysis: {e}")
|
652 |
+
# Continue without contrast results
|
653 |
+
results['contrast'] = None
|
654 |
+
|
655 |
+
# 4. Generate combined statistics
|
656 |
+
stats = self._generate_statistics(results)
|
657 |
+
results['statistics'] = stats
|
658 |
+
|
659 |
+
logger.info("Image analysis completed successfully")
|
660 |
+
return results
|
661 |
+
|
662 |
+
except Exception as e:
|
663 |
+
logger.error(f"Error in image analysis: {e}")
|
664 |
+
import traceback
|
665 |
+
traceback.print_exc()
|
666 |
+
return {"error": f"Analysis failed: {str(e)}"}
|
667 |
+
|
668 |
+
def _generate_statistics(self, results: Dict) -> Dict:
|
669 |
+
"""Generate comprehensive statistics"""
|
670 |
+
stats = {}
|
671 |
+
|
672 |
+
# Segmentation stats
|
673 |
+
if results['segmentation']:
|
674 |
+
unique_classes = np.unique(results['segmentation']['mask'])
|
675 |
+
stats['segmentation'] = {
|
676 |
+
'num_classes': len(unique_classes),
|
677 |
+
'image_size': results['segmentation']['mask'].shape
|
678 |
+
}
|
679 |
+
|
680 |
+
# Blackspot stats
|
681 |
+
if results['blackspot']:
|
682 |
+
bs = results['blackspot']
|
683 |
+
stats['blackspot'] = {
|
684 |
+
'floor_area_pixels': bs['floor_area'],
|
685 |
+
'blackspot_area_pixels': bs['blackspot_area'],
|
686 |
+
'coverage_percentage': bs['coverage_percentage'],
|
687 |
+
'num_detections': bs['num_detections'],
|
688 |
+
'avg_confidence': bs['avg_confidence']
|
689 |
+
}
|
690 |
+
|
691 |
+
# Contrast stats
|
692 |
+
if results['contrast']:
|
693 |
+
cs = results['contrast']['statistics']
|
694 |
+
# Count issues by severity
|
695 |
+
critical_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'critical')
|
696 |
+
high_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'high')
|
697 |
+
medium_count = sum(1 for issue in results['contrast'].get('issues', []) if issue['severity'] == 'medium')
|
698 |
+
|
699 |
+
stats['contrast'] = {
|
700 |
+
'total_issues': cs.get('low_contrast_pairs', 0),
|
701 |
+
'critical_issues': critical_count,
|
702 |
+
'high_priority_issues': high_count,
|
703 |
+
'medium_priority_issues': medium_count,
|
704 |
+
'segments_analyzed': cs.get('total_segments', 0),
|
705 |
+
'floor_object_issues': cs.get('floor_object_issues', 0)
|
706 |
+
}
|
707 |
+
|
708 |
+
return stats
|
709 |
+
########################################
|
710 |
+
# GRADIO INTERFACE
|
711 |
+
########################################
|
712 |
+
|
713 |
+
########################################
|
714 |
+
# ENHANCED GRADIO INTERFACE WITH MULTIPLE BLACKSPOT VIEWS
|
715 |
+
########################################
|
716 |
+
|
717 |
+
def create_gradio_interface():
|
718 |
+
"""Create the enhanced Gradio interface with better blackspot visualization"""
|
719 |
+
|
720 |
+
# Initialize the application
|
721 |
+
app = NeuroNestApp()
|
722 |
+
oneformer_ok, blackspot_ok = app.initialize()
|
723 |
+
|
724 |
+
if not oneformer_ok:
|
725 |
+
raise RuntimeError("Failed to initialize OneFormer")
|
726 |
+
|
727 |
+
def analyze_wrapper(image_path, blackspot_threshold, contrast_threshold,
|
728 |
+
enable_blackspot, enable_contrast, blackspot_view_type):
|
729 |
+
"""Enhanced wrapper function for Gradio interface"""
|
730 |
+
if image_path is None:
|
731 |
+
return None, None, None, None, None, "Please upload an image"
|
732 |
+
|
733 |
+
results = app.analyze_image(
|
734 |
+
image_path=image_path,
|
735 |
+
blackspot_threshold=blackspot_threshold,
|
736 |
+
contrast_threshold=contrast_threshold,
|
737 |
+
enable_blackspot=enable_blackspot,
|
738 |
+
enable_contrast=enable_contrast
|
739 |
+
)
|
740 |
+
|
741 |
+
if "error" in results:
|
742 |
+
return None, None, None, None, None, f"Error: {results['error']}"
|
743 |
+
|
744 |
+
# Extract outputs
|
745 |
+
seg_output = results['segmentation']['visualization'] if results['segmentation'] else None
|
746 |
+
|
747 |
+
# Enhanced blackspot output selection
|
748 |
+
blackspot_output = None
|
749 |
+
blackspot_segmentation = None
|
750 |
+
if results['blackspot'] and 'enhanced_views' in results['blackspot']:
|
751 |
+
views = results['blackspot']['enhanced_views']
|
752 |
+
|
753 |
+
# Select view based on user choice
|
754 |
+
if blackspot_view_type == "High Contrast":
|
755 |
+
blackspot_output = views['high_contrast_overlay']
|
756 |
+
elif blackspot_view_type == "Segmentation Only":
|
757 |
+
blackspot_output = views['segmentation_view']
|
758 |
+
elif blackspot_view_type == "Blackspots Only":
|
759 |
+
blackspot_output = views['blackspot_only']
|
760 |
+
elif blackspot_view_type == "Side by Side":
|
761 |
+
blackspot_output = views['side_by_side']
|
762 |
+
elif blackspot_view_type == "Annotated":
|
763 |
+
blackspot_output = views['annotated_view']
|
764 |
+
else:
|
765 |
+
blackspot_output = views['high_contrast_overlay']
|
766 |
+
|
767 |
+
# Always provide segmentation view for the dedicated tab
|
768 |
+
blackspot_segmentation = views['segmentation_view']
|
769 |
+
|
770 |
+
contrast_output = results['contrast']['visualization'] if results['contrast'] else None
|
771 |
+
|
772 |
+
# Generate report
|
773 |
+
report = generate_analysis_report(results)
|
774 |
+
|
775 |
+
return seg_output, blackspot_output, blackspot_segmentation, contrast_output, report
|
776 |
+
|
777 |
+
# Update the generate_analysis_report function
|
778 |
+
def generate_analysis_report(results: Dict) -> str:
|
779 |
+
"""Generate enhanced analysis report text"""
|
780 |
+
report = ["# NeuroNest Analysis Report\n"]
|
781 |
+
|
782 |
+
# Segmentation results
|
783 |
+
if results['segmentation']:
|
784 |
+
stats = results['statistics'].get('segmentation', {})
|
785 |
+
report.append(f"## 🎯 Semantic Segmentation")
|
786 |
+
report.append(f"- **Objects detected:** {stats.get('num_classes', 'N/A')}")
|
787 |
+
report.append(f"- **Image size:** {stats.get('image_size', 'N/A')}")
|
788 |
+
report.append("")
|
789 |
+
|
790 |
+
# Enhanced blackspot results
|
791 |
+
if results['blackspot']:
|
792 |
+
bs_stats = results['statistics'].get('blackspot', {})
|
793 |
+
report.append(f"## ⚫ Blackspot Detection")
|
794 |
+
report.append(f"- **Floor area:** {bs_stats.get('floor_area_pixels', 0):,} pixels")
|
795 |
+
report.append(f"- **Blackspot area:** {bs_stats.get('blackspot_area_pixels', 0):,} pixels")
|
796 |
+
report.append(f"- **Coverage:** {bs_stats.get('coverage_percentage', 0):.2f}% of floor")
|
797 |
+
report.append(f"- **Individual blackspots:** {bs_stats.get('num_detections', 0)}")
|
798 |
+
report.append(f"- **Average confidence:** {bs_stats.get('avg_confidence', 0):.2f}")
|
799 |
+
|
800 |
+
# Risk assessment
|
801 |
+
coverage = bs_stats.get('coverage_percentage', 0)
|
802 |
+
if coverage > 5:
|
803 |
+
report.append(f"- **⚠️ Risk Level:** HIGH - Significant blackspot coverage detected")
|
804 |
+
elif coverage > 1:
|
805 |
+
report.append(f"- **⚠️ Risk Level:** MEDIUM - Moderate blackspot coverage")
|
806 |
+
elif coverage > 0:
|
807 |
+
report.append(f"- **✓ Risk Level:** LOW - Minimal blackspot coverage")
|
808 |
+
else:
|
809 |
+
report.append(f"- **✓ Risk Level:** NONE - No blackspots detected")
|
810 |
+
report.append("")
|
811 |
+
|
812 |
+
# Contrast analysis results (updated for universal analyzer)
|
813 |
+
if results['contrast']:
|
814 |
+
contrast_stats = results['statistics'].get('contrast', {})
|
815 |
+
report.append(f"## 🎨 Universal Contrast Analysis")
|
816 |
+
report.append(f"- **Adjacent pairs analyzed:** {results['contrast']['statistics'].get('analyzed_pairs', 0)}")
|
817 |
+
report.append(f"- **Total contrast issues:** {contrast_stats.get('total_issues', 0)}")
|
818 |
+
report.append(f"- **🔴 Critical:** {contrast_stats.get('critical_issues', 0)}")
|
819 |
+
report.append(f"- **🟠 High priority:** {contrast_stats.get('high_priority_issues', 0)}")
|
820 |
+
report.append(f"- **🟡 Medium priority:** {contrast_stats.get('medium_priority_issues', 0)}")
|
821 |
+
report.append(f"- **⚠️ Floor-object issues:** {contrast_stats.get('floor_object_issues', 0)}")
|
822 |
+
report.append("")
|
823 |
+
|
824 |
+
# Add detailed issues
|
825 |
+
issues = results['contrast'].get('issues', [])
|
826 |
+
if issues:
|
827 |
+
# Group by severity
|
828 |
+
critical_issues = [i for i in issues if i['severity'] == 'critical']
|
829 |
+
high_issues = [i for i in issues if i['severity'] == 'high']
|
830 |
+
|
831 |
+
if critical_issues:
|
832 |
+
report.append("### 🔴 Critical Issues (Immediate Attention Required)")
|
833 |
+
for issue in critical_issues[:5]: # Show top 5
|
834 |
+
cats = f"{issue['categories'][0]} ↔ {issue['categories'][1]}"
|
835 |
+
ratio = issue['wcag_ratio']
|
836 |
+
report.append(f"- **{cats}**: {ratio:.1f}:1 contrast ratio")
|
837 |
+
if issue['is_floor_object']:
|
838 |
+
report.append(f" _⚠️ Object on floor - high visibility required!_")
|
839 |
+
report.append("")
|
840 |
+
|
841 |
+
if high_issues:
|
842 |
+
report.append("### 🟠 High Priority Issues")
|
843 |
+
for issue in high_issues[:3]: # Show top 3
|
844 |
+
cats = f"{issue['categories'][0]} ↔ {issue['categories'][1]}"
|
845 |
+
ratio = issue['wcag_ratio']
|
846 |
+
report.append(f"- **{cats}**: {ratio:.1f}:1 contrast ratio")
|
847 |
+
report.append("")
|
848 |
+
|
849 |
+
# Enhanced recommendations
|
850 |
+
report.append("## 📋 Recommendations")
|
851 |
+
|
852 |
+
# Blackspot-specific recommendations
|
853 |
+
if results['blackspot']:
|
854 |
+
coverage = results['statistics'].get('blackspot', {}).get('coverage_percentage', 0)
|
855 |
+
if coverage > 0:
|
856 |
+
report.append("### Blackspot Mitigation")
|
857 |
+
report.append("- Remove or replace dark-colored floor materials in detected areas")
|
858 |
+
report.append("- Improve lighting in blackspot areas")
|
859 |
+
report.append("- Consider using light-colored rugs or mats to cover blackspots")
|
860 |
+
report.append("- Add visual cues like contrasting tape around problem areas")
|
861 |
+
report.append("")
|
862 |
+
|
863 |
+
# Contrast-specific recommendations
|
864 |
+
contrast_issues = results['statistics'].get('contrast', {}).get('total_issues', 0)
|
865 |
+
if contrast_issues > 0:
|
866 |
+
report.append("### Contrast Improvements")
|
867 |
+
report.append("- Increase lighting in low-contrast areas")
|
868 |
+
report.append("- Use contrasting colors for furniture and floors")
|
869 |
+
report.append("- Add visual markers for important boundaries")
|
870 |
+
report.append("- Consider color therapy guidelines for dementia")
|
871 |
+
report.append("")
|
872 |
+
|
873 |
+
if coverage == 0 and contrast_issues == 0:
|
874 |
+
report.append("✅ **Environment Assessment: EXCELLENT**")
|
875 |
+
report.append("No significant safety issues detected. This environment appears well-suited for individuals with Alzheimer's.")
|
876 |
+
|
877 |
+
return "\n".join(report)
|
878 |
+
|
879 |
+
# Create the interface with enhanced controls
|
880 |
+
title = "🧠 NeuroNest: Advanced Environment Analysis for Alzheimer's Care"
|
881 |
+
description = """
|
882 |
+
**Comprehensive analysis system for creating Alzheimer's-friendly environments**
|
883 |
+
|
884 |
+
This application integrates:
|
885 |
+
- **Semantic Segmentation**: Identifies rooms, furniture, and objects
|
886 |
+
- **Enhanced Blackspot Detection**: Locates and visualizes dangerous black areas on floors
|
887 |
+
- **Contrast Analysis**: Evaluates color contrast for visual accessibility
|
888 |
+
"""
|
889 |
+
|
890 |
+
with gr.Blocks(
|
891 |
+
title=title,
|
892 |
+
theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue"),
|
893 |
+
css="""
|
894 |
+
.main-header { text-align: center; margin-bottom: 2rem; }
|
895 |
+
.analysis-section { border: 2px solid #f0f0f0; border-radius: 10px; padding: 1rem; margin: 1rem 0; }
|
896 |
+
.critical-text { color: #ff0000; font-weight: bold; }
|
897 |
+
.high-text { color: #ff8800; font-weight: bold; }
|
898 |
+
.medium-text { color: #ffaa00; font-weight: bold; }
|
899 |
+
"""
|
900 |
+
) as interface:
|
901 |
+
|
902 |
+
gr.Markdown(f"# {title}")
|
903 |
+
gr.Markdown(description)
|
904 |
+
|
905 |
+
with gr.Row():
|
906 |
+
# Input Column
|
907 |
+
with gr.Column(scale=1):
|
908 |
+
# Image upload
|
909 |
+
image_input = gr.Image(
|
910 |
+
label="📸 Upload Room Image",
|
911 |
+
type="filepath",
|
912 |
+
height=300
|
913 |
+
)
|
914 |
+
|
915 |
+
# Analysis settings
|
916 |
+
with gr.Accordion("🔧 Analysis Settings", open=True):
|
917 |
+
enable_blackspot = gr.Checkbox(
|
918 |
+
value=blackspot_ok,
|
919 |
+
label="Enable Blackspot Detection",
|
920 |
+
interactive=blackspot_ok
|
921 |
+
)
|
922 |
+
|
923 |
+
blackspot_threshold = gr.Slider(
|
924 |
+
minimum=0.1,
|
925 |
+
maximum=0.9,
|
926 |
+
value=0.5,
|
927 |
+
step=0.05,
|
928 |
+
label="Blackspot Detection Threshold",
|
929 |
+
visible=blackspot_ok
|
930 |
+
)
|
931 |
+
|
932 |
+
# NEW: Blackspot visualization options
|
933 |
+
blackspot_view_type = gr.Radio(
|
934 |
+
choices=["High Contrast", "Segmentation Only", "Blackspots Only", "Side by Side", "Annotated"],
|
935 |
+
value="High Contrast",
|
936 |
+
label="Blackspot Visualization Style",
|
937 |
+
visible=blackspot_ok
|
938 |
+
)
|
939 |
+
|
940 |
+
enable_contrast = gr.Checkbox(
|
941 |
+
value=True,
|
942 |
+
label="Enable Contrast Analysis"
|
943 |
+
)
|
944 |
+
|
945 |
+
contrast_threshold = gr.Slider(
|
946 |
+
minimum=1.0,
|
947 |
+
maximum=10.0,
|
948 |
+
value=4.5,
|
949 |
+
step=0.1,
|
950 |
+
label="WCAG Contrast Threshold"
|
951 |
+
)
|
952 |
+
|
953 |
+
# Analysis button
|
954 |
+
analyze_button = gr.Button(
|
955 |
+
"🔍 Analyze Environment",
|
956 |
+
variant="primary",
|
957 |
+
size="lg"
|
958 |
+
)
|
959 |
+
|
960 |
+
# Output Column
|
961 |
+
with gr.Column(scale=2):
|
962 |
+
# Main display (Segmentation by default)
|
963 |
+
main_display = gr.Image(
|
964 |
+
label="🎯 Object Detection & Segmentation",
|
965 |
+
height=400,
|
966 |
+
interactive=False
|
967 |
+
)
|
968 |
+
|
969 |
+
# Enhanced analysis tabs
|
970 |
+
with gr.Tabs():
|
971 |
+
with gr.Tab("📊 Analysis Report"):
|
972 |
+
analysis_report = gr.Markdown(
|
973 |
+
value="Upload an image and click 'Analyze Environment' to see results.",
|
974 |
+
elem_classes=["analysis-section"]
|
975 |
+
)
|
976 |
+
|
977 |
+
if blackspot_ok:
|
978 |
+
with gr.Tab("⚫ Blackspot Detection"):
|
979 |
+
blackspot_display = gr.Image(
|
980 |
+
label="Blackspot Analysis (Selected View)",
|
981 |
+
height=300,
|
982 |
+
interactive=False
|
983 |
+
)
|
984 |
+
|
985 |
+
with gr.Tab("🔍 Blackspot Segmentation"):
|
986 |
+
blackspot_segmentation_display = gr.Image(
|
987 |
+
label="Pure Blackspot Segmentation",
|
988 |
+
height=300,
|
989 |
+
interactive=False
|
990 |
+
)
|
991 |
+
else:
|
992 |
+
blackspot_display = gr.Image(visible=False)
|
993 |
+
blackspot_segmentation_display = gr.Image(visible=False)
|
994 |
+
|
995 |
+
with gr.Tab("🎨 Contrast Analysis"):
|
996 |
+
contrast_display = gr.Image(
|
997 |
+
label="Contrast Issues Visualization",
|
998 |
+
height=300,
|
999 |
+
interactive=False
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
# Connect the interface
|
1003 |
+
analyze_button.click(
|
1004 |
+
fn=analyze_wrapper,
|
1005 |
+
inputs=[
|
1006 |
+
image_input,
|
1007 |
+
blackspot_threshold,
|
1008 |
+
contrast_threshold,
|
1009 |
+
enable_blackspot,
|
1010 |
+
enable_contrast,
|
1011 |
+
blackspot_view_type
|
1012 |
+
],
|
1013 |
+
outputs=[
|
1014 |
+
main_display,
|
1015 |
+
blackspot_display,
|
1016 |
+
blackspot_segmentation_display,
|
1017 |
+
contrast_display,
|
1018 |
+
analysis_report
|
1019 |
+
]
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
# Example images (optional)
|
1023 |
+
example_dir = Path("examples")
|
1024 |
+
if example_dir.exists():
|
1025 |
+
examples = [
|
1026 |
+
[str(img), 0.5, 4.5, True, True, "High Contrast"]
|
1027 |
+
for img in example_dir.glob("*.jpg")
|
1028 |
+
]
|
1029 |
+
|
1030 |
+
if examples:
|
1031 |
+
gr.Examples(
|
1032 |
+
examples=examples[:3], # Show max 3 examples
|
1033 |
+
inputs=[
|
1034 |
+
image_input,
|
1035 |
+
blackspot_threshold,
|
1036 |
+
contrast_threshold,
|
1037 |
+
enable_blackspot,
|
1038 |
+
enable_contrast,
|
1039 |
+
blackspot_view_type
|
1040 |
+
],
|
1041 |
+
outputs=[
|
1042 |
+
main_display,
|
1043 |
+
blackspot_display,
|
1044 |
+
blackspot_segmentation_display,
|
1045 |
+
contrast_display,
|
1046 |
+
analysis_report
|
1047 |
+
],
|
1048 |
+
fn=analyze_wrapper,
|
1049 |
+
label="🖼️ Example Images"
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
# Footer
|
1053 |
+
gr.Markdown("""
|
1054 |
+
---
|
1055 |
+
**NeuroNest** - Advanced AI for Alzheimer's-friendly environments
|
1056 |
+
*Helping create safer, more accessible spaces for cognitive health*
|
1057 |
+
""")
|
1058 |
+
|
1059 |
+
return interface
|
1060 |
+
|
1061 |
+
###############################
|
1062 |
+
# MAIN EXECUTION - FIXED
|
1063 |
+
########################################
|
1064 |
+
|
1065 |
+
if __name__ == "__main__":
|
1066 |
+
print(f"🚀 Starting NeuroNest on {DEVICE}")
|
1067 |
+
print(f"OneFormer available: {ONEFORMER_AVAILABLE}")
|
1068 |
+
|
1069 |
+
try:
|
1070 |
+
interface = create_gradio_interface()
|
1071 |
+
|
1072 |
+
# Fixed launch call - removed incompatible parameters
|
1073 |
+
interface.queue(max_size=10).launch(
|
1074 |
+
server_name="0.0.0.0",
|
1075 |
+
server_port=7860,
|
1076 |
+
share=True
|
1077 |
+
)
|
1078 |
+
except Exception as e:
|
1079 |
+
logger.error(f"Failed to launch application: {e}")
|
1080 |
+
raise
|
oneformer/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
oneformer/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import data # register all new datasets
|
3 |
+
from . import modeling
|
4 |
+
|
5 |
+
# config
|
6 |
+
from .config import *
|
7 |
+
|
8 |
+
# models
|
9 |
+
from .oneformer_model import OneFormer
|
oneformer/config.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
from detectron2.config import CfgNode as CN
|
4 |
+
|
5 |
+
__all__ = ["add_common_config", "add_oneformer_config", "add_swin_config",
|
6 |
+
"add_dinat_config", "add_beit_adapter_config", "add_convnext_config"]
|
7 |
+
|
8 |
+
def add_common_config(cfg):
|
9 |
+
"""
|
10 |
+
Add config for common configuration
|
11 |
+
"""
|
12 |
+
# data config
|
13 |
+
# select the dataset mapper
|
14 |
+
cfg.INPUT.DATASET_MAPPER_NAME = "oneformer_unified"
|
15 |
+
# Color augmentation
|
16 |
+
cfg.INPUT.COLOR_AUG_SSD = False
|
17 |
+
# We retry random cropping until no single category in semantic segmentation GT occupies more
|
18 |
+
# than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
|
19 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
|
20 |
+
# Pad image and segmentation GT in dataset mapper.
|
21 |
+
cfg.INPUT.SIZE_DIVISIBILITY = -1
|
22 |
+
|
23 |
+
cfg.INPUT.TASK_SEQ_LEN = 77
|
24 |
+
cfg.INPUT.MAX_SEQ_LEN = 77
|
25 |
+
|
26 |
+
cfg.INPUT.TASK_PROB = CN()
|
27 |
+
cfg.INPUT.TASK_PROB.SEMANTIC = 0.33
|
28 |
+
cfg.INPUT.TASK_PROB.INSTANCE = 0.66
|
29 |
+
|
30 |
+
# test dataset
|
31 |
+
cfg.DATASETS.TEST_PANOPTIC = ("",)
|
32 |
+
cfg.DATASETS.TEST_INSTANCE = ("",)
|
33 |
+
cfg.DATASETS.TEST_SEMANTIC = ("",)
|
34 |
+
|
35 |
+
# solver config
|
36 |
+
# weight decay on embedding
|
37 |
+
cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
|
38 |
+
# optimizer
|
39 |
+
cfg.SOLVER.OPTIMIZER = "ADAMW"
|
40 |
+
cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
|
41 |
+
|
42 |
+
# wandb
|
43 |
+
cfg.WANDB = CN()
|
44 |
+
cfg.WANDB.PROJECT = "unified_dense_recognition"
|
45 |
+
cfg.WANDB.NAME = None
|
46 |
+
|
47 |
+
cfg.MODEL.IS_TRAIN = False
|
48 |
+
cfg.MODEL.IS_DEMO = True
|
49 |
+
|
50 |
+
# text encoder config
|
51 |
+
cfg.MODEL.TEXT_ENCODER = CN()
|
52 |
+
|
53 |
+
cfg.MODEL.TEXT_ENCODER.WIDTH = 256
|
54 |
+
cfg.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
|
55 |
+
cfg.MODEL.TEXT_ENCODER.NUM_LAYERS = 12
|
56 |
+
cfg.MODEL.TEXT_ENCODER.VOCAB_SIZE = 49408
|
57 |
+
cfg.MODEL.TEXT_ENCODER.PROJ_NUM_LAYERS = 2
|
58 |
+
cfg.MODEL.TEXT_ENCODER.N_CTX = 16
|
59 |
+
|
60 |
+
# mask_former inference config
|
61 |
+
cfg.MODEL.TEST = CN()
|
62 |
+
cfg.MODEL.TEST.SEMANTIC_ON = True
|
63 |
+
cfg.MODEL.TEST.INSTANCE_ON = False
|
64 |
+
cfg.MODEL.TEST.PANOPTIC_ON = False
|
65 |
+
cfg.MODEL.TEST.DETECTION_ON = False
|
66 |
+
cfg.MODEL.TEST.OBJECT_MASK_THRESHOLD = 0.0
|
67 |
+
cfg.MODEL.TEST.OVERLAP_THRESHOLD = 0.0
|
68 |
+
cfg.MODEL.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
|
69 |
+
cfg.MODEL.TEST.TASK = "panoptic"
|
70 |
+
|
71 |
+
# TEST AUG Slide
|
72 |
+
cfg.TEST.AUG.IS_SLIDE = False
|
73 |
+
cfg.TEST.AUG.CROP_SIZE = (640, 640)
|
74 |
+
cfg.TEST.AUG.STRIDE = (426, 426)
|
75 |
+
cfg.TEST.AUG.SCALE = (2048, 640)
|
76 |
+
cfg.TEST.AUG.SETR_MULTI_SCALE = True
|
77 |
+
cfg.TEST.AUG.KEEP_RATIO = True
|
78 |
+
cfg.TEST.AUG.SIZE_DIVISOR = 32
|
79 |
+
|
80 |
+
# pixel decoder config
|
81 |
+
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
|
82 |
+
# adding transformer in pixel decoder
|
83 |
+
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
|
84 |
+
# pixel decoder
|
85 |
+
cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
|
86 |
+
cfg.MODEL.SEM_SEG_HEAD.SEM_EMBED_DIM = 256
|
87 |
+
cfg.MODEL.SEM_SEG_HEAD.INST_EMBED_DIM = 256
|
88 |
+
|
89 |
+
# LSJ aug
|
90 |
+
cfg.INPUT.IMAGE_SIZE = 1024
|
91 |
+
cfg.INPUT.MIN_SCALE = 0.1
|
92 |
+
cfg.INPUT.MAX_SCALE = 2.0
|
93 |
+
|
94 |
+
# MSDeformAttn encoder configs
|
95 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["res3", "res4", "res5"]
|
96 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4
|
97 |
+
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_HEADS = 8
|
98 |
+
|
99 |
+
def add_oneformer_config(cfg):
|
100 |
+
"""
|
101 |
+
Add config for ONE_FORMER.
|
102 |
+
"""
|
103 |
+
|
104 |
+
# mask_former model config
|
105 |
+
cfg.MODEL.ONE_FORMER = CN()
|
106 |
+
|
107 |
+
# loss
|
108 |
+
cfg.MODEL.ONE_FORMER.DEEP_SUPERVISION = True
|
109 |
+
cfg.MODEL.ONE_FORMER.NO_OBJECT_WEIGHT = 0.1
|
110 |
+
cfg.MODEL.ONE_FORMER.CLASS_WEIGHT = 1.0
|
111 |
+
cfg.MODEL.ONE_FORMER.DICE_WEIGHT = 1.0
|
112 |
+
cfg.MODEL.ONE_FORMER.MASK_WEIGHT = 20.0
|
113 |
+
cfg.MODEL.ONE_FORMER.CONTRASTIVE_WEIGHT = 0.5
|
114 |
+
cfg.MODEL.ONE_FORMER.CONTRASTIVE_TEMPERATURE = 0.07
|
115 |
+
|
116 |
+
# transformer config
|
117 |
+
cfg.MODEL.ONE_FORMER.NHEADS = 8
|
118 |
+
cfg.MODEL.ONE_FORMER.DROPOUT = 0.1
|
119 |
+
cfg.MODEL.ONE_FORMER.DIM_FEEDFORWARD = 2048
|
120 |
+
cfg.MODEL.ONE_FORMER.ENC_LAYERS = 0
|
121 |
+
cfg.MODEL.ONE_FORMER.CLASS_DEC_LAYERS = 2
|
122 |
+
cfg.MODEL.ONE_FORMER.DEC_LAYERS = 6
|
123 |
+
cfg.MODEL.ONE_FORMER.PRE_NORM = False
|
124 |
+
|
125 |
+
cfg.MODEL.ONE_FORMER.HIDDEN_DIM = 256
|
126 |
+
cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES = 120
|
127 |
+
cfg.MODEL.ONE_FORMER.NUM_OBJECT_CTX = 16
|
128 |
+
cfg.MODEL.ONE_FORMER.USE_TASK_NORM = True
|
129 |
+
|
130 |
+
cfg.MODEL.ONE_FORMER.TRANSFORMER_IN_FEATURE = "res5"
|
131 |
+
cfg.MODEL.ONE_FORMER.ENFORCE_INPUT_PROJ = False
|
132 |
+
|
133 |
+
# Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
|
134 |
+
# you can use this config to override
|
135 |
+
cfg.MODEL.ONE_FORMER.SIZE_DIVISIBILITY = 32
|
136 |
+
|
137 |
+
# transformer module
|
138 |
+
cfg.MODEL.ONE_FORMER.TRANSFORMER_DECODER_NAME = "ContrastiveMultiScaleMaskedTransformerDecoder"
|
139 |
+
|
140 |
+
# point loss configs
|
141 |
+
# Number of points sampled during training for a mask point head.
|
142 |
+
cfg.MODEL.ONE_FORMER.TRAIN_NUM_POINTS = 112 * 112
|
143 |
+
# Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
|
144 |
+
# original paper.
|
145 |
+
cfg.MODEL.ONE_FORMER.OVERSAMPLE_RATIO = 3.0
|
146 |
+
# Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
|
147 |
+
# the original paper.
|
148 |
+
cfg.MODEL.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75
|
149 |
+
|
150 |
+
def add_swin_config(cfg):
|
151 |
+
"""
|
152 |
+
Add config forSWIN Backbone.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# swin transformer backbone
|
156 |
+
cfg.MODEL.SWIN = CN()
|
157 |
+
cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
|
158 |
+
cfg.MODEL.SWIN.PATCH_SIZE = 4
|
159 |
+
cfg.MODEL.SWIN.EMBED_DIM = 96
|
160 |
+
cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
161 |
+
cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
162 |
+
cfg.MODEL.SWIN.WINDOW_SIZE = 7
|
163 |
+
cfg.MODEL.SWIN.MLP_RATIO = 4.0
|
164 |
+
cfg.MODEL.SWIN.QKV_BIAS = True
|
165 |
+
cfg.MODEL.SWIN.QK_SCALE = None
|
166 |
+
cfg.MODEL.SWIN.DROP_RATE = 0.0
|
167 |
+
cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
|
168 |
+
cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
|
169 |
+
cfg.MODEL.SWIN.APE = False
|
170 |
+
cfg.MODEL.SWIN.PATCH_NORM = True
|
171 |
+
cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
172 |
+
cfg.MODEL.SWIN.USE_CHECKPOINT = False
|
173 |
+
## Semask additions
|
174 |
+
cfg.MODEL.SWIN.SEM_WINDOW_SIZE = 7
|
175 |
+
cfg.MODEL.SWIN.NUM_SEM_BLOCKS = 1
|
176 |
+
|
177 |
+
def add_dinat_config(cfg):
|
178 |
+
"""
|
179 |
+
Add config for NAT Backbone.
|
180 |
+
"""
|
181 |
+
|
182 |
+
# DINAT transformer backbone
|
183 |
+
cfg.MODEL.DiNAT = CN()
|
184 |
+
cfg.MODEL.DiNAT.DEPTHS = [3, 4, 18, 5]
|
185 |
+
cfg.MODEL.DiNAT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
186 |
+
cfg.MODEL.DiNAT.EMBED_DIM = 64
|
187 |
+
cfg.MODEL.DiNAT.MLP_RATIO = 3.0
|
188 |
+
cfg.MODEL.DiNAT.NUM_HEADS = [2, 4, 8, 16]
|
189 |
+
cfg.MODEL.DiNAT.DROP_PATH_RATE = 0.2
|
190 |
+
cfg.MODEL.DiNAT.KERNEL_SIZE = 7
|
191 |
+
cfg.MODEL.DiNAT.DILATIONS = [[1, 16, 1], [1, 4, 1, 8], [1, 2, 1, 3, 1, 4], [1, 2, 1, 2, 1]]
|
192 |
+
cfg.MODEL.DiNAT.OUT_INDICES = (0, 1, 2, 3)
|
193 |
+
cfg.MODEL.DiNAT.QKV_BIAS = True
|
194 |
+
cfg.MODEL.DiNAT.QK_SCALE = None
|
195 |
+
cfg.MODEL.DiNAT.DROP_RATE = 0
|
196 |
+
cfg.MODEL.DiNAT.ATTN_DROP_RATE = 0.
|
197 |
+
cfg.MODEL.DiNAT.IN_PATCH_SIZE = 4
|
198 |
+
|
199 |
+
def add_convnext_config(cfg):
|
200 |
+
"""
|
201 |
+
Add config for ConvNeXt Backbone.
|
202 |
+
"""
|
203 |
+
|
204 |
+
# swin transformer backbone
|
205 |
+
cfg.MODEL.CONVNEXT = CN()
|
206 |
+
cfg.MODEL.CONVNEXT.IN_CHANNELS = 3
|
207 |
+
cfg.MODEL.CONVNEXT.DEPTHS = [3, 3, 27, 3]
|
208 |
+
cfg.MODEL.CONVNEXT.DIMS = [192, 384, 768, 1536]
|
209 |
+
cfg.MODEL.CONVNEXT.DROP_PATH_RATE = 0.4
|
210 |
+
cfg.MODEL.CONVNEXT.LSIT = 1.0
|
211 |
+
cfg.MODEL.CONVNEXT.OUT_INDICES = [0, 1, 2, 3]
|
212 |
+
cfg.MODEL.CONVNEXT.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
213 |
+
|
214 |
+
def add_beit_adapter_config(cfg):
|
215 |
+
"""
|
216 |
+
Add config for BEiT Adapter Backbone.
|
217 |
+
"""
|
218 |
+
|
219 |
+
# beit adapter backbone
|
220 |
+
cfg.MODEL.BEiTAdapter = CN()
|
221 |
+
cfg.MODEL.BEiTAdapter.IMG_SIZE = 640
|
222 |
+
cfg.MODEL.BEiTAdapter.PATCH_SIZE = 16
|
223 |
+
cfg.MODEL.BEiTAdapter.EMBED_DIM = 1024
|
224 |
+
cfg.MODEL.BEiTAdapter.DEPTH = 24
|
225 |
+
cfg.MODEL.BEiTAdapter.NUM_HEADS = 16
|
226 |
+
cfg.MODEL.BEiTAdapter.MLP_RATIO = 4
|
227 |
+
cfg.MODEL.BEiTAdapter.QKV_BIAS = True
|
228 |
+
cfg.MODEL.BEiTAdapter.USE_ABS_POS_EMB = False
|
229 |
+
cfg.MODEL.BEiTAdapter.USE_REL_POS_BIAS = True
|
230 |
+
cfg.MODEL.BEiTAdapter.INIT_VALUES = 1e-6
|
231 |
+
cfg.MODEL.BEiTAdapter.DROP_PATH_RATE = 0.3
|
232 |
+
cfg.MODEL.BEiTAdapter.CONV_INPLANE = 64
|
233 |
+
cfg.MODEL.BEiTAdapter.N_POINTS = 4
|
234 |
+
cfg.MODEL.BEiTAdapter.DEFORM_NUM_HEADS = 16
|
235 |
+
cfg.MODEL.BEiTAdapter.CFFN_RATIO = 0.25
|
236 |
+
cfg.MODEL.BEiTAdapter.DEFORM_RATIO = 0.5
|
237 |
+
cfg.MODEL.BEiTAdapter.WITH_CP = True
|
238 |
+
cfg.MODEL.BEiTAdapter.INTERACTION_INDEXES=[[0, 5], [6, 11], [12, 17], [18, 23]]
|
239 |
+
cfg.MODEL.BEiTAdapter.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
|
oneformer/data/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from . import datasets
|
oneformer/data/bpe_simple_vocab_16e6.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
oneformer/data/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
oneformer/data/build.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
import torch.utils.data as torchdata
|
4 |
+
|
5 |
+
from detectron2.config import configurable
|
6 |
+
|
7 |
+
|
8 |
+
from detectron2.data.common import DatasetFromList, MapDataset
|
9 |
+
from detectron2.data.dataset_mapper import DatasetMapper
|
10 |
+
from detectron2.data.samplers import (
|
11 |
+
InferenceSampler,
|
12 |
+
)
|
13 |
+
from detectron2.data.build import (
|
14 |
+
get_detection_dataset_dicts,
|
15 |
+
trivial_batch_collator
|
16 |
+
)
|
17 |
+
"""
|
18 |
+
This file contains the default logic to build a dataloader for training or testing.
|
19 |
+
"""
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"build_detection_test_loader",
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def _test_loader_from_config(cfg, dataset_name, mapper=None):
|
27 |
+
"""
|
28 |
+
Uses the given `dataset_name` argument (instead of the names in cfg), because the
|
29 |
+
standard practice is to evaluate each test set individually (not combining them).
|
30 |
+
"""
|
31 |
+
if isinstance(dataset_name, str):
|
32 |
+
dataset_name = [dataset_name]
|
33 |
+
|
34 |
+
dataset = get_detection_dataset_dicts(
|
35 |
+
dataset_name,
|
36 |
+
filter_empty=False,
|
37 |
+
proposal_files=[
|
38 |
+
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
|
39 |
+
]
|
40 |
+
if cfg.MODEL.LOAD_PROPOSALS
|
41 |
+
else None,
|
42 |
+
)
|
43 |
+
if mapper is None:
|
44 |
+
mapper = DatasetMapper(cfg, False)
|
45 |
+
return {
|
46 |
+
"dataset": dataset,
|
47 |
+
"mapper": mapper,
|
48 |
+
"num_workers": cfg.DATALOADER.NUM_WORKERS,
|
49 |
+
"sampler": InferenceSampler(len(dataset))
|
50 |
+
if not isinstance(dataset, torchdata.IterableDataset)
|
51 |
+
else None,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
@configurable(from_config=_test_loader_from_config)
|
56 |
+
def build_detection_test_loader(
|
57 |
+
dataset: Union[List[Any], torchdata.Dataset],
|
58 |
+
*,
|
59 |
+
mapper: Callable[[Dict[str, Any]], Any],
|
60 |
+
sampler: Optional[torchdata.Sampler] = None,
|
61 |
+
batch_size: int = 1,
|
62 |
+
num_workers: int = 0,
|
63 |
+
collate_fn: Optional[Callable[[List[Any]], Any]] = None,
|
64 |
+
) -> torchdata.DataLoader:
|
65 |
+
"""
|
66 |
+
Similar to `build_detection_train_loader`, with default batch size = 1,
|
67 |
+
and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
|
68 |
+
to produce the exact set of all samples.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
dataset: a list of dataset dicts,
|
72 |
+
or a pytorch dataset (either map-style or iterable). They can be obtained
|
73 |
+
by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
|
74 |
+
mapper: a callable which takes a sample (dict) from dataset
|
75 |
+
and returns the format to be consumed by the model.
|
76 |
+
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
|
77 |
+
sampler: a sampler that produces
|
78 |
+
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
|
79 |
+
which splits the dataset across all workers. Sampler must be None
|
80 |
+
if `dataset` is iterable.
|
81 |
+
batch_size: the batch size of the data loader to be created.
|
82 |
+
Default to 1 image per worker since this is the standard when reporting
|
83 |
+
inference time in papers.
|
84 |
+
num_workers: number of parallel data loading workers
|
85 |
+
collate_fn: same as the argument of `torch.utils.data.DataLoader`.
|
86 |
+
Defaults to do no collation and return a list of data.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
DataLoader: a torch DataLoader, that loads the given detection
|
90 |
+
dataset, with test-time transformation and batching.
|
91 |
+
|
92 |
+
Examples:
|
93 |
+
::
|
94 |
+
data_loader = build_detection_test_loader(
|
95 |
+
DatasetRegistry.get("my_test"),
|
96 |
+
mapper=DatasetMapper(...))
|
97 |
+
|
98 |
+
# or, instantiate with a CfgNode:
|
99 |
+
data_loader = build_detection_test_loader(cfg, "my_test")
|
100 |
+
"""
|
101 |
+
if isinstance(dataset, list):
|
102 |
+
dataset = DatasetFromList(dataset, copy=False)
|
103 |
+
if mapper is not None:
|
104 |
+
dataset = MapDataset(dataset, mapper)
|
105 |
+
if isinstance(dataset, torchdata.IterableDataset):
|
106 |
+
assert sampler is None, "sampler must be None if dataset is IterableDataset"
|
107 |
+
else:
|
108 |
+
if sampler is None:
|
109 |
+
sampler = InferenceSampler(len(dataset))
|
110 |
+
return torchdata.DataLoader(
|
111 |
+
dataset,
|
112 |
+
batch_size=batch_size,
|
113 |
+
sampler=sampler,
|
114 |
+
drop_last=False,
|
115 |
+
num_workers=num_workers,
|
116 |
+
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
|
117 |
+
)
|
oneformer/data/dataset_mappers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
oneformer/data/dataset_mappers/coco_unified_new_baseline_dataset_mapper.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/coco_panoptic_new_baseline_dataset_mapper.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from detectron2.data import MetadataCatalog
|
13 |
+
from detectron2.config import configurable
|
14 |
+
from detectron2.data import detection_utils as utils
|
15 |
+
from detectron2.data import transforms as T
|
16 |
+
from detectron2.structures import BitMasks, Instances
|
17 |
+
from oneformer.utils.box_ops import masks_to_boxes
|
18 |
+
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
|
19 |
+
|
20 |
+
__all__ = ["COCOUnifiedNewBaselineDatasetMapper"]
|
21 |
+
|
22 |
+
|
23 |
+
def build_transform_gen(cfg, is_train):
|
24 |
+
"""
|
25 |
+
Create a list of default :class:`Augmentation` from config.
|
26 |
+
Now it includes resizing and flipping.
|
27 |
+
Returns:
|
28 |
+
list[Augmentation]
|
29 |
+
"""
|
30 |
+
assert is_train, "Only support training augmentation"
|
31 |
+
image_size = cfg.INPUT.IMAGE_SIZE
|
32 |
+
min_scale = cfg.INPUT.MIN_SCALE
|
33 |
+
max_scale = cfg.INPUT.MAX_SCALE
|
34 |
+
|
35 |
+
augmentation = []
|
36 |
+
|
37 |
+
if cfg.INPUT.RANDOM_FLIP != "none":
|
38 |
+
augmentation.append(
|
39 |
+
T.RandomFlip(
|
40 |
+
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
|
41 |
+
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
|
42 |
+
)
|
43 |
+
)
|
44 |
+
|
45 |
+
augmentation.extend([
|
46 |
+
T.ResizeScale(
|
47 |
+
min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size
|
48 |
+
),
|
49 |
+
T.FixedSizeCrop(crop_size=(image_size, image_size)),
|
50 |
+
])
|
51 |
+
|
52 |
+
return augmentation
|
53 |
+
|
54 |
+
|
55 |
+
# This is specifically designed for the COCO dataset.
|
56 |
+
class COCOUnifiedNewBaselineDatasetMapper:
|
57 |
+
"""
|
58 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
59 |
+
and map it into a format used by OneFormer.
|
60 |
+
|
61 |
+
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation.
|
62 |
+
|
63 |
+
The callable currently does the following:
|
64 |
+
|
65 |
+
1. Read the image from "file_name"
|
66 |
+
2. Applies geometric transforms to the image and annotation
|
67 |
+
3. Find and applies suitable cropping to the image and annotation
|
68 |
+
4. Prepare image and annotation to Tensors
|
69 |
+
"""
|
70 |
+
|
71 |
+
@configurable
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
is_train=True,
|
75 |
+
*,
|
76 |
+
num_queries,
|
77 |
+
tfm_gens,
|
78 |
+
meta,
|
79 |
+
image_format,
|
80 |
+
max_seq_len,
|
81 |
+
task_seq_len,
|
82 |
+
semantic_prob,
|
83 |
+
instance_prob,
|
84 |
+
):
|
85 |
+
"""
|
86 |
+
NOTE: this interface is experimental.
|
87 |
+
Args:
|
88 |
+
is_train: for training or inference
|
89 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
90 |
+
crop_gen: crop augmentation
|
91 |
+
tfm_gens: data augmentation
|
92 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
93 |
+
"""
|
94 |
+
self.tfm_gens = tfm_gens
|
95 |
+
logging.getLogger(__name__).info(
|
96 |
+
"[COCOUnifiedNewBaselineDatasetMapper] Full TransformGens used in training: {}".format(
|
97 |
+
str(self.tfm_gens)
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
self.img_format = image_format
|
102 |
+
self.is_train = is_train
|
103 |
+
self.meta = meta
|
104 |
+
self.ignore_label = self.meta.ignore_label
|
105 |
+
self.num_queries = num_queries
|
106 |
+
|
107 |
+
self.things = []
|
108 |
+
for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
|
109 |
+
self.things.append(v)
|
110 |
+
self.class_names = self.meta.stuff_classes
|
111 |
+
self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
|
112 |
+
self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
|
113 |
+
self.semantic_prob = semantic_prob
|
114 |
+
self.instance_prob = instance_prob
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_config(cls, cfg, is_train=True):
|
118 |
+
# Build augmentation
|
119 |
+
tfm_gens = build_transform_gen(cfg, is_train)
|
120 |
+
dataset_names = cfg.DATASETS.TRAIN
|
121 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
122 |
+
|
123 |
+
ret = {
|
124 |
+
"is_train": is_train,
|
125 |
+
"meta": meta,
|
126 |
+
"tfm_gens": tfm_gens,
|
127 |
+
"image_format": cfg.INPUT.FORMAT,
|
128 |
+
"num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
|
129 |
+
"task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
|
130 |
+
"max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
|
131 |
+
"semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
|
132 |
+
"instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
|
133 |
+
}
|
134 |
+
return ret
|
135 |
+
|
136 |
+
def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
137 |
+
instances = Instances(image_shape)
|
138 |
+
|
139 |
+
classes = []
|
140 |
+
texts = ["a semantic photo"] * self.num_queries
|
141 |
+
masks = []
|
142 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
143 |
+
|
144 |
+
for segment_info in segments_info:
|
145 |
+
class_id = segment_info["category_id"]
|
146 |
+
if not segment_info["iscrowd"]:
|
147 |
+
mask = pan_seg_gt == segment_info["id"]
|
148 |
+
if not np.all(mask == False):
|
149 |
+
if class_id not in classes:
|
150 |
+
cls_name = self.class_names[class_id]
|
151 |
+
classes.append(class_id)
|
152 |
+
masks.append(mask)
|
153 |
+
num_class_obj[cls_name] += 1
|
154 |
+
else:
|
155 |
+
idx = classes.index(class_id)
|
156 |
+
masks[idx] += mask
|
157 |
+
masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
|
158 |
+
label[mask] = class_id
|
159 |
+
|
160 |
+
num = 0
|
161 |
+
for i, cls_name in enumerate(self.class_names):
|
162 |
+
if num_class_obj[cls_name] > 0:
|
163 |
+
for _ in range(num_class_obj[cls_name]):
|
164 |
+
if num >= len(texts):
|
165 |
+
break
|
166 |
+
texts[num] = f"a photo with a {cls_name}"
|
167 |
+
num += 1
|
168 |
+
|
169 |
+
classes = np.array(classes)
|
170 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
171 |
+
if len(masks) == 0:
|
172 |
+
# Some image does not have annotation (all ignored)
|
173 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
174 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
175 |
+
else:
|
176 |
+
masks = BitMasks(
|
177 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
178 |
+
)
|
179 |
+
instances.gt_masks = masks.tensor
|
180 |
+
# Placeholder bounding boxes for stuff regions. Note that these are not used during training.
|
181 |
+
instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
|
182 |
+
return instances, texts, label
|
183 |
+
|
184 |
+
def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
185 |
+
instances = Instances(image_shape)
|
186 |
+
|
187 |
+
classes = []
|
188 |
+
texts = ["an instance photo"] * self.num_queries
|
189 |
+
masks = []
|
190 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
191 |
+
|
192 |
+
for segment_info in segments_info:
|
193 |
+
class_id = segment_info["category_id"]
|
194 |
+
if class_id in self.things:
|
195 |
+
if not segment_info["iscrowd"]:
|
196 |
+
mask = pan_seg_gt == segment_info["id"]
|
197 |
+
if not np.all(mask == False):
|
198 |
+
cls_name = self.class_names[class_id]
|
199 |
+
classes.append(class_id)
|
200 |
+
masks.append(mask)
|
201 |
+
num_class_obj[cls_name] += 1
|
202 |
+
label[mask] = class_id
|
203 |
+
|
204 |
+
num = 0
|
205 |
+
for i, cls_name in enumerate(self.class_names):
|
206 |
+
if num_class_obj[cls_name] > 0:
|
207 |
+
for _ in range(num_class_obj[cls_name]):
|
208 |
+
if num >= len(texts):
|
209 |
+
break
|
210 |
+
texts[num] = f"a photo with a {cls_name}"
|
211 |
+
num += 1
|
212 |
+
|
213 |
+
classes = np.array(classes)
|
214 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
215 |
+
if len(masks) == 0:
|
216 |
+
# Some image does not have annotation (all ignored)
|
217 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
218 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
219 |
+
else:
|
220 |
+
masks = BitMasks(
|
221 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
222 |
+
)
|
223 |
+
instances.gt_masks = masks.tensor
|
224 |
+
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
|
225 |
+
return instances, texts, label
|
226 |
+
|
227 |
+
def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
228 |
+
instances = Instances(image_shape)
|
229 |
+
|
230 |
+
classes = []
|
231 |
+
texts = ["a panoptic photo"] * self.num_queries
|
232 |
+
masks = []
|
233 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
234 |
+
|
235 |
+
for segment_info in segments_info:
|
236 |
+
class_id = segment_info["category_id"]
|
237 |
+
if not segment_info["iscrowd"]:
|
238 |
+
mask = pan_seg_gt == segment_info["id"]
|
239 |
+
if not np.all(mask == False):
|
240 |
+
cls_name = self.class_names[class_id]
|
241 |
+
classes.append(class_id)
|
242 |
+
masks.append(mask)
|
243 |
+
num_class_obj[cls_name] += 1
|
244 |
+
label[mask] = class_id
|
245 |
+
|
246 |
+
num = 0
|
247 |
+
for i, cls_name in enumerate(self.class_names):
|
248 |
+
if num_class_obj[cls_name] > 0:
|
249 |
+
for _ in range(num_class_obj[cls_name]):
|
250 |
+
if num >= len(texts):
|
251 |
+
break
|
252 |
+
texts[num] = f"a photo with a {cls_name}"
|
253 |
+
num += 1
|
254 |
+
|
255 |
+
classes = np.array(classes)
|
256 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
257 |
+
if len(masks) == 0:
|
258 |
+
# Some image does not have annotation (all ignored)
|
259 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
260 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
261 |
+
else:
|
262 |
+
masks = BitMasks(
|
263 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
264 |
+
)
|
265 |
+
instances.gt_masks = masks.tensor
|
266 |
+
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
|
267 |
+
for i in range(instances.gt_classes.shape[0]):
|
268 |
+
# Placeholder bounding boxes for stuff regions. Note that these are not used during training.
|
269 |
+
if instances.gt_classes[i].item() not in self.things:
|
270 |
+
instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
|
271 |
+
return instances, texts, label
|
272 |
+
|
273 |
+
def __call__(self, dataset_dict):
|
274 |
+
"""
|
275 |
+
Args:
|
276 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
dict: a format that builtin models in detectron2 accept
|
280 |
+
"""
|
281 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
282 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
283 |
+
utils.check_image_size(dataset_dict, image)
|
284 |
+
|
285 |
+
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
|
286 |
+
image_shape = image.shape[:2] # h, w
|
287 |
+
|
288 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
289 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
290 |
+
# Therefore it's important to use torch.Tensor.
|
291 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
292 |
+
|
293 |
+
if not self.is_train:
|
294 |
+
# USER: Modify this if you want to keep them for some reason.
|
295 |
+
dataset_dict.pop("annotations", None)
|
296 |
+
return dataset_dict
|
297 |
+
|
298 |
+
# semantic segmentation
|
299 |
+
if "sem_seg_file_name" in dataset_dict:
|
300 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
301 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
302 |
+
sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
|
303 |
+
else:
|
304 |
+
sem_seg_gt = None
|
305 |
+
|
306 |
+
if "pan_seg_file_name" in dataset_dict:
|
307 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
308 |
+
segments_info = dataset_dict["segments_info"]
|
309 |
+
|
310 |
+
# apply the same transformation to panoptic segmentation
|
311 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
312 |
+
|
313 |
+
from panopticapi.utils import rgb2id
|
314 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
315 |
+
|
316 |
+
prob_task = np.random.uniform(0,1.)
|
317 |
+
|
318 |
+
num_class_obj = {}
|
319 |
+
|
320 |
+
for name in self.class_names:
|
321 |
+
num_class_obj[name] = 0
|
322 |
+
|
323 |
+
if prob_task < self.semantic_prob:
|
324 |
+
task = "The task is semantic"
|
325 |
+
instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
326 |
+
elif prob_task < self.instance_prob:
|
327 |
+
task = "The task is instance"
|
328 |
+
instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
329 |
+
else:
|
330 |
+
task = "The task is panoptic"
|
331 |
+
instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
332 |
+
|
333 |
+
|
334 |
+
dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
|
335 |
+
dataset_dict["instances"] = instances
|
336 |
+
dataset_dict["orig_shape"] = image_shape
|
337 |
+
dataset_dict["task"] = task
|
338 |
+
dataset_dict["text"] = text
|
339 |
+
dataset_dict["thing_ids"] = self.things
|
340 |
+
|
341 |
+
return dataset_dict
|
oneformer/data/dataset_mappers/dataset_mapper.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/dataset_mapper.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from typing import List, Optional, Union
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from detectron2.config import configurable
|
13 |
+
|
14 |
+
from detectron2.data import detection_utils as utils
|
15 |
+
from detectron2.data import transforms as T
|
16 |
+
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
|
17 |
+
|
18 |
+
__all__ = ["DatasetMapper"]
|
19 |
+
|
20 |
+
|
21 |
+
class DatasetMapper:
|
22 |
+
"""
|
23 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
24 |
+
and map it into a format used by the model.
|
25 |
+
|
26 |
+
This is the default callable to be used to map your dataset dict into training data.
|
27 |
+
You may need to follow it to implement your own one for customized logic,
|
28 |
+
such as a different way to read or transform images.
|
29 |
+
See :doc:`/tutorials/data_loading` for details.
|
30 |
+
|
31 |
+
The callable currently does the following:
|
32 |
+
|
33 |
+
1. Read the image from "file_name"
|
34 |
+
2. Applies cropping/geometric transforms to the image and annotations
|
35 |
+
3. Prepare data and annotations to Tensor and :class:`Instances`
|
36 |
+
"""
|
37 |
+
|
38 |
+
@configurable
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
is_train: bool,
|
42 |
+
*,
|
43 |
+
augmentations: List[Union[T.Augmentation, T.Transform]],
|
44 |
+
image_format: str,
|
45 |
+
task_seq_len: int,
|
46 |
+
task: str = "panoptic",
|
47 |
+
use_instance_mask: bool = False,
|
48 |
+
use_keypoint: bool = False,
|
49 |
+
instance_mask_format: str = "polygon",
|
50 |
+
keypoint_hflip_indices: Optional[np.ndarray] = None,
|
51 |
+
precomputed_proposal_topk: Optional[int] = None,
|
52 |
+
recompute_boxes: bool = False,
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
NOTE: this interface is experimental.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
is_train: whether it's used in training or inference
|
59 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
60 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
61 |
+
use_instance_mask: whether to process instance segmentation annotations, if available
|
62 |
+
use_keypoint: whether to process keypoint annotations if available
|
63 |
+
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
|
64 |
+
masks into this format.
|
65 |
+
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
|
66 |
+
precomputed_proposal_topk: if given, will load pre-computed
|
67 |
+
proposals from dataset_dict and keep the top k proposals for each image.
|
68 |
+
recompute_boxes: whether to overwrite bounding box annotations
|
69 |
+
by computing tight bounding boxes from instance mask annotations.
|
70 |
+
"""
|
71 |
+
if recompute_boxes:
|
72 |
+
assert use_instance_mask, "recompute_boxes requires instance masks"
|
73 |
+
# fmt: off
|
74 |
+
self.is_train = is_train
|
75 |
+
self.augmentations = T.AugmentationList(augmentations)
|
76 |
+
self.image_format = image_format
|
77 |
+
self.use_instance_mask = use_instance_mask
|
78 |
+
self.instance_mask_format = instance_mask_format
|
79 |
+
self.use_keypoint = use_keypoint
|
80 |
+
self.keypoint_hflip_indices = keypoint_hflip_indices
|
81 |
+
self.proposal_topk = precomputed_proposal_topk
|
82 |
+
self.recompute_boxes = recompute_boxes
|
83 |
+
self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
|
84 |
+
self.task = task
|
85 |
+
assert self.task in ["panoptic", "semantic", "instance"]
|
86 |
+
|
87 |
+
# fmt: on
|
88 |
+
logger = logging.getLogger(__name__)
|
89 |
+
mode = "training" if is_train else "inference"
|
90 |
+
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_config(cls, cfg, is_train: bool = True):
|
94 |
+
augs = utils.build_augmentation(cfg, is_train)
|
95 |
+
if cfg.INPUT.CROP.ENABLED and is_train:
|
96 |
+
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
|
97 |
+
recompute_boxes = cfg.MODEL.MASK_ON
|
98 |
+
else:
|
99 |
+
recompute_boxes = False
|
100 |
+
|
101 |
+
ret = {
|
102 |
+
"is_train": is_train,
|
103 |
+
"augmentations": augs,
|
104 |
+
"image_format": cfg.INPUT.FORMAT,
|
105 |
+
"use_instance_mask": cfg.MODEL.MASK_ON,
|
106 |
+
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
|
107 |
+
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
|
108 |
+
"task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
|
109 |
+
"recompute_boxes": recompute_boxes,
|
110 |
+
"task": cfg.MODEL.TEST.TASK,
|
111 |
+
}
|
112 |
+
|
113 |
+
if cfg.MODEL.KEYPOINT_ON:
|
114 |
+
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
|
115 |
+
|
116 |
+
if cfg.MODEL.LOAD_PROPOSALS:
|
117 |
+
ret["precomputed_proposal_topk"] = (
|
118 |
+
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
|
119 |
+
if is_train
|
120 |
+
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
|
121 |
+
)
|
122 |
+
return ret
|
123 |
+
|
124 |
+
def _transform_annotations(self, dataset_dict, transforms, image_shape):
|
125 |
+
# USER: Modify this if you want to keep them for some reason.
|
126 |
+
for anno in dataset_dict["annotations"]:
|
127 |
+
if not self.use_instance_mask:
|
128 |
+
anno.pop("segmentation", None)
|
129 |
+
if not self.use_keypoint:
|
130 |
+
anno.pop("keypoints", None)
|
131 |
+
|
132 |
+
# USER: Implement additional transformations if you have other types of data
|
133 |
+
annos = [
|
134 |
+
utils.transform_instance_annotations(
|
135 |
+
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
|
136 |
+
)
|
137 |
+
for obj in dataset_dict.pop("annotations")
|
138 |
+
if obj.get("iscrowd", 0) == 0
|
139 |
+
]
|
140 |
+
instances = utils.annotations_to_instances(
|
141 |
+
annos, image_shape, mask_format=self.instance_mask_format
|
142 |
+
)
|
143 |
+
|
144 |
+
# After transforms such as cropping are applied, the bounding box may no longer
|
145 |
+
# tightly bound the object. As an example, imagine a triangle object
|
146 |
+
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
|
147 |
+
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
|
148 |
+
# the intersection of original bounding box and the cropping box.
|
149 |
+
if self.recompute_boxes:
|
150 |
+
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
|
151 |
+
dataset_dict["instances"] = utils.filter_empty_instances(instances)
|
152 |
+
|
153 |
+
def __call__(self, dataset_dict):
|
154 |
+
"""
|
155 |
+
Args:
|
156 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
dict: a format that builtin models in detectron2 accept
|
160 |
+
"""
|
161 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
162 |
+
# USER: Write your own image loading if it's not from a file
|
163 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
|
164 |
+
utils.check_image_size(dataset_dict, image)
|
165 |
+
|
166 |
+
task = f"The task is {self.task}"
|
167 |
+
dataset_dict["task"] = task
|
168 |
+
|
169 |
+
# USER: Remove if you don't do semantic/panoptic segmentation.
|
170 |
+
if "sem_seg_file_name" in dataset_dict:
|
171 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
|
172 |
+
else:
|
173 |
+
sem_seg_gt = None
|
174 |
+
|
175 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
176 |
+
transforms = self.augmentations(aug_input)
|
177 |
+
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
|
178 |
+
|
179 |
+
image_shape = image.shape[:2] # h, w
|
180 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
181 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
182 |
+
# Therefore it's important to use torch.Tensor.
|
183 |
+
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
184 |
+
if sem_seg_gt is not None:
|
185 |
+
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
|
186 |
+
|
187 |
+
# USER: Remove if you don't use pre-computed proposals.
|
188 |
+
# Most users would not need this feature.
|
189 |
+
if self.proposal_topk is not None:
|
190 |
+
utils.transform_proposals(
|
191 |
+
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
|
192 |
+
)
|
193 |
+
|
194 |
+
if not self.is_train:
|
195 |
+
# USER: Modify this if you want to keep them for some reason.
|
196 |
+
dataset_dict.pop("annotations", None)
|
197 |
+
dataset_dict.pop("sem_seg_file_name", None)
|
198 |
+
return dataset_dict
|
199 |
+
|
200 |
+
if "annotations" in dataset_dict:
|
201 |
+
self._transform_annotations(dataset_dict, transforms, image_shape)
|
202 |
+
|
203 |
+
return dataset_dict
|
oneformer/data/dataset_mappers/oneformer_unified_dataset_mapper.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/dataset_mappers/mask_former_panoptic_dataset_mapper.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from detectron2.config import configurable
|
15 |
+
from detectron2.data import detection_utils as utils
|
16 |
+
from detectron2.data import transforms as T
|
17 |
+
from detectron2.structures import BitMasks, Instances
|
18 |
+
from detectron2.data import MetadataCatalog
|
19 |
+
from detectron2.projects.point_rend import ColorAugSSDTransform
|
20 |
+
from oneformer.utils.box_ops import masks_to_boxes
|
21 |
+
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize
|
22 |
+
|
23 |
+
__all__ = ["OneFormerUnifiedDatasetMapper"]
|
24 |
+
|
25 |
+
|
26 |
+
class OneFormerUnifiedDatasetMapper:
|
27 |
+
"""
|
28 |
+
A callable which takes a dataset dict in Detectron2 Dataset format,
|
29 |
+
and map it into a format used by OneFormer for universal segmentation.
|
30 |
+
|
31 |
+
The callable currently does the following:
|
32 |
+
|
33 |
+
1. Read the image from "file_name"
|
34 |
+
2. Applies geometric transforms to the image and annotation
|
35 |
+
3. Find and applies suitable cropping to the image and annotation
|
36 |
+
4. Prepare image and annotation to Tensors
|
37 |
+
"""
|
38 |
+
|
39 |
+
@configurable
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
is_train=True,
|
43 |
+
*,
|
44 |
+
name,
|
45 |
+
num_queries,
|
46 |
+
meta,
|
47 |
+
augmentations,
|
48 |
+
image_format,
|
49 |
+
ignore_label,
|
50 |
+
size_divisibility,
|
51 |
+
task_seq_len,
|
52 |
+
max_seq_len,
|
53 |
+
semantic_prob,
|
54 |
+
instance_prob,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
NOTE: this interface is experimental.
|
58 |
+
Args:
|
59 |
+
is_train: for training or inference
|
60 |
+
augmentations: a list of augmentations or deterministic transforms to apply
|
61 |
+
image_format: an image format supported by :func:`detection_utils.read_image`.
|
62 |
+
ignore_label: the label that is ignored to evaluation
|
63 |
+
size_divisibility: pad image size to be divisible by this value
|
64 |
+
"""
|
65 |
+
self.is_train = is_train
|
66 |
+
self.meta = meta
|
67 |
+
self.name = name
|
68 |
+
self.tfm_gens = augmentations
|
69 |
+
self.img_format = image_format
|
70 |
+
self.ignore_label = ignore_label
|
71 |
+
self.size_divisibility = size_divisibility
|
72 |
+
self.num_queries = num_queries
|
73 |
+
|
74 |
+
logger = logging.getLogger(__name__)
|
75 |
+
mode = "training" if is_train else "inference"
|
76 |
+
logger.info(f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}")
|
77 |
+
|
78 |
+
self.things = []
|
79 |
+
for k,v in self.meta.thing_dataset_id_to_contiguous_id.items():
|
80 |
+
self.things.append(v)
|
81 |
+
self.class_names = self.meta.stuff_classes
|
82 |
+
self.text_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=max_seq_len)
|
83 |
+
self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len)
|
84 |
+
self.semantic_prob = semantic_prob
|
85 |
+
self.instance_prob = instance_prob
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def from_config(cls, cfg, is_train=True):
|
89 |
+
# Build augmentation
|
90 |
+
augs = [
|
91 |
+
T.ResizeShortestEdge(
|
92 |
+
cfg.INPUT.MIN_SIZE_TRAIN,
|
93 |
+
cfg.INPUT.MAX_SIZE_TRAIN,
|
94 |
+
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
|
95 |
+
)
|
96 |
+
]
|
97 |
+
if cfg.INPUT.CROP.ENABLED:
|
98 |
+
augs.append(
|
99 |
+
T.RandomCrop_CategoryAreaConstraint(
|
100 |
+
cfg.INPUT.CROP.TYPE,
|
101 |
+
cfg.INPUT.CROP.SIZE,
|
102 |
+
cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
|
103 |
+
cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
|
104 |
+
)
|
105 |
+
)
|
106 |
+
if cfg.INPUT.COLOR_AUG_SSD:
|
107 |
+
augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
|
108 |
+
augs.append(T.RandomFlip())
|
109 |
+
|
110 |
+
# Assume always applies to the training set.
|
111 |
+
dataset_names = cfg.DATASETS.TRAIN
|
112 |
+
meta = MetadataCatalog.get(dataset_names[0])
|
113 |
+
ignore_label = meta.ignore_label
|
114 |
+
|
115 |
+
ret = {
|
116 |
+
"is_train": is_train,
|
117 |
+
"meta": meta,
|
118 |
+
"name": dataset_names[0],
|
119 |
+
"num_queries": cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES - cfg.MODEL.TEXT_ENCODER.N_CTX,
|
120 |
+
"task_seq_len": cfg.INPUT.TASK_SEQ_LEN,
|
121 |
+
"max_seq_len": cfg.INPUT.MAX_SEQ_LEN,
|
122 |
+
"augmentations": augs,
|
123 |
+
"image_format": cfg.INPUT.FORMAT,
|
124 |
+
"ignore_label": ignore_label,
|
125 |
+
"size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY,
|
126 |
+
"semantic_prob": cfg.INPUT.TASK_PROB.SEMANTIC,
|
127 |
+
"instance_prob": cfg.INPUT.TASK_PROB.INSTANCE,
|
128 |
+
}
|
129 |
+
return ret
|
130 |
+
|
131 |
+
def _get_semantic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
132 |
+
pan_seg_gt = pan_seg_gt.numpy()
|
133 |
+
instances = Instances(image_shape)
|
134 |
+
|
135 |
+
classes = []
|
136 |
+
texts = ["a semantic photo"] * self.num_queries
|
137 |
+
masks = []
|
138 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
139 |
+
|
140 |
+
for segment_info in segments_info:
|
141 |
+
class_id = segment_info["category_id"]
|
142 |
+
if not segment_info["iscrowd"]:
|
143 |
+
mask = pan_seg_gt == segment_info["id"]
|
144 |
+
if not np.all(mask == False):
|
145 |
+
if class_id not in classes:
|
146 |
+
cls_name = self.class_names[class_id]
|
147 |
+
classes.append(class_id)
|
148 |
+
masks.append(mask)
|
149 |
+
num_class_obj[cls_name] += 1
|
150 |
+
else:
|
151 |
+
idx = classes.index(class_id)
|
152 |
+
masks[idx] += mask
|
153 |
+
masks[idx] = np.clip(masks[idx], 0, 1).astype(np.bool)
|
154 |
+
label[mask] = class_id
|
155 |
+
|
156 |
+
num = 0
|
157 |
+
for i, cls_name in enumerate(self.class_names):
|
158 |
+
if num_class_obj[cls_name] > 0:
|
159 |
+
for _ in range(num_class_obj[cls_name]):
|
160 |
+
if num >= len(texts):
|
161 |
+
break
|
162 |
+
texts[num] = f"a photo with a {cls_name}"
|
163 |
+
num += 1
|
164 |
+
|
165 |
+
classes = np.array(classes)
|
166 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
167 |
+
if len(masks) == 0:
|
168 |
+
# Some image does not have annotation (all ignored)
|
169 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
170 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
171 |
+
else:
|
172 |
+
masks = BitMasks(
|
173 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
174 |
+
)
|
175 |
+
instances.gt_masks = masks.tensor
|
176 |
+
# Placeholder bounding boxes for stuff regions. Note that these are not used during training.
|
177 |
+
instances.gt_bboxes = torch.stack([torch.tensor([0., 0., 1., 1.])] * instances.gt_masks.shape[0])
|
178 |
+
return instances, texts, label
|
179 |
+
|
180 |
+
def _get_instance_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
181 |
+
pan_seg_gt = pan_seg_gt.numpy()
|
182 |
+
instances = Instances(image_shape)
|
183 |
+
|
184 |
+
classes = []
|
185 |
+
texts = ["an instance photo"] * self.num_queries
|
186 |
+
masks = []
|
187 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
188 |
+
|
189 |
+
for segment_info in segments_info:
|
190 |
+
class_id = segment_info["category_id"]
|
191 |
+
if class_id in self.things:
|
192 |
+
if not segment_info["iscrowd"]:
|
193 |
+
mask = pan_seg_gt == segment_info["id"]
|
194 |
+
if not np.all(mask == False):
|
195 |
+
cls_name = self.class_names[class_id]
|
196 |
+
classes.append(class_id)
|
197 |
+
masks.append(mask)
|
198 |
+
num_class_obj[cls_name] += 1
|
199 |
+
label[mask] = class_id
|
200 |
+
|
201 |
+
num = 0
|
202 |
+
for i, cls_name in enumerate(self.class_names):
|
203 |
+
if num_class_obj[cls_name] > 0:
|
204 |
+
for _ in range(num_class_obj[cls_name]):
|
205 |
+
if num >= len(texts):
|
206 |
+
break
|
207 |
+
texts[num] = f"a photo with a {cls_name}"
|
208 |
+
num += 1
|
209 |
+
|
210 |
+
classes = np.array(classes)
|
211 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
212 |
+
if len(masks) == 0:
|
213 |
+
# Some image does not have annotation (all ignored)
|
214 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
215 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
216 |
+
else:
|
217 |
+
masks = BitMasks(
|
218 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
219 |
+
)
|
220 |
+
instances.gt_masks = masks.tensor
|
221 |
+
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
|
222 |
+
return instances, texts, label
|
223 |
+
|
224 |
+
def _get_panoptic_dict(self, pan_seg_gt, image_shape, segments_info, num_class_obj):
|
225 |
+
pan_seg_gt = pan_seg_gt.numpy()
|
226 |
+
instances = Instances(image_shape)
|
227 |
+
|
228 |
+
classes = []
|
229 |
+
texts = ["a panoptic photo"] * self.num_queries
|
230 |
+
masks = []
|
231 |
+
label = np.ones_like(pan_seg_gt) * self.ignore_label
|
232 |
+
|
233 |
+
for segment_info in segments_info:
|
234 |
+
class_id = segment_info["category_id"]
|
235 |
+
if not segment_info["iscrowd"]:
|
236 |
+
mask = pan_seg_gt == segment_info["id"]
|
237 |
+
if not np.all(mask == False):
|
238 |
+
cls_name = self.class_names[class_id]
|
239 |
+
classes.append(class_id)
|
240 |
+
masks.append(mask)
|
241 |
+
num_class_obj[cls_name] += 1
|
242 |
+
label[mask] = class_id
|
243 |
+
|
244 |
+
num = 0
|
245 |
+
for i, cls_name in enumerate(self.class_names):
|
246 |
+
if num_class_obj[cls_name] > 0:
|
247 |
+
for _ in range(num_class_obj[cls_name]):
|
248 |
+
if num >= len(texts):
|
249 |
+
break
|
250 |
+
texts[num] = f"a photo with a {cls_name}"
|
251 |
+
num += 1
|
252 |
+
|
253 |
+
classes = np.array(classes)
|
254 |
+
instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
|
255 |
+
if len(masks) == 0:
|
256 |
+
# Some image does not have annotation (all ignored)
|
257 |
+
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1]))
|
258 |
+
instances.gt_bboxes = torch.zeros((0, 4))
|
259 |
+
else:
|
260 |
+
masks = BitMasks(
|
261 |
+
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
|
262 |
+
)
|
263 |
+
instances.gt_masks = masks.tensor
|
264 |
+
instances.gt_bboxes = masks_to_boxes(instances.gt_masks)
|
265 |
+
for i in range(instances.gt_classes.shape[0]):
|
266 |
+
# Placeholder bounding boxes for stuff regions. Note that these are not used during training.
|
267 |
+
if instances.gt_classes[i].item() not in self.things:
|
268 |
+
instances.gt_bboxes[i] = torch.tensor([0., 0., 1., 1.])
|
269 |
+
return instances, texts, label
|
270 |
+
|
271 |
+
def __call__(self, dataset_dict):
|
272 |
+
"""
|
273 |
+
Args:
|
274 |
+
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
dict: a format that builtin models in detectron2 accept
|
278 |
+
"""
|
279 |
+
assert self.is_train, "OneFormerUnifiedDatasetMapper should only be used for training!"
|
280 |
+
|
281 |
+
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
|
282 |
+
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
|
283 |
+
utils.check_image_size(dataset_dict, image)
|
284 |
+
|
285 |
+
# semantic segmentation
|
286 |
+
if "sem_seg_file_name" in dataset_dict:
|
287 |
+
# PyTorch transformation not implemented for uint16, so converting it to double first
|
288 |
+
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
|
289 |
+
else:
|
290 |
+
sem_seg_gt = None
|
291 |
+
|
292 |
+
# panoptic segmentation
|
293 |
+
if "pan_seg_file_name" in dataset_dict:
|
294 |
+
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB")
|
295 |
+
segments_info = dataset_dict["segments_info"]
|
296 |
+
else:
|
297 |
+
pan_seg_gt = None
|
298 |
+
segments_info = None
|
299 |
+
|
300 |
+
if pan_seg_gt is None:
|
301 |
+
raise ValueError(
|
302 |
+
"Cannot find 'pan_seg_file_name' for panoptic segmentation dataset {}.".format(
|
303 |
+
dataset_dict["file_name"]
|
304 |
+
)
|
305 |
+
)
|
306 |
+
|
307 |
+
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
|
308 |
+
aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
|
309 |
+
image = aug_input.image
|
310 |
+
if sem_seg_gt is not None:
|
311 |
+
sem_seg_gt = aug_input.sem_seg
|
312 |
+
|
313 |
+
# apply the same transformation to panoptic segmentation
|
314 |
+
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt)
|
315 |
+
|
316 |
+
from panopticapi.utils import rgb2id
|
317 |
+
|
318 |
+
pan_seg_gt = rgb2id(pan_seg_gt)
|
319 |
+
|
320 |
+
# Pad image and segmentation label here!
|
321 |
+
image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
|
322 |
+
if sem_seg_gt is not None:
|
323 |
+
sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
|
324 |
+
pan_seg_gt = torch.as_tensor(pan_seg_gt.astype("long"))
|
325 |
+
|
326 |
+
if self.size_divisibility > 0:
|
327 |
+
image_size = (image.shape[-2], image.shape[-1])
|
328 |
+
padding_size = [
|
329 |
+
0,
|
330 |
+
self.size_divisibility - image_size[1],
|
331 |
+
0,
|
332 |
+
self.size_divisibility - image_size[0],
|
333 |
+
]
|
334 |
+
image = F.pad(image, padding_size, value=128).contiguous()
|
335 |
+
if sem_seg_gt is not None:
|
336 |
+
sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()
|
337 |
+
pan_seg_gt = F.pad(
|
338 |
+
pan_seg_gt, padding_size, value=0
|
339 |
+
).contiguous() # 0 is the VOID panoptic label
|
340 |
+
|
341 |
+
image_shape = (image.shape[-2], image.shape[-1]) # h, w
|
342 |
+
|
343 |
+
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
|
344 |
+
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
|
345 |
+
# Therefore it's important to use torch.Tensor.
|
346 |
+
dataset_dict["image"] = image
|
347 |
+
|
348 |
+
if "annotations" in dataset_dict:
|
349 |
+
raise ValueError("Pemantic segmentation dataset should not have 'annotations'.")
|
350 |
+
|
351 |
+
prob_task = np.random.uniform(0,1.)
|
352 |
+
|
353 |
+
num_class_obj = {}
|
354 |
+
|
355 |
+
for name in self.class_names:
|
356 |
+
num_class_obj[name] = 0
|
357 |
+
|
358 |
+
if prob_task < self.semantic_prob:
|
359 |
+
task = "The task is semantic"
|
360 |
+
instances, text, sem_seg = self._get_semantic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
361 |
+
elif prob_task < self.instance_prob:
|
362 |
+
task = "The task is instance"
|
363 |
+
instances, text, sem_seg = self._get_instance_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
364 |
+
else:
|
365 |
+
task = "The task is panoptic"
|
366 |
+
instances, text, sem_seg = self._get_panoptic_dict(pan_seg_gt, image_shape, segments_info, num_class_obj)
|
367 |
+
|
368 |
+
dataset_dict["sem_seg"] = torch.from_numpy(sem_seg).long()
|
369 |
+
dataset_dict["instances"] = instances
|
370 |
+
dataset_dict["orig_shape"] = image_shape
|
371 |
+
dataset_dict["task"] = task
|
372 |
+
dataset_dict["text"] = text
|
373 |
+
dataset_dict["thing_ids"] = self.things
|
374 |
+
|
375 |
+
return dataset_dict
|
oneformer/data/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import (
|
2 |
+
register_ade20k_panoptic,
|
3 |
+
register_cityscapes_panoptic,
|
4 |
+
register_coco_panoptic_annos_semseg,
|
5 |
+
register_ade20k_instance,
|
6 |
+
register_coco_panoptic2instance,
|
7 |
+
)
|
oneformer/data/datasets/register_ade20k_instance.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_instance.py
|
3 |
+
# ------------------------------------------------------------------------------
|
4 |
+
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import os
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
12 |
+
from detectron2.data.datasets.coco import load_coco_json, register_coco_instances
|
13 |
+
from detectron2.utils.file_io import PathManager
|
14 |
+
|
15 |
+
ADE_CATEGORIES = [{'id': 7, 'name': 'bed'}, {'id': 8, 'name': 'windowpane'}, {'id': 10, 'name': 'cabinet'}, {'id': 12, 'name': 'person'}, {'id': 14, 'name': 'door'}, {'id': 15, 'name': 'table'}, {'id': 18, 'name': 'curtain'}, {'id': 19, 'name': 'chair'}, {'id': 20, 'name': 'car'}, {'id': 22, 'name': 'painting'}, {'id': 23, 'name': 'sofa'}, {'id': 24, 'name': 'shelf'}, {'id': 27, 'name': 'mirror'}, {'id': 30, 'name': 'armchair'}, {'id': 31, 'name': 'seat'}, {'id': 32, 'name': 'fence'}, {'id': 33, 'name': 'desk'}, {'id': 35, 'name': 'wardrobe'}, {'id': 36, 'name': 'lamp'}, {'id': 37, 'name': 'bathtub'}, {'id': 38, 'name': 'railing'}, {'id': 39, 'name': 'cushion'}, {'id': 41, 'name': 'box'}, {'id': 42, 'name': 'column'}, {'id': 43, 'name': 'signboard'}, {'id': 44, 'name': 'chest of drawers'}, {'id': 45, 'name': 'counter'}, {'id': 47, 'name': 'sink'}, {'id': 49, 'name': 'fireplace'}, {'id': 50, 'name': 'refrigerator'}, {'id': 53, 'name': 'stairs'}, {'id': 55, 'name': 'case'}, {'id': 56, 'name': 'pool table'}, {'id': 57, 'name': 'pillow'}, {'id': 58, 'name': 'screen door'}, {'id': 62, 'name': 'bookcase'}, {'id': 64, 'name': 'coffee table'}, {'id': 65, 'name': 'toilet'}, {'id': 66, 'name': 'flower'}, {'id': 67, 'name': 'book'}, {'id': 69, 'name': 'bench'}, {'id': 70, 'name': 'countertop'}, {'id': 71, 'name': 'stove'}, {'id': 72, 'name': 'palm'}, {'id': 73, 'name': 'kitchen island'}, {'id': 74, 'name': 'computer'}, {'id': 75, 'name': 'swivel chair'}, {'id': 76, 'name': 'boat'}, {'id': 78, 'name': 'arcade machine'}, {'id': 80, 'name': 'bus'}, {'id': 81, 'name': 'towel'}, {'id': 82, 'name': 'light'}, {'id': 83, 'name': 'truck'}, {'id': 85, 'name': 'chandelier'}, {'id': 86, 'name': 'awning'}, {'id': 87, 'name': 'streetlight'}, {'id': 88, 'name': 'booth'}, {'id': 89, 'name': 'television receiver'}, {'id': 90, 'name': 'airplane'}, {'id': 92, 'name': 'apparel'}, {'id': 93, 'name': 'pole'}, {'id': 95, 'name': 'bannister'}, {'id': 97, 'name': 'ottoman'}, {'id': 98, 'name': 'bottle'}, {'id': 102, 'name': 'van'}, {'id': 103, 'name': 'ship'}, {'id': 104, 'name': 'fountain'}, {'id': 107, 'name': 'washer'}, {'id': 108, 'name': 'plaything'}, {'id': 110, 'name': 'stool'}, {'id': 111, 'name': 'barrel'}, {'id': 112, 'name': 'basket'}, {'id': 115, 'name': 'bag'}, {'id': 116, 'name': 'minibike'}, {'id': 118, 'name': 'oven'}, {'id': 119, 'name': 'ball'}, {'id': 120, 'name': 'food'}, {'id': 121, 'name': 'step'}, {'id': 123, 'name': 'trade name'}, {'id': 124, 'name': 'microwave'}, {'id': 125, 'name': 'pot'}, {'id': 126, 'name': 'animal'}, {'id': 127, 'name': 'bicycle'}, {'id': 129, 'name': 'dishwasher'}, {'id': 130, 'name': 'screen'}, {'id': 132, 'name': 'sculpture'}, {'id': 133, 'name': 'hood'}, {'id': 134, 'name': 'sconce'}, {'id': 135, 'name': 'vase'}, {'id': 136, 'name': 'traffic light'}, {'id': 137, 'name': 'tray'}, {'id': 138, 'name': 'ashcan'}, {'id': 139, 'name': 'fan'}, {'id': 142, 'name': 'plate'}, {'id': 143, 'name': 'monitor'}, {'id': 144, 'name': 'bulletin board'}, {'id': 146, 'name': 'radiator'}, {'id': 147, 'name': 'glass'}, {'id': 148, 'name': 'clock'}, {'id': 149, 'name': 'flag'}]
|
16 |
+
|
17 |
+
|
18 |
+
_PREDEFINED_SPLITS = {
|
19 |
+
# point annotations without masks
|
20 |
+
"ade20k_instance_train": (
|
21 |
+
"ADEChallengeData2016/images/training",
|
22 |
+
"ADEChallengeData2016/ade20k_instance_train.json",
|
23 |
+
),
|
24 |
+
"ade20k_instance_val": (
|
25 |
+
"ADEChallengeData2016/images/validation",
|
26 |
+
"ADEChallengeData2016/ade20k_instance_val.json",
|
27 |
+
),
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def _get_ade_instances_meta():
|
32 |
+
thing_ids = [k["id"] for k in ADE_CATEGORIES]
|
33 |
+
assert len(thing_ids) == 100, len(thing_ids)
|
34 |
+
# Mapping from the incontiguous ADE category id to an id in [0, 99]
|
35 |
+
thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
|
36 |
+
thing_classes = [k["name"] for k in ADE_CATEGORIES]
|
37 |
+
ret = {
|
38 |
+
"thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
|
39 |
+
"thing_classes": thing_classes,
|
40 |
+
}
|
41 |
+
return ret
|
42 |
+
|
43 |
+
|
44 |
+
def register_all_ade20k_instance(root):
|
45 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS.items():
|
46 |
+
# Assume pre-defined datasets live in `./datasets`.
|
47 |
+
register_coco_instances(
|
48 |
+
key,
|
49 |
+
_get_ade_instances_meta(),
|
50 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
51 |
+
os.path.join(root, image_root),
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
56 |
+
register_all_ade20k_instance(_root)
|
oneformer/data/datasets/register_ade20k_panoptic.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_panoptic.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
10 |
+
from detectron2.utils.file_io import PathManager
|
11 |
+
|
12 |
+
ADE20K_150_CATEGORIES = [
|
13 |
+
{"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
|
14 |
+
{"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
|
15 |
+
{"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
|
16 |
+
{"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
|
17 |
+
{"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
|
18 |
+
{"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
|
19 |
+
{"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
|
20 |
+
{"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
|
21 |
+
{"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
|
22 |
+
{"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
|
23 |
+
{"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
|
24 |
+
{"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
|
25 |
+
{"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
|
26 |
+
{"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
|
27 |
+
{"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
|
28 |
+
{"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
|
29 |
+
{"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
|
30 |
+
{"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
|
31 |
+
{"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
|
32 |
+
{"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
|
33 |
+
{"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
|
34 |
+
{"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
|
35 |
+
{"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
|
36 |
+
{"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
|
37 |
+
{"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
|
38 |
+
{"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
|
39 |
+
{"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
|
40 |
+
{"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
|
41 |
+
{"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
|
42 |
+
{"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
|
43 |
+
{"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
|
44 |
+
{"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
|
45 |
+
{"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
|
46 |
+
{"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
|
47 |
+
{"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
|
48 |
+
{"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
|
49 |
+
{"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
|
50 |
+
{"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
|
51 |
+
{"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
|
52 |
+
{"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
|
53 |
+
{"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
|
54 |
+
{"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
|
55 |
+
{"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
|
56 |
+
{"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
|
57 |
+
{
|
58 |
+
"color": [6, 51, 255],
|
59 |
+
"id": 44,
|
60 |
+
"isthing": 1,
|
61 |
+
"name": "chest of drawers, chest, bureau, dresser",
|
62 |
+
},
|
63 |
+
{"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
|
64 |
+
{"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
|
65 |
+
{"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
|
66 |
+
{"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
|
67 |
+
{"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
|
68 |
+
{"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
|
69 |
+
{"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
|
70 |
+
{"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
|
71 |
+
{"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
|
72 |
+
{"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
|
73 |
+
{"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
|
74 |
+
{
|
75 |
+
"color": [255, 71, 0],
|
76 |
+
"id": 56,
|
77 |
+
"isthing": 1,
|
78 |
+
"name": "pool table, billiard table, snooker table",
|
79 |
+
},
|
80 |
+
{"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
|
81 |
+
{"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
|
82 |
+
{"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
|
83 |
+
{"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
|
84 |
+
{"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
|
85 |
+
{"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
|
86 |
+
{"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
|
87 |
+
{"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
|
88 |
+
{
|
89 |
+
"color": [0, 255, 133],
|
90 |
+
"id": 65,
|
91 |
+
"isthing": 1,
|
92 |
+
"name": "toilet, can, commode, crapper, pot, potty, stool, throne",
|
93 |
+
},
|
94 |
+
{"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
|
95 |
+
{"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
|
96 |
+
{"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
|
97 |
+
{"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
|
98 |
+
{"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
|
99 |
+
{"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
|
100 |
+
{"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
|
101 |
+
{"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
|
102 |
+
{"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
|
103 |
+
{"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
|
104 |
+
{"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
|
105 |
+
{"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
|
106 |
+
{"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
|
107 |
+
{"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
|
108 |
+
{"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
|
109 |
+
{"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
|
110 |
+
{"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
|
111 |
+
{"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
|
112 |
+
{"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
|
113 |
+
{"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
|
114 |
+
{"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
|
115 |
+
{"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
|
116 |
+
{"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
|
117 |
+
{"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
|
118 |
+
{"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
|
119 |
+
{"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
|
120 |
+
{"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
|
121 |
+
{"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
|
122 |
+
{"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
|
123 |
+
{
|
124 |
+
"color": [0, 122, 255],
|
125 |
+
"id": 95,
|
126 |
+
"isthing": 1,
|
127 |
+
"name": "bannister, banister, balustrade, balusters, handrail",
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"color": [0, 255, 163],
|
131 |
+
"id": 96,
|
132 |
+
"isthing": 0,
|
133 |
+
"name": "escalator, moving staircase, moving stairway",
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"color": [255, 153, 0],
|
137 |
+
"id": 97,
|
138 |
+
"isthing": 1,
|
139 |
+
"name": "ottoman, pouf, pouffe, puff, hassock",
|
140 |
+
},
|
141 |
+
{"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
|
142 |
+
{"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
|
143 |
+
{
|
144 |
+
"color": [143, 255, 0],
|
145 |
+
"id": 100,
|
146 |
+
"isthing": 0,
|
147 |
+
"name": "poster, posting, placard, notice, bill, card",
|
148 |
+
},
|
149 |
+
{"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
|
150 |
+
{"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
|
151 |
+
{"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
|
152 |
+
{"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
|
153 |
+
{
|
154 |
+
"color": [133, 0, 255],
|
155 |
+
"id": 105,
|
156 |
+
"isthing": 0,
|
157 |
+
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
|
158 |
+
},
|
159 |
+
{"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
|
160 |
+
{
|
161 |
+
"color": [184, 0, 255],
|
162 |
+
"id": 107,
|
163 |
+
"isthing": 1,
|
164 |
+
"name": "washer, automatic washer, washing machine",
|
165 |
+
},
|
166 |
+
{"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
|
167 |
+
{"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
|
168 |
+
{"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
|
169 |
+
{"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
|
170 |
+
{"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
|
171 |
+
{"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
|
172 |
+
{"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
|
173 |
+
{"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
|
174 |
+
{"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
|
175 |
+
{"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
|
176 |
+
{"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
|
177 |
+
{"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
|
178 |
+
{"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
|
179 |
+
{"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
|
180 |
+
{"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
|
181 |
+
{"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
|
182 |
+
{"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
|
183 |
+
{"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
|
184 |
+
{"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
|
185 |
+
{"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
|
186 |
+
{"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
|
187 |
+
{"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
|
188 |
+
{"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
|
189 |
+
{"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
|
190 |
+
{"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
|
191 |
+
{"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
|
192 |
+
{"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
|
193 |
+
{"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
|
194 |
+
{"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
|
195 |
+
{"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
|
196 |
+
{"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
|
197 |
+
{"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
|
198 |
+
{"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
|
199 |
+
{"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
|
200 |
+
{"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
|
201 |
+
{"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
|
202 |
+
{"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
|
203 |
+
{"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
|
204 |
+
{"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
|
205 |
+
{"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
|
206 |
+
{"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
|
207 |
+
{"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
|
208 |
+
]
|
209 |
+
|
210 |
+
ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
|
211 |
+
|
212 |
+
MetadataCatalog.get("ade20k_sem_seg_train").set(
|
213 |
+
stuff_colors=ADE20k_COLORS[:],
|
214 |
+
)
|
215 |
+
|
216 |
+
MetadataCatalog.get("ade20k_sem_seg_val").set(
|
217 |
+
stuff_colors=ADE20k_COLORS[:],
|
218 |
+
)
|
219 |
+
|
220 |
+
|
221 |
+
def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
|
222 |
+
"""
|
223 |
+
Args:
|
224 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
225 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
226 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
227 |
+
Returns:
|
228 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
229 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
230 |
+
"""
|
231 |
+
|
232 |
+
def _convert_category_id(segment_info, meta):
|
233 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
234 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
235 |
+
segment_info["category_id"]
|
236 |
+
]
|
237 |
+
segment_info["isthing"] = True
|
238 |
+
else:
|
239 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
240 |
+
segment_info["category_id"]
|
241 |
+
]
|
242 |
+
segment_info["isthing"] = False
|
243 |
+
return segment_info
|
244 |
+
|
245 |
+
with PathManager.open(json_file) as f:
|
246 |
+
json_info = json.load(f)
|
247 |
+
|
248 |
+
ret = []
|
249 |
+
for ann in json_info["annotations"]:
|
250 |
+
image_id = ann["image_id"]
|
251 |
+
# TODO: currently we assume image and label has the same filename but
|
252 |
+
# different extension, and images have extension ".jpg" for COCO. Need
|
253 |
+
# to make image extension a user-provided argument if we extend this
|
254 |
+
# function to support other COCO-like datasets.
|
255 |
+
image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
|
256 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
257 |
+
sem_label_file = os.path.join(semseg_dir, ann["file_name"])
|
258 |
+
segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
|
259 |
+
ret.append(
|
260 |
+
{
|
261 |
+
"file_name": image_file,
|
262 |
+
"image_id": image_id,
|
263 |
+
"pan_seg_file_name": label_file,
|
264 |
+
"sem_seg_file_name": sem_label_file,
|
265 |
+
"segments_info": segments_info,
|
266 |
+
}
|
267 |
+
)
|
268 |
+
assert len(ret), f"No images found in {image_dir}!"
|
269 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
270 |
+
assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
|
271 |
+
assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
|
272 |
+
return ret
|
273 |
+
|
274 |
+
|
275 |
+
def register_ade20k_panoptic(
|
276 |
+
name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None,
|
277 |
+
):
|
278 |
+
"""
|
279 |
+
Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
|
280 |
+
The dictionaries in this registered dataset follows detectron2's standard format.
|
281 |
+
Hence it's called "standard".
|
282 |
+
Args:
|
283 |
+
name (str): the name that identifies a dataset,
|
284 |
+
e.g. "ade20k_panoptic_train"
|
285 |
+
metadata (dict): extra metadata associated with this dataset.
|
286 |
+
image_root (str): directory which contains all the images
|
287 |
+
panoptic_root (str): directory which contains panoptic annotation images in COCO format
|
288 |
+
panoptic_json (str): path to the json panoptic annotation file in COCO format
|
289 |
+
sem_seg_root (none): not used, to be consistent with
|
290 |
+
`register_coco_panoptic_separated`.
|
291 |
+
instances_json (str): path to the json instance annotation file
|
292 |
+
"""
|
293 |
+
panoptic_name = name
|
294 |
+
DatasetCatalog.register(
|
295 |
+
panoptic_name,
|
296 |
+
lambda: load_ade20k_panoptic_json(
|
297 |
+
panoptic_json, image_root, panoptic_root, semantic_root, metadata
|
298 |
+
),
|
299 |
+
)
|
300 |
+
MetadataCatalog.get(panoptic_name).set(
|
301 |
+
panoptic_root=panoptic_root,
|
302 |
+
image_root=image_root,
|
303 |
+
panoptic_json=panoptic_json,
|
304 |
+
json_file=instances_json,
|
305 |
+
evaluator_type="ade20k_panoptic_seg",
|
306 |
+
ignore_label=255,
|
307 |
+
label_divisor=1000,
|
308 |
+
**metadata,
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
_PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
|
313 |
+
"ade20k_panoptic_train": (
|
314 |
+
"ADEChallengeData2016/images/training",
|
315 |
+
"ADEChallengeData2016/ade20k_panoptic_train",
|
316 |
+
"ADEChallengeData2016/ade20k_panoptic_train.json",
|
317 |
+
"ADEChallengeData2016/annotations_detectron2/training",
|
318 |
+
"ADEChallengeData2016/ade20k_instance_train.json",
|
319 |
+
),
|
320 |
+
"ade20k_panoptic_val": (
|
321 |
+
"ADEChallengeData2016/images/validation",
|
322 |
+
"ADEChallengeData2016/ade20k_panoptic_val",
|
323 |
+
"ADEChallengeData2016/ade20k_panoptic_val.json",
|
324 |
+
"ADEChallengeData2016/annotations_detectron2/validation",
|
325 |
+
"ADEChallengeData2016/ade20k_instance_val.json",
|
326 |
+
),
|
327 |
+
}
|
328 |
+
|
329 |
+
|
330 |
+
def get_metadata():
|
331 |
+
meta = {}
|
332 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
333 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
334 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
335 |
+
# visualization function in D2 handles thing and class classes differently
|
336 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
337 |
+
# enable reusing existing visualization functions.
|
338 |
+
thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
|
339 |
+
thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
|
340 |
+
stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
|
341 |
+
stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
|
342 |
+
|
343 |
+
meta["thing_classes"] = thing_classes
|
344 |
+
meta["thing_colors"] = thing_colors
|
345 |
+
meta["stuff_classes"] = stuff_classes
|
346 |
+
meta["stuff_colors"] = stuff_colors
|
347 |
+
|
348 |
+
# Convert category id for training:
|
349 |
+
# category id: like semantic segmentation, it is the class id for each
|
350 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
351 |
+
# id is not always contiguous and thus we have two set of category ids:
|
352 |
+
# - original category id: category id in the original dataset, mainly
|
353 |
+
# used for evaluation.
|
354 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
355 |
+
# softmax classifier.
|
356 |
+
thing_dataset_id_to_contiguous_id = {}
|
357 |
+
stuff_dataset_id_to_contiguous_id = {}
|
358 |
+
|
359 |
+
for i, cat in enumerate(ADE20K_150_CATEGORIES):
|
360 |
+
if cat["isthing"]:
|
361 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
362 |
+
# else:
|
363 |
+
# stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
364 |
+
|
365 |
+
# in order to use sem_seg evaluator
|
366 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
367 |
+
|
368 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
369 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
370 |
+
|
371 |
+
return meta
|
372 |
+
|
373 |
+
|
374 |
+
def register_all_ade20k_panoptic(root):
|
375 |
+
metadata = get_metadata()
|
376 |
+
for (
|
377 |
+
prefix,
|
378 |
+
(image_root, panoptic_root, panoptic_json, semantic_root, instance_json),
|
379 |
+
) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
|
380 |
+
# The "standard" version of COCO panoptic segmentation dataset,
|
381 |
+
# e.g. used by Panoptic-DeepLab
|
382 |
+
register_ade20k_panoptic(
|
383 |
+
prefix,
|
384 |
+
metadata,
|
385 |
+
os.path.join(root, image_root),
|
386 |
+
os.path.join(root, panoptic_root),
|
387 |
+
os.path.join(root, semantic_root),
|
388 |
+
os.path.join(root, panoptic_json),
|
389 |
+
os.path.join(root, instance_json),
|
390 |
+
)
|
391 |
+
|
392 |
+
|
393 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
394 |
+
register_all_ade20k_panoptic(_root)
|
oneformer/data/datasets/register_cityscapes_panoptic.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/cityscapes_panoptic.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
11 |
+
from detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
|
12 |
+
from detectron2.utils.file_io import PathManager
|
13 |
+
|
14 |
+
"""
|
15 |
+
This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
|
23 |
+
files = []
|
24 |
+
# scan through the directory
|
25 |
+
cities = PathManager.ls(image_dir)
|
26 |
+
logger.info(f"{len(cities)} cities found in '{image_dir}'.")
|
27 |
+
image_dict = {}
|
28 |
+
for city in cities:
|
29 |
+
city_img_dir = os.path.join(image_dir, city)
|
30 |
+
for basename in PathManager.ls(city_img_dir):
|
31 |
+
image_file = os.path.join(city_img_dir, basename)
|
32 |
+
|
33 |
+
suffix = "_leftImg8bit.png"
|
34 |
+
assert basename.endswith(suffix), basename
|
35 |
+
basename = os.path.basename(basename)[: -len(suffix)]
|
36 |
+
|
37 |
+
image_dict[basename] = image_file
|
38 |
+
|
39 |
+
for ann in json_info["annotations"]:
|
40 |
+
image_file = image_dict.get(ann["image_id"], None)
|
41 |
+
assert image_file is not None, "No image {} found for annotation {}".format(
|
42 |
+
ann["image_id"], ann["file_name"]
|
43 |
+
)
|
44 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
45 |
+
segments_info = ann["segments_info"]
|
46 |
+
files.append((image_file, label_file, segments_info))
|
47 |
+
|
48 |
+
assert len(files), "No images found in {}".format(image_dir)
|
49 |
+
assert PathManager.isfile(files[0][0]), files[0][0]
|
50 |
+
assert PathManager.isfile(files[0][1]), files[0][1]
|
51 |
+
return files
|
52 |
+
|
53 |
+
|
54 |
+
def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
|
58 |
+
gt_dir (str): path to the raw annotations. e.g.,
|
59 |
+
"~/cityscapes/gtFine/cityscapes_panoptic_train".
|
60 |
+
gt_json (str): path to the json file. e.g.,
|
61 |
+
"~/cityscapes/gtFine/cityscapes_panoptic_train.json".
|
62 |
+
meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
|
63 |
+
and "stuff_dataset_id_to_contiguous_id" to map category ids to
|
64 |
+
contiguous ids for training.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
68 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
69 |
+
"""
|
70 |
+
|
71 |
+
def _convert_category_id(segment_info, meta):
|
72 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
73 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
74 |
+
segment_info["category_id"]
|
75 |
+
]
|
76 |
+
else:
|
77 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
78 |
+
segment_info["category_id"]
|
79 |
+
]
|
80 |
+
return segment_info
|
81 |
+
|
82 |
+
assert os.path.exists(
|
83 |
+
gt_json
|
84 |
+
), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
|
85 |
+
|
86 |
+
|
87 |
+
with open(gt_json) as f:
|
88 |
+
json_info = json.load(f)
|
89 |
+
|
90 |
+
files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
|
91 |
+
ret = []
|
92 |
+
for image_file, label_file, segments_info in files:
|
93 |
+
sem_label_file = (
|
94 |
+
image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
|
95 |
+
)
|
96 |
+
segments_info = [_convert_category_id(x, meta) for x in segments_info]
|
97 |
+
ret.append(
|
98 |
+
{
|
99 |
+
"file_name": image_file,
|
100 |
+
"image_id": "_".join(
|
101 |
+
os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
|
102 |
+
),
|
103 |
+
"sem_seg_file_name": sem_label_file,
|
104 |
+
"pan_seg_file_name": label_file,
|
105 |
+
"segments_info": segments_info,
|
106 |
+
}
|
107 |
+
)
|
108 |
+
assert len(ret), f"No images found in {image_dir}!"
|
109 |
+
assert PathManager.isfile(
|
110 |
+
ret[0]["sem_seg_file_name"]
|
111 |
+
), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
|
112 |
+
assert PathManager.isfile(
|
113 |
+
ret[0]["pan_seg_file_name"]
|
114 |
+
), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
|
115 |
+
return ret
|
116 |
+
|
117 |
+
|
118 |
+
_RAW_CITYSCAPES_PANOPTIC_SPLITS = {
|
119 |
+
"cityscapes_fine_panoptic_train": (
|
120 |
+
"cityscapes/leftImg8bit/train",
|
121 |
+
"cityscapes/gtFine/cityscapes_panoptic_train",
|
122 |
+
"cityscapes/gtFine/cityscapes_panoptic_train.json",
|
123 |
+
),
|
124 |
+
"cityscapes_fine_panoptic_val": (
|
125 |
+
"cityscapes/leftImg8bit/val",
|
126 |
+
"cityscapes/gtFine/cityscapes_panoptic_val",
|
127 |
+
"cityscapes/gtFine/cityscapes_panoptic_val.json",
|
128 |
+
),
|
129 |
+
# "cityscapes_fine_panoptic_test": not supported yet
|
130 |
+
}
|
131 |
+
|
132 |
+
|
133 |
+
def register_all_cityscapes_panoptic(root):
|
134 |
+
meta = {}
|
135 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
136 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
137 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
138 |
+
# visualization function in D2 handles thing and class classes differently
|
139 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
140 |
+
# enable reusing existing visualization functions.
|
141 |
+
thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
|
142 |
+
thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
|
143 |
+
stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
|
144 |
+
stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
|
145 |
+
|
146 |
+
meta["thing_classes"] = thing_classes
|
147 |
+
meta["thing_colors"] = thing_colors
|
148 |
+
meta["stuff_classes"] = stuff_classes
|
149 |
+
meta["stuff_colors"] = stuff_colors
|
150 |
+
|
151 |
+
# There are three types of ids in cityscapes panoptic segmentation:
|
152 |
+
# (1) category id: like semantic segmentation, it is the class id for each
|
153 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
154 |
+
# id is not always contiguous and thus we have two set of category ids:
|
155 |
+
# - original category id: category id in the original dataset, mainly
|
156 |
+
# used for evaluation.
|
157 |
+
# - contiguous category id: [0, #classes), in order to train the classifier
|
158 |
+
# (2) instance id: this id is used to differentiate different instances from
|
159 |
+
# the same category. For "stuff" classes, the instance id is always 0; for
|
160 |
+
# "thing" classes, the instance id starts from 1 and 0 is reserved for
|
161 |
+
# ignored instances (e.g. crowd annotation).
|
162 |
+
# (3) panoptic id: this is the compact id that encode both category and
|
163 |
+
# instance id by: category_id * 1000 + instance_id.
|
164 |
+
thing_dataset_id_to_contiguous_id = {}
|
165 |
+
stuff_dataset_id_to_contiguous_id = {}
|
166 |
+
|
167 |
+
for k in CITYSCAPES_CATEGORIES:
|
168 |
+
if k["isthing"] == 1:
|
169 |
+
thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
|
170 |
+
else:
|
171 |
+
stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
|
172 |
+
|
173 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
174 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
175 |
+
|
176 |
+
for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
|
177 |
+
image_dir = os.path.join(root, image_dir)
|
178 |
+
gt_dir = os.path.join(root, gt_dir)
|
179 |
+
gt_json = os.path.join(root, gt_json)
|
180 |
+
|
181 |
+
if key in DatasetCatalog.list():
|
182 |
+
DatasetCatalog.remove(key)
|
183 |
+
|
184 |
+
DatasetCatalog.register(
|
185 |
+
key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
|
186 |
+
)
|
187 |
+
MetadataCatalog.get(key).set(
|
188 |
+
panoptic_root=gt_dir,
|
189 |
+
image_root=image_dir,
|
190 |
+
panoptic_json=gt_json,
|
191 |
+
gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
|
192 |
+
evaluator_type="cityscapes_panoptic_seg",
|
193 |
+
ignore_label=255,
|
194 |
+
label_divisor=1000,
|
195 |
+
**meta,
|
196 |
+
)
|
197 |
+
|
198 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
199 |
+
register_all_cityscapes_panoptic(_root)
|
oneformer/data/datasets/register_coco_panoptic2instance.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/builtin.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
|
7 |
+
"""
|
8 |
+
This file registers pre-defined datasets at hard-coded paths, and their metadata.
|
9 |
+
|
10 |
+
We hard-code metadata for common datasets. This will enable:
|
11 |
+
1. Consistency check when loading the datasets
|
12 |
+
2. Use models on these standard datasets directly and run demos,
|
13 |
+
without having to download the dataset annotations
|
14 |
+
|
15 |
+
We hard-code some paths to the dataset that's assumed to
|
16 |
+
exist in "./datasets/".
|
17 |
+
|
18 |
+
Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
|
19 |
+
To add new dataset, refer to the tutorial "docs/DATASETS.md".
|
20 |
+
"""
|
21 |
+
|
22 |
+
import os
|
23 |
+
from detectron2.data.datasets.builtin_meta import _get_builtin_metadata
|
24 |
+
from detectron2.data.datasets.coco import register_coco_instances
|
25 |
+
|
26 |
+
|
27 |
+
_PREDEFINED_SPLITS_COCO = {
|
28 |
+
"coco_2017_val_panoptic2instance": ("coco/val2017", "coco/annotations/panoptic2instances_val2017.json"),
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
def register_panoptic2instances_coco(root):
|
33 |
+
for key, (image_root, json_file) in _PREDEFINED_SPLITS_COCO.items():
|
34 |
+
# Assume pre-defined datasets live in `./datasets`.
|
35 |
+
register_coco_instances(
|
36 |
+
key,
|
37 |
+
_get_builtin_metadata("coco"),
|
38 |
+
os.path.join(root, json_file) if "://" not in json_file else json_file,
|
39 |
+
os.path.join(root, image_root),
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
_root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
|
44 |
+
register_panoptic2instances_coco(_root)
|
oneformer/data/datasets/register_coco_panoptic_annos_semseg.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py
|
3 |
+
# Modified by Jitesh Jain (https://github.com/praeclarumjj3)
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
10 |
+
from detectron2.data.datasets import load_sem_seg
|
11 |
+
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
|
12 |
+
from detectron2.utils.file_io import PathManager
|
13 |
+
import contextlib
|
14 |
+
import logging
|
15 |
+
import io
|
16 |
+
from fvcore.common.timer import Timer
|
17 |
+
import pycocotools.mask as mask_util
|
18 |
+
from detectron2.structures import BoxMode
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
_PREDEFINED_SPLITS_COCO_PANOPTIC = {
|
25 |
+
"coco_2017_train_panoptic": (
|
26 |
+
# This is the original panoptic annotation directory
|
27 |
+
"coco/panoptic_train2017",
|
28 |
+
"coco/annotations/panoptic_train2017.json",
|
29 |
+
# This directory contains semantic annotations that are
|
30 |
+
# converted from panoptic annotations.
|
31 |
+
# It is used by PanopticFPN.
|
32 |
+
# You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
|
33 |
+
# to create these directories.
|
34 |
+
"coco/panoptic_semseg_train2017",
|
35 |
+
),
|
36 |
+
"coco_2017_val_panoptic": (
|
37 |
+
"coco/panoptic_val2017",
|
38 |
+
"coco/annotations/panoptic_val2017.json",
|
39 |
+
"coco/panoptic_semseg_val2017",
|
40 |
+
),
|
41 |
+
}
|
42 |
+
|
43 |
+
def load_coco_instance_json(json_file, image_root, dataset_name=None):
|
44 |
+
from pycocotools.coco import COCO
|
45 |
+
|
46 |
+
timer = Timer()
|
47 |
+
json_file = PathManager.get_local_path(json_file)
|
48 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
49 |
+
coco_api = COCO(json_file)
|
50 |
+
if timer.seconds() > 1:
|
51 |
+
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
52 |
+
|
53 |
+
id_map = None
|
54 |
+
if dataset_name is not None:
|
55 |
+
meta = MetadataCatalog.get(dataset_name)
|
56 |
+
cat_ids = sorted(coco_api.getCatIds())
|
57 |
+
cats = coco_api.loadCats(cat_ids)
|
58 |
+
# The categories in a custom json file may not be sorted.
|
59 |
+
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
|
60 |
+
meta.thing_classes = thing_classes
|
61 |
+
|
62 |
+
# In COCO, certain category ids are artificially removed,
|
63 |
+
# and by convention they are always ignored.
|
64 |
+
# We deal with COCO's id issue and translate
|
65 |
+
# the category ids to contiguous ids in [0, 80).
|
66 |
+
|
67 |
+
# It works by looking at the "categories" field in the json, therefore
|
68 |
+
# if users' own json also have incontiguous ids, we'll
|
69 |
+
# apply this mapping as well but print a warning.
|
70 |
+
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
|
71 |
+
if "coco" not in dataset_name:
|
72 |
+
logger.warning(
|
73 |
+
"""
|
74 |
+
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
|
75 |
+
"""
|
76 |
+
)
|
77 |
+
id_map = {v: i for i, v in enumerate(cat_ids)}
|
78 |
+
meta.thing_dataset_id_to_contiguous_id = id_map
|
79 |
+
|
80 |
+
# sort indices for reproducible results
|
81 |
+
img_ids = sorted(coco_api.imgs.keys())
|
82 |
+
# imgs is a list of dicts, each looks something like:
|
83 |
+
# {'license': 4,
|
84 |
+
# 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
|
85 |
+
# 'file_name': 'COCO_val2014_000000001268.jpg',
|
86 |
+
# 'height': 427,
|
87 |
+
# 'width': 640,
|
88 |
+
# 'date_captured': '2013-11-17 05:57:24',
|
89 |
+
# 'id': 1268}
|
90 |
+
imgs = coco_api.loadImgs(img_ids)
|
91 |
+
# anns is a list[list[dict]], where each dict is an annotation
|
92 |
+
# record for an object. The inner list enumerates the objects in an image
|
93 |
+
# and the outer list enumerates over images. Example of anns[0]:
|
94 |
+
# [{'segmentation': [[192.81,
|
95 |
+
# 247.09,
|
96 |
+
# ...
|
97 |
+
# 219.03,
|
98 |
+
# 249.06]],
|
99 |
+
# 'area': 1035.749,
|
100 |
+
# 'iscrowd': 0,
|
101 |
+
# 'image_id': 1268,
|
102 |
+
# 'bbox': [192.81, 224.8, 74.73, 33.43],
|
103 |
+
# 'category_id': 16,
|
104 |
+
# 'id': 42986},
|
105 |
+
# ...]
|
106 |
+
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
107 |
+
total_num_valid_anns = sum([len(x) for x in anns])
|
108 |
+
total_num_anns = len(coco_api.anns)
|
109 |
+
if total_num_valid_anns < total_num_anns:
|
110 |
+
logger.warning(
|
111 |
+
f"{json_file} contains {total_num_anns} annotations, but only "
|
112 |
+
f"{total_num_valid_anns} of them match to images in the file."
|
113 |
+
)
|
114 |
+
|
115 |
+
if "minival" not in json_file:
|
116 |
+
# The popular valminusminival & minival annotations for COCO2014 contain this bug.
|
117 |
+
# However the ratio of buggy annotations there is tiny and does not affect accuracy.
|
118 |
+
# Therefore we explicitly white-list them.
|
119 |
+
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
|
120 |
+
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
|
121 |
+
json_file
|
122 |
+
)
|
123 |
+
|
124 |
+
imgs_anns = list(zip(imgs, anns))
|
125 |
+
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
|
126 |
+
|
127 |
+
dataset_dicts = {}
|
128 |
+
|
129 |
+
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"]
|
130 |
+
|
131 |
+
num_instances_without_valid_segmentation = 0
|
132 |
+
|
133 |
+
for (img_dict, anno_dict_list) in imgs_anns:
|
134 |
+
record = {}
|
135 |
+
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
|
136 |
+
record["height"] = img_dict["height"]
|
137 |
+
record["width"] = img_dict["width"]
|
138 |
+
image_id = record["image_id"] = img_dict["id"]
|
139 |
+
|
140 |
+
objs = []
|
141 |
+
for anno in anno_dict_list:
|
142 |
+
# Check that the image_id in this annotation is the same as
|
143 |
+
# the image_id we're looking at.
|
144 |
+
# This fails only when the data parsing logic or the annotation file is buggy.
|
145 |
+
|
146 |
+
# The original COCO valminusminival2014 & minival2014 annotation files
|
147 |
+
# actually contains bugs that, together with certain ways of using COCO API,
|
148 |
+
# can trigger this assertion.
|
149 |
+
assert anno["image_id"] == image_id
|
150 |
+
|
151 |
+
assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
|
152 |
+
|
153 |
+
obj = {key: anno[key] for key in ann_keys if key in anno}
|
154 |
+
if "bbox" in obj and len(obj["bbox"]) == 0:
|
155 |
+
raise ValueError(
|
156 |
+
f"One annotation of image {image_id} contains empty 'bbox' value! "
|
157 |
+
"This json does not have valid COCO format."
|
158 |
+
)
|
159 |
+
|
160 |
+
segm = anno.get("segmentation", None)
|
161 |
+
if segm: # either list[list[float]] or dict(RLE)
|
162 |
+
if isinstance(segm, dict):
|
163 |
+
if isinstance(segm["counts"], list):
|
164 |
+
# convert to compressed RLE
|
165 |
+
segm = mask_util.frPyObjects(segm, *segm["size"])
|
166 |
+
else:
|
167 |
+
# filter out invalid polygons (< 3 points)
|
168 |
+
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
|
169 |
+
if len(segm) == 0:
|
170 |
+
num_instances_without_valid_segmentation += 1
|
171 |
+
continue # ignore this instance
|
172 |
+
obj["segmentation"] = segm
|
173 |
+
|
174 |
+
keypts = anno.get("keypoints", None)
|
175 |
+
if keypts: # list[int]
|
176 |
+
for idx, v in enumerate(keypts):
|
177 |
+
if idx % 3 != 2:
|
178 |
+
# COCO's segmentation coordinates are floating points in [0, H or W],
|
179 |
+
# but keypoint coordinates are integers in [0, H-1 or W-1]
|
180 |
+
# Therefore we assume the coordinates are "pixel indices" and
|
181 |
+
# add 0.5 to convert to floating point coordinates.
|
182 |
+
keypts[idx] = v + 0.5
|
183 |
+
obj["keypoints"] = keypts
|
184 |
+
|
185 |
+
obj["bbox_mode"] = BoxMode.XYWH_ABS
|
186 |
+
if id_map:
|
187 |
+
annotation_category_id = obj["category_id"]
|
188 |
+
try:
|
189 |
+
obj["category_id"] = id_map[annotation_category_id]
|
190 |
+
except KeyError as e:
|
191 |
+
raise KeyError(
|
192 |
+
f"Encountered category_id={annotation_category_id} "
|
193 |
+
"but this id does not exist in 'categories' of the json file."
|
194 |
+
) from e
|
195 |
+
objs.append(obj)
|
196 |
+
record["annotations"] = objs
|
197 |
+
dataset_dicts[image_id] = record
|
198 |
+
|
199 |
+
if num_instances_without_valid_segmentation > 0:
|
200 |
+
logger.warning(
|
201 |
+
"Filtered out {} instances without valid segmentation. ".format(
|
202 |
+
num_instances_without_valid_segmentation
|
203 |
+
)
|
204 |
+
+ "There might be issues in your dataset generation process. Please "
|
205 |
+
"check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
|
206 |
+
)
|
207 |
+
return dataset_dicts
|
208 |
+
|
209 |
+
def get_metadata():
|
210 |
+
meta = {}
|
211 |
+
# The following metadata maps contiguous id from [0, #thing categories +
|
212 |
+
# #stuff categories) to their names and colors. We have to replica of the
|
213 |
+
# same name and color under "thing_*" and "stuff_*" because the current
|
214 |
+
# visualization function in D2 handles thing and class classes differently
|
215 |
+
# due to some heuristic used in Panoptic FPN. We keep the same naming to
|
216 |
+
# enable reusing existing visualization functions.
|
217 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
218 |
+
thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
|
219 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES]
|
220 |
+
stuff_colors = [k["color"] for k in COCO_CATEGORIES]
|
221 |
+
|
222 |
+
meta["thing_classes"] = thing_classes
|
223 |
+
meta["thing_colors"] = thing_colors
|
224 |
+
meta["stuff_classes"] = stuff_classes
|
225 |
+
meta["stuff_colors"] = stuff_colors
|
226 |
+
|
227 |
+
# Convert category id for training:
|
228 |
+
# category id: like semantic segmentation, it is the class id for each
|
229 |
+
# pixel. Since there are some classes not used in evaluation, the category
|
230 |
+
# id is not always contiguous and thus we have two set of category ids:
|
231 |
+
# - original category id: category id in the original dataset, mainly
|
232 |
+
# used for evaluation.
|
233 |
+
# - contiguous category id: [0, #classes), in order to train the linear
|
234 |
+
# softmax classifier.
|
235 |
+
thing_dataset_id_to_contiguous_id = {}
|
236 |
+
stuff_dataset_id_to_contiguous_id = {}
|
237 |
+
|
238 |
+
for i, cat in enumerate(COCO_CATEGORIES):
|
239 |
+
if cat["isthing"]:
|
240 |
+
thing_dataset_id_to_contiguous_id[cat["id"]] = i
|
241 |
+
# else:
|
242 |
+
# stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
243 |
+
|
244 |
+
# in order to use sem_seg evaluator
|
245 |
+
stuff_dataset_id_to_contiguous_id[cat["id"]] = i
|
246 |
+
|
247 |
+
meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
|
248 |
+
meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
|
249 |
+
|
250 |
+
return meta
|
251 |
+
|
252 |
+
|
253 |
+
def load_coco_panoptic_json(json_file, instances_json, instances_name, image_dir, gt_dir, semseg_dir, meta):
|
254 |
+
"""
|
255 |
+
Args:
|
256 |
+
image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
|
257 |
+
gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
|
258 |
+
json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
|
259 |
+
Returns:
|
260 |
+
list[dict]: a list of dicts in Detectron2 standard format. (See
|
261 |
+
`Using Custom Datasets </tutorials/datasets.html>`_ )
|
262 |
+
"""
|
263 |
+
|
264 |
+
def _convert_category_id(segment_info, meta):
|
265 |
+
if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
|
266 |
+
segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
|
267 |
+
segment_info["category_id"]
|
268 |
+
]
|
269 |
+
segment_info["isthing"] = True
|
270 |
+
else:
|
271 |
+
segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
|
272 |
+
segment_info["category_id"]
|
273 |
+
]
|
274 |
+
segment_info["isthing"] = False
|
275 |
+
return segment_info
|
276 |
+
|
277 |
+
with PathManager.open(json_file) as f:
|
278 |
+
json_info = json.load(f)
|
279 |
+
|
280 |
+
instance_data_dicts = load_coco_instance_json(instances_json, image_dir.replace("panoptic_", ""), instances_name)
|
281 |
+
|
282 |
+
ret = []
|
283 |
+
for ann in json_info["annotations"]:
|
284 |
+
image_id = int(ann["image_id"])
|
285 |
+
# TODO: currently we assume image and label has the same filename but
|
286 |
+
# different extension, and images have extension ".jpg" for COCO. Need
|
287 |
+
# to make image extension a user-provided argument if we extend this
|
288 |
+
# function to support other COCO-like datasets.
|
289 |
+
image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
|
290 |
+
label_file = os.path.join(gt_dir, ann["file_name"])
|
291 |
+
sem_label_file = os.path.join(semseg_dir, ann["file_name"])
|
292 |
+
segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
|
293 |
+
ret.append(
|
294 |
+
{
|
295 |
+
"file_name": image_file,
|
296 |
+
"image_id": image_id,
|
297 |
+
"pan_seg_file_name": label_file,
|
298 |
+
"sem_seg_file_name": sem_label_file,
|
299 |
+
"segments_info": segments_info,
|
300 |
+
"annotations": instance_data_dicts[image_id]["annotations"],
|
301 |
+
}
|
302 |
+
)
|
303 |
+
assert len(ret), f"No images found in {image_dir}!"
|
304 |
+
assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
|
305 |
+
assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
|
306 |
+
assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
|
307 |
+
return ret
|
308 |
+
|
309 |
+
|
310 |
+
def register_coco_panoptic_annos_sem_seg(
|
311 |
+
name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json, instances_name,
|
312 |
+
):
|
313 |
+
panoptic_name = name
|
314 |
+
delattr(MetadataCatalog.get(panoptic_name), "thing_classes")
|
315 |
+
delattr(MetadataCatalog.get(panoptic_name), "thing_colors")
|
316 |
+
MetadataCatalog.get(panoptic_name).set(
|
317 |
+
thing_classes=metadata["thing_classes"],
|
318 |
+
thing_colors=metadata["thing_colors"],
|
319 |
+
# thing_dataset_id_to_contiguous_id=metadata["thing_dataset_id_to_contiguous_id"],
|
320 |
+
)
|
321 |
+
|
322 |
+
# the name is "coco_2017_train_panoptic_with_sem_seg" and "coco_2017_val_panoptic_with_sem_seg"
|
323 |
+
semantic_name = name + "_with_sem_seg"
|
324 |
+
DatasetCatalog.register(
|
325 |
+
semantic_name,
|
326 |
+
lambda: load_coco_panoptic_json(panoptic_json, instances_json, instances_name, image_root, panoptic_root, sem_seg_root, metadata),
|
327 |
+
)
|
328 |
+
MetadataCatalog.get(semantic_name).set(
|
329 |
+
sem_seg_root=sem_seg_root,
|
330 |
+
panoptic_root=panoptic_root,
|
331 |
+
image_root=image_root,
|
332 |
+
panoptic_json=panoptic_json,
|
333 |
+
json_file=instances_json,
|
334 |
+
evaluator_type="coco_panoptic_seg",
|
335 |
+
ignore_label=255,
|
336 |
+
label_divisor=1000,
|
337 |
+
**metadata,
|
338 |
+
)
|
339 |
+
|
340 |
+
|
341 |
+
def register_all_coco_panoptic_annos_sem_seg(root):
|
342 |
+
for (
|
343 |
+
prefix,
|
344 |
+
(panoptic_root, panoptic_json, semantic_root),
|
345 |
+
) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
|
346 |
+
|
347 |
+
prefix_instances = prefix[: -len("_panoptic")]
|
348 |
+
instances_meta = MetadataCatalog.get(prefix_instances)
|
349 |
+
image_root, instances_json = instances_meta.image_root, instances_meta.json_file
|
350 |
+
|
351 |
+
if 'val' in instances_json:
|
352 |
+
instances_json = instances_json.replace('instances_', 'panoptic2instances_')
|
353 |
+
|
354 |
+
register_coco_panoptic_annos_sem_seg(
|
355 |
+
prefix,
|
356 |
+
get_metadata(),
|
357 |
+
image_root,
|
358 |
+
os.path.join(root, panoptic_root),
|
359 |
+
os.path.join(root, panoptic_json),
|
360 |
+
os.path.join(root, semantic_root),
|
361 |
+
instances_json,
|
362 |
+
prefix_instances,
|
363 |
+
)
|
364 |
+
|
365 |
+
|
366 |
+
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
|
367 |
+
register_all_coco_panoptic_annos_sem_seg(_root)
|
oneformer/data/tokenizer.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -------------------------------------------------------------------------
|
2 |
+
# MIT License
|
3 |
+
#
|
4 |
+
# Copyright (c) 2021 OpenAI
|
5 |
+
#
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
#
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
#
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
#
|
24 |
+
# Modified by Jiarui Xu
|
25 |
+
# -------------------------------------------------------------------------
|
26 |
+
|
27 |
+
import gzip
|
28 |
+
import html
|
29 |
+
import os
|
30 |
+
from functools import lru_cache
|
31 |
+
|
32 |
+
import ftfy
|
33 |
+
import regex as re
|
34 |
+
import torch
|
35 |
+
|
36 |
+
|
37 |
+
@lru_cache()
|
38 |
+
def default_bpe():
|
39 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bpe_simple_vocab_16e6.txt')
|
40 |
+
|
41 |
+
@lru_cache()
|
42 |
+
def bytes_to_unicode():
|
43 |
+
"""Returns list of utf-8 byte and a corresponding list of unicode strings.
|
44 |
+
|
45 |
+
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
|
46 |
+
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent
|
47 |
+
coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables
|
48 |
+
between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.
|
49 |
+
"""
|
50 |
+
bs = list(range(ord('!'), ord('~') + 1)) + list(range(ord('¡'), ord('¬') + 1)) + list(range(ord('®'), ord('ÿ') + 1))
|
51 |
+
cs = bs[:]
|
52 |
+
n = 0
|
53 |
+
for b in range(2**8):
|
54 |
+
if b not in bs:
|
55 |
+
bs.append(b)
|
56 |
+
cs.append(2**8 + n)
|
57 |
+
n += 1
|
58 |
+
cs = [chr(n) for n in cs]
|
59 |
+
return dict(zip(bs, cs))
|
60 |
+
|
61 |
+
|
62 |
+
def get_pairs(word):
|
63 |
+
"""Return set of symbol pairs in a word.
|
64 |
+
|
65 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
66 |
+
"""
|
67 |
+
pairs = set()
|
68 |
+
prev_char = word[0]
|
69 |
+
for char in word[1:]:
|
70 |
+
pairs.add((prev_char, char))
|
71 |
+
prev_char = char
|
72 |
+
return pairs
|
73 |
+
|
74 |
+
|
75 |
+
def basic_clean(text):
|
76 |
+
text = ftfy.fix_text(text)
|
77 |
+
text = html.unescape(html.unescape(text))
|
78 |
+
return text.strip()
|
79 |
+
|
80 |
+
|
81 |
+
def whitespace_clean(text):
|
82 |
+
text = re.sub(r'\s+', ' ', text)
|
83 |
+
text = text.strip()
|
84 |
+
return text
|
85 |
+
|
86 |
+
class Tokenize:
|
87 |
+
|
88 |
+
def __init__(self, tokenizer, max_seq_len=77, truncate=True):
|
89 |
+
self.tokenizer = tokenizer
|
90 |
+
self.max_seq_len = max_seq_len
|
91 |
+
self.truncate = truncate
|
92 |
+
|
93 |
+
def __call__(self, texts):
|
94 |
+
expanded_dim = False
|
95 |
+
if isinstance(texts, str):
|
96 |
+
texts = [texts]
|
97 |
+
expanded_dim = True
|
98 |
+
|
99 |
+
sot_token = self.tokenizer.encoder['<|startoftext|>']
|
100 |
+
eot_token = self.tokenizer.encoder['<|endoftext|>']
|
101 |
+
all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] for text in texts]
|
102 |
+
result = torch.zeros(len(all_tokens), self.max_seq_len, dtype=torch.long)
|
103 |
+
|
104 |
+
for i, tokens in enumerate(all_tokens):
|
105 |
+
if len(tokens) > self.max_seq_len:
|
106 |
+
if self.truncate:
|
107 |
+
tokens = tokens[:self.max_seq_len]
|
108 |
+
tokens[-1] = eot_token
|
109 |
+
else:
|
110 |
+
raise RuntimeError(f'Input {texts[i]} is too long for context length {self.max_seq_len}')
|
111 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
112 |
+
|
113 |
+
if expanded_dim:
|
114 |
+
return result[0]
|
115 |
+
|
116 |
+
return result
|
117 |
+
|
118 |
+
|
119 |
+
class SimpleTokenizer(object):
|
120 |
+
|
121 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
122 |
+
self.byte_encoder = bytes_to_unicode()
|
123 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
124 |
+
|
125 |
+
with open(bpe_path) as f:
|
126 |
+
contents = f.readlines()
|
127 |
+
merges = []
|
128 |
+
for cnt in contents:
|
129 |
+
merges.append(cnt.split('\n')[0])
|
130 |
+
merges.append("")
|
131 |
+
|
132 |
+
# merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
133 |
+
merges = merges[1:49152 - 256 - 2 + 1]
|
134 |
+
merges = [tuple(merge.split()) for merge in merges]
|
135 |
+
vocab = list(bytes_to_unicode().values())
|
136 |
+
vocab = vocab + [v + '</w>' for v in vocab]
|
137 |
+
for merge in merges:
|
138 |
+
vocab.append(''.join(merge))
|
139 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
140 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
141 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
142 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
143 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
144 |
+
self.pat = re.compile(
|
145 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
146 |
+
re.IGNORECASE)
|
147 |
+
|
148 |
+
def bpe(self, token):
|
149 |
+
if token in self.cache:
|
150 |
+
return self.cache[token]
|
151 |
+
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
152 |
+
pairs = get_pairs(word)
|
153 |
+
|
154 |
+
if not pairs:
|
155 |
+
return token + '</w>'
|
156 |
+
|
157 |
+
while True:
|
158 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
159 |
+
if bigram not in self.bpe_ranks:
|
160 |
+
break
|
161 |
+
first, second = bigram
|
162 |
+
new_word = []
|
163 |
+
i = 0
|
164 |
+
while i < len(word):
|
165 |
+
try:
|
166 |
+
j = word.index(first, i)
|
167 |
+
new_word.extend(word[i:j])
|
168 |
+
i = j
|
169 |
+
except: # noqa: E722
|
170 |
+
new_word.extend(word[i:])
|
171 |
+
break
|
172 |
+
|
173 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
174 |
+
new_word.append(first + second)
|
175 |
+
i += 2
|
176 |
+
else:
|
177 |
+
new_word.append(word[i])
|
178 |
+
i += 1
|
179 |
+
new_word = tuple(new_word)
|
180 |
+
word = new_word
|
181 |
+
if len(word) == 1:
|
182 |
+
break
|
183 |
+
else:
|
184 |
+
pairs = get_pairs(word)
|
185 |
+
word = ' '.join(word)
|
186 |
+
self.cache[token] = word
|
187 |
+
return word
|
188 |
+
|
189 |
+
def encode(self, text):
|
190 |
+
bpe_tokens = []
|
191 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
192 |
+
for token in re.findall(self.pat, text):
|
193 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
194 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
195 |
+
return bpe_tokens
|
196 |
+
|
197 |
+
def decode(self, tokens):
|
198 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
199 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace').replace('</w>', ' ')
|
200 |
+
return text
|
oneformer/evaluation/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .detection_coco_evaluator import *
|
2 |
+
from .coco_evaluator import *
|
3 |
+
from .cityscapes_evaluation import CityscapesInstanceEvaluator
|