Spaces:
Running
Running
Commit
·
9665c2c
0
Parent(s):
Release
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .flake8 +3 -0
- .gitignore +138 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE +1 -0
- README.md +229 -0
- assets/demo.jpg +0 -0
- assets/teaser.svg +0 -0
- demo.ipynb +0 -0
- maploc/__init__.py +28 -0
- maploc/conf/__init__.py +0 -0
- maploc/conf/data/__init__.py +0 -0
- maploc/conf/data/kitti.yaml +29 -0
- maploc/conf/data/mapillary.yaml +40 -0
- maploc/conf/model/image_encoder/global.yaml +9 -0
- maploc/conf/model/image_encoder/resnet_fpn.yaml +7 -0
- maploc/conf/model/image_encoder/vgg_unet.yaml +8 -0
- maploc/conf/orienternet.yaml +34 -0
- maploc/conf/overfit.yaml +17 -0
- maploc/conf/training.yaml +22 -0
- maploc/data/__init__.py +4 -0
- maploc/data/dataset.py +264 -0
- maploc/data/image.py +140 -0
- maploc/data/kitti/dataset.py +306 -0
- maploc/data/kitti/prepare.py +123 -0
- maploc/data/kitti/test1_files.txt +0 -0
- maploc/data/kitti/test2_files.txt +0 -0
- maploc/data/kitti/train_files.txt +0 -0
- maploc/data/kitti/utils.py +79 -0
- maploc/data/mapillary/dataset.py +350 -0
- maploc/data/mapillary/download.py +180 -0
- maploc/data/mapillary/prepare.py +406 -0
- maploc/data/mapillary/splits_MGL_13loc.json +0 -0
- maploc/data/mapillary/utils.py +173 -0
- maploc/data/sequential.py +61 -0
- maploc/data/torch.py +111 -0
- maploc/data/utils.py +60 -0
- maploc/demo.py +209 -0
- maploc/evaluation/kitti.py +89 -0
- maploc/evaluation/mapillary.py +111 -0
- maploc/evaluation/run.py +252 -0
- maploc/evaluation/utils.py +40 -0
- maploc/evaluation/viz.py +178 -0
- maploc/models/__init__.py +34 -0
- maploc/models/base.py +123 -0
- maploc/models/bev_net.py +61 -0
- maploc/models/bev_projection.py +91 -0
- maploc/models/feature_extractor.py +231 -0
- maploc/models/feature_extractor_v2.py +192 -0
- maploc/models/map_encoder.py +66 -0
.flake8
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 88
|
3 |
+
extend-ignore = E203
|
.gitignore
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets/
|
2 |
+
experiments/
|
3 |
+
outputs/
|
4 |
+
*.mp4
|
5 |
+
lsf*
|
6 |
+
.DS_Store
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
pip-wheel-metadata/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
.python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
136 |
+
|
137 |
+
# vscode
|
138 |
+
.vscode
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Code of Conduct
|
2 |
+
|
3 |
+
## Our Pledge
|
4 |
+
|
5 |
+
In the interest of fostering an open and welcoming environment, we as
|
6 |
+
contributors and maintainers pledge to make participation in our project and
|
7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
9 |
+
level of experience, education, socio-economic status, nationality, personal
|
10 |
+
appearance, race, religion, or sexual identity and orientation.
|
11 |
+
|
12 |
+
## Our Standards
|
13 |
+
|
14 |
+
Examples of behavior that contributes to creating a positive environment
|
15 |
+
include:
|
16 |
+
|
17 |
+
* Using welcoming and inclusive language
|
18 |
+
* Being respectful of differing viewpoints and experiences
|
19 |
+
* Gracefully accepting constructive criticism
|
20 |
+
* Focusing on what is best for the community
|
21 |
+
* Showing empathy towards other community members
|
22 |
+
|
23 |
+
Examples of unacceptable behavior by participants include:
|
24 |
+
|
25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
26 |
+
advances
|
27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
28 |
+
* Public or private harassment
|
29 |
+
* Publishing others' private information, such as a physical or electronic
|
30 |
+
address, without explicit permission
|
31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
32 |
+
professional setting
|
33 |
+
|
34 |
+
## Our Responsibilities
|
35 |
+
|
36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
38 |
+
response to any instances of unacceptable behavior.
|
39 |
+
|
40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
44 |
+
threatening, offensive, or harmful.
|
45 |
+
|
46 |
+
## Scope
|
47 |
+
|
48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
49 |
+
an individual is representing the project or its community in public spaces.
|
50 |
+
Examples of representing a project or community include using an official
|
51 |
+
project e-mail address, posting via an official social media account, or acting
|
52 |
+
as an appointed representative at an online or offline event. Representation of
|
53 |
+
a project may be further defined and clarified by project maintainers.
|
54 |
+
|
55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
57 |
+
the project or its community.
|
58 |
+
|
59 |
+
## Enforcement
|
60 |
+
|
61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
62 |
+
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
63 |
+
complaints will be reviewed and investigated and will result in a response that
|
64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
66 |
+
Further details of specific enforcement policies may be posted separately.
|
67 |
+
|
68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
69 |
+
faith may face temporary or permanent repercussions as determined by other
|
70 |
+
members of the project's leadership.
|
71 |
+
|
72 |
+
## Attribution
|
73 |
+
|
74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
76 |
+
|
77 |
+
[homepage]: https://www.contributor-covenant.org
|
78 |
+
|
79 |
+
For answers to common questions about this code of conduct, see
|
80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to OrienterNet
|
2 |
+
We want to make contributing to this project as easy and transparent as
|
3 |
+
possible.
|
4 |
+
|
5 |
+
## Pull Requests
|
6 |
+
We actively welcome your pull requests.
|
7 |
+
|
8 |
+
1. Fork the repo and create your branch from `main`.
|
9 |
+
2. If you've added code that should be tested, add tests.
|
10 |
+
3. If you've changed APIs, update the documentation.
|
11 |
+
4. Ensure the test suite passes.
|
12 |
+
5. Make sure your code lints.
|
13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
14 |
+
|
15 |
+
## Contributor License Agreement ("CLA")
|
16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
17 |
+
to do this once to work on any of Facebook's open source projects.
|
18 |
+
|
19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
20 |
+
|
21 |
+
## Issues
|
22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
24 |
+
|
25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
26 |
+
disclosure of security bugs. In those cases, please go through the process
|
27 |
+
outlined on that page and do not file a public issue.
|
28 |
+
|
29 |
+
## License
|
30 |
+
By contributing to OrienterNet, you agree that your contributions will be licensed
|
31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The MGL dataset is made available under the CC-BY-SA license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a CC-BY-NC license. OpenStreetMap data follows its own license.
|
README.md
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h1 align="center"><ins>OrienterNet</ins><br>Visual Localization in 2D Public Maps<br>with Neural Matching</h1>
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://psarlin.com/">Paul-Edouard Sarlin</a>
|
5 |
+
·
|
6 |
+
<a href="https://danieldetone.com/">Daniel DeTone</a>
|
7 |
+
·
|
8 |
+
<a href="https://scholar.google.com/citations?user=WhISCE4AAAAJ&hl=en">Tsun-Yi Yang</a>
|
9 |
+
·
|
10 |
+
<a href="https://scholar.google.com/citations?user=Ta4TDJoAAAAJ&hl=en">Armen Avetisyan</a>
|
11 |
+
·
|
12 |
+
<a href="https://scholar.google.com/citations?hl=en&user=49_cCT8AAAAJ">Julian Straub</a>
|
13 |
+
<br>
|
14 |
+
<a href="https://tom.ai/">Tomasz Malisiewicz</a>
|
15 |
+
·
|
16 |
+
<a href="https://scholar.google.com/citations?user=484sccEAAAAJ&hl=en">Samuel Rota Bulo</a>
|
17 |
+
·
|
18 |
+
<a href="https://scholar.google.com/citations?hl=en&user=MhowvPkAAAAJ">Richard Newcombe</a>
|
19 |
+
·
|
20 |
+
<a href="https://scholar.google.com/citations?hl=en&user=CxbDDRMAAAAJ">Peter Kontschieder</a>
|
21 |
+
·
|
22 |
+
<a href="https://scholar.google.com/citations?user=AGoNHcsAAAAJ&hl=en">Vasileios Balntas</a>
|
23 |
+
</p>
|
24 |
+
<h2 align="center">CVPR 2023</h2>
|
25 |
+
<h3 align="center"><a href="https://arxiv.org/pdf/2304.02009.pdf">Paper</a> | <a href="https://psarlin.com/orienternet">Project Page</a> | <a href="https://youtu.be/wglW8jnupSs">Video</a></h3>
|
26 |
+
<div align="center"></div>
|
27 |
+
</p>
|
28 |
+
<p align="center">
|
29 |
+
<a href="https://psarlin.com/orienternet"><img src="assets/teaser.svg" alt="teaser" width="60%"></a>
|
30 |
+
<br>
|
31 |
+
<em>OrienterNet is a deep neural network that can accurately localize an image<br>using the same 2D semantic maps that humans use to orient themselves.</em>
|
32 |
+
</p>
|
33 |
+
|
34 |
+
##
|
35 |
+
|
36 |
+
This repository hosts the source code for OrienterNet, a research project by Meta Reality Labs. OrienterNet leverages the power of deep learning to provide accurate positioning of images using free and globally-available maps from OpenStreetMap. As opposed to complex existing algorithms that rely on 3D point clouds, OrienterNet estimates a position and orientation by matching a neural Bird's-Eye-View with 2D maps.
|
37 |
+
|
38 |
+
## Installation
|
39 |
+
|
40 |
+
OrienterNet requires Python >= 3.8 and [PyTorch](https://pytorch.org/). To run the demo, clone this repo and install the minimal requirements:
|
41 |
+
|
42 |
+
```bash
|
43 |
+
git clone https://github.com/facebookresearch/OrienterNet
|
44 |
+
python -m pip install -r requirements/demo.txt
|
45 |
+
```
|
46 |
+
|
47 |
+
To run the evaluation and training, install the full requirements:
|
48 |
+
|
49 |
+
```bash
|
50 |
+
python -m pip install -r requirements/full.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
## Demo ➡️ [](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)
|
54 |
+
|
55 |
+
Check out the Jupyter notebook [`demo.ipynb`](./demo.ipynb) ([run it on Colab!](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)) for a minimal demo - take a picture with your phone in any city and find its exact location in a few seconds!
|
56 |
+
|
57 |
+
<p align="center">
|
58 |
+
<a href="./demo.ipynb"><img src="assets/demo.jpg" alt="demo" width="60%"></a>
|
59 |
+
<br>
|
60 |
+
<em>OrienterNet positions any image within a large area - try it with your own images!</em>
|
61 |
+
</p>
|
62 |
+
|
63 |
+
## Evaluation
|
64 |
+
|
65 |
+
#### Mapillary Geo-Localization dataset
|
66 |
+
|
67 |
+
<details>
|
68 |
+
<summary>[Click to expand]</summary>
|
69 |
+
|
70 |
+
To obtain the dataset:
|
71 |
+
|
72 |
+
1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token.
|
73 |
+
2. Run the following script to download the data from Mapillary and prepare it:
|
74 |
+
|
75 |
+
```bash
|
76 |
+
python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN
|
77 |
+
```
|
78 |
+
|
79 |
+
By default the data is written to the directory `./datasets/MGL/`. Then run the evaluation with the pre-trained model:
|
80 |
+
|
81 |
+
```bash
|
82 |
+
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL model.num_rotations=256
|
83 |
+
```
|
84 |
+
|
85 |
+
This downloads the pre-trained models if necessary. The results should be close to the following:
|
86 |
+
|
87 |
+
```
|
88 |
+
Recall xy_max_error: [14.37, 48.69, 61.7] at (1, 3, 5) m/°
|
89 |
+
Recall yaw_max_error: [20.95, 54.96, 70.17] at (1, 3, 5) m/°
|
90 |
+
```
|
91 |
+
|
92 |
+
This requires a GPU with 11GB of memory. If you run into OOM issues, consider reducing the number of rotations (the default is 256):
|
93 |
+
|
94 |
+
```bash
|
95 |
+
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL \
|
96 |
+
model.num_rotations=128
|
97 |
+
```
|
98 |
+
|
99 |
+
To export visualizations for the first 100 examples:
|
100 |
+
|
101 |
+
```bash
|
102 |
+
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL \
|
103 |
+
--output_dir ./viz_MGL/ --num 100
|
104 |
+
```
|
105 |
+
|
106 |
+
To run the evaluation in sequential mode (by default with 10 frames):
|
107 |
+
|
108 |
+
```bash
|
109 |
+
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL --sequential
|
110 |
+
```
|
111 |
+
|
112 |
+
</details>
|
113 |
+
|
114 |
+
#### KITTI dataset
|
115 |
+
|
116 |
+
<details>
|
117 |
+
<summary>[Click to expand]</summary>
|
118 |
+
|
119 |
+
1. Download and prepare the dataset to `./datasets/kitti/`:
|
120 |
+
|
121 |
+
```bash
|
122 |
+
python -m maploc.data.kitti.prepare
|
123 |
+
```
|
124 |
+
|
125 |
+
2. Run the evaluation with the model trained on MGL:
|
126 |
+
|
127 |
+
```bash
|
128 |
+
python -m maploc.evaluation.kitti --experiment OrienterNet_MGL
|
129 |
+
```
|
130 |
+
|
131 |
+
You should expect the following results:
|
132 |
+
|
133 |
+
```
|
134 |
+
Recall directional_error: [[50.33, 85.18, 92.73], [24.38, 56.13, 67.98]] at (1, 3, 5) m/°
|
135 |
+
Recall yaw_max_error: [29.22, 68.2, 84.49] at (1, 3, 5) m/°
|
136 |
+
```
|
137 |
+
|
138 |
+
You can similarly export some visual examples:
|
139 |
+
|
140 |
+
```bash
|
141 |
+
python -m maploc.evaluation.kitti --experiment OrienterNet_MGL \
|
142 |
+
--output_dir ./viz_KITTI/ --num 100
|
143 |
+
```
|
144 |
+
|
145 |
+
</details>
|
146 |
+
|
147 |
+
#### Aria Detroit & Seattle
|
148 |
+
|
149 |
+
We are currently unable to release the dataset used to evaluate OrienterNet in the CVPR 2023 paper.
|
150 |
+
|
151 |
+
## Training
|
152 |
+
|
153 |
+
#### MGL dataset
|
154 |
+
|
155 |
+
We trained the model on the MGL dataset using 3x 3090 GPUs (24GB VRAM each) and a total batch size of 12 for 340k iterations (about 3-4 days) with the following command:
|
156 |
+
|
157 |
+
```bash
|
158 |
+
python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
|
159 |
+
```
|
160 |
+
|
161 |
+
Feel free to use any other experiment name. Configurations are managed by [Hydra](https://hydra.cc/) and [OmegaConf](https://omegaconf.readthedocs.io) so any entry can be overridden from the command line. You may thus reduce the number of GPUs and the batch size via:
|
162 |
+
|
163 |
+
```bash
|
164 |
+
python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
|
165 |
+
experiment.gpus=1 data.loading.train.batch_size=4
|
166 |
+
```
|
167 |
+
|
168 |
+
Be aware that this can reduce the overall performance. The checkpoints are written to `./experiments/experiment_name/`. Then run the evaluation:
|
169 |
+
|
170 |
+
```bash
|
171 |
+
# the best checkpoint:
|
172 |
+
python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL_reproduce
|
173 |
+
# a specific checkpoint:
|
174 |
+
python -m maploc.evaluation.mapillary \
|
175 |
+
--experiment OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt
|
176 |
+
```
|
177 |
+
|
178 |
+
#### KITTI
|
179 |
+
|
180 |
+
To fine-tune a trained model on the KITTI dataset:
|
181 |
+
|
182 |
+
```bash
|
183 |
+
python -m maploc.train experiment.name=OrienterNet_MGL_kitti data=kitti \
|
184 |
+
training.finetune_from_checkpoint='"experiments/OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt"'
|
185 |
+
```
|
186 |
+
|
187 |
+
## Interactive development
|
188 |
+
|
189 |
+
We provide several visualization notebooks:
|
190 |
+
|
191 |
+
- [Visualize predictions on the MGL dataset](./notebooks/visualize_predictions_mgl.ipynb)
|
192 |
+
- [Visualize predictions on the KITTI dataset](./notebooks/visualize_predictions_kitti.ipynb)
|
193 |
+
- [Visualize sequential predictions](./notebooks/visualize_predictions_sequences.ipynb)
|
194 |
+
|
195 |
+
## OpenStreetMap data
|
196 |
+
|
197 |
+
<details>
|
198 |
+
<summary>[Click to expand]</summary>
|
199 |
+
|
200 |
+
To make sure that the results are consistent over time, we used OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021. By default, the dataset scripts `maploc.data.[mapillary,kitti].prepare` download pre-generated raster tiles. If you wish to use different OSM classes, you can pass `--generate_tiles`, which will download and use our prepared raw `.osm` XML files. You may alternatively download more recent files.
|
201 |
+
|
202 |
+
</details>
|
203 |
+
|
204 |
+
## License
|
205 |
+
|
206 |
+
The MGL dataset is made available under the [CC-BY-SA](https://creativecommons.org/licenses/by-sa/4.0/) license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a [CC-BY-NC](https://creativecommons.org/licenses/by-nc/2.0/) license. Keep in mind that OpenStreetMap [follows a different license](https://www.openstreetmap.org/copyright).
|
207 |
+
|
208 |
+
## BibTex citation
|
209 |
+
|
210 |
+
Please consider citing our work if you use any code from this repo or ideas presented in the paper:
|
211 |
+
```
|
212 |
+
@inproceedings{sarlin2023orienternet,
|
213 |
+
author = {Paul-Edouard Sarlin and
|
214 |
+
Daniel DeTone and
|
215 |
+
Tsun-Yi Yang and
|
216 |
+
Armen Avetisyan and
|
217 |
+
Julian Straub and
|
218 |
+
Tomasz Malisiewicz and
|
219 |
+
Samuel Rota Bulo and
|
220 |
+
Richard Newcombe and
|
221 |
+
Peter Kontschieder and
|
222 |
+
Vasileios Balntas},
|
223 |
+
title = {{OrienterNet: Visual Localization in 2D Public Maps with Neural Matching}},
|
224 |
+
booktitle = {CVPR},
|
225 |
+
year = {2023},
|
226 |
+
}
|
227 |
+
```
|
228 |
+
|
229 |
+
|
assets/demo.jpg
ADDED
![]() |
assets/teaser.svg
ADDED
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maploc/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import pytorch_lightning # noqa: F401
|
7 |
+
|
8 |
+
|
9 |
+
formatter = logging.Formatter(
|
10 |
+
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
|
11 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
12 |
+
)
|
13 |
+
handler = logging.StreamHandler()
|
14 |
+
handler.setFormatter(formatter)
|
15 |
+
handler.setLevel(logging.INFO)
|
16 |
+
|
17 |
+
logger = logging.getLogger("maploc")
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
logger.addHandler(handler)
|
20 |
+
logger.propagate = False
|
21 |
+
|
22 |
+
pl_logger = logging.getLogger("pytorch_lightning")
|
23 |
+
if len(pl_logger.handlers):
|
24 |
+
pl_logger.handlers[0].setFormatter(formatter)
|
25 |
+
|
26 |
+
repo_dir = Path(__file__).parent.parent
|
27 |
+
EXPERIMENTS_PATH = repo_dir / "experiments/"
|
28 |
+
DATASETS_PATH = repo_dir / "datasets/"
|
maploc/conf/__init__.py
ADDED
File without changes
|
maploc/conf/data/__init__.py
ADDED
File without changes
|
maploc/conf/data/kitti.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: kitti
|
2 |
+
loading:
|
3 |
+
train:
|
4 |
+
batch_size: 9
|
5 |
+
num_workers: ${.batch_size}
|
6 |
+
val:
|
7 |
+
batch_size: ${..train.batch_size}
|
8 |
+
num_workers: ${.batch_size}
|
9 |
+
# make sure train and val locations are at least 5m apart
|
10 |
+
selection_subset_val: furthest
|
11 |
+
max_num_val: 500
|
12 |
+
drop_train_too_close_to_val: 5.0
|
13 |
+
# map data
|
14 |
+
num_classes:
|
15 |
+
areas: 7
|
16 |
+
ways: 10
|
17 |
+
nodes: 33
|
18 |
+
pixel_per_meter: 2
|
19 |
+
crop_size_meters: 64
|
20 |
+
max_init_error: 32
|
21 |
+
# preprocessing
|
22 |
+
target_focal_length: 256
|
23 |
+
resize_image: [448, 160] # multiple of 32 at f=256px
|
24 |
+
# pad_to_multiple: 32
|
25 |
+
rectify_pitch: true
|
26 |
+
augmentation:
|
27 |
+
rot90: true
|
28 |
+
flip: true
|
29 |
+
image: {apply: true}
|
maploc/conf/data/mapillary.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mapillary
|
2 |
+
scenes:
|
3 |
+
- sanfrancisco_soma
|
4 |
+
- sanfrancisco_hayes
|
5 |
+
- amsterdam
|
6 |
+
- berlin
|
7 |
+
- lemans
|
8 |
+
- montrouge
|
9 |
+
- toulouse
|
10 |
+
- nantes
|
11 |
+
- vilnius
|
12 |
+
- avignon
|
13 |
+
- helsinki
|
14 |
+
- milan
|
15 |
+
- paris
|
16 |
+
split: splits_MGL_13loc.json
|
17 |
+
loading:
|
18 |
+
train:
|
19 |
+
batch_size: 12
|
20 |
+
num_workers: ${.batch_size}
|
21 |
+
val:
|
22 |
+
batch_size: ${..train.batch_size}
|
23 |
+
num_workers: ${.batch_size}
|
24 |
+
# map data
|
25 |
+
num_classes:
|
26 |
+
areas: 7
|
27 |
+
ways: 10
|
28 |
+
nodes: 33
|
29 |
+
pixel_per_meter: 2
|
30 |
+
crop_size_meters: 64
|
31 |
+
max_init_error: 48
|
32 |
+
add_map_mask: true
|
33 |
+
# preprocessing
|
34 |
+
resize_image: 512
|
35 |
+
pad_to_square: true
|
36 |
+
rectify_pitch: true
|
37 |
+
augmentation:
|
38 |
+
rot90: true
|
39 |
+
flip: true
|
40 |
+
image: {apply: true}
|
maploc/conf/model/image_encoder/global.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: feature_extractor
|
2 |
+
backbone:
|
3 |
+
encoder: resnet18
|
4 |
+
pretrained: true
|
5 |
+
output_dim: ${...latent_dim}
|
6 |
+
output_scales: [5]
|
7 |
+
num_downsample: 5
|
8 |
+
decoder: null
|
9 |
+
pooling: mean
|
maploc/conf/model/image_encoder/resnet_fpn.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: feature_extractor_v2
|
2 |
+
backbone:
|
3 |
+
encoder: resnet50
|
4 |
+
pretrained: true
|
5 |
+
output_dim: ${...latent_dim}
|
6 |
+
num_downsample: null
|
7 |
+
remove_stride_from_first_conv: false
|
maploc/conf/model/image_encoder/vgg_unet.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: feature_extractor
|
2 |
+
backbone:
|
3 |
+
encoder: vgg16
|
4 |
+
pretrained: true
|
5 |
+
output_dim: ${...latent_dim}
|
6 |
+
output_scales: [0]
|
7 |
+
num_downsample: 4
|
8 |
+
decoder: [512, 256, 256, 128]
|
maploc/conf/orienternet.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: mapillary
|
3 |
+
- model/image_encoder: resnet_fpn
|
4 |
+
- training
|
5 |
+
- _self_
|
6 |
+
model:
|
7 |
+
name: orienternet
|
8 |
+
latent_dim: 128
|
9 |
+
matching_dim: 8
|
10 |
+
z_max: 32
|
11 |
+
x_max: 32
|
12 |
+
pixel_per_meter: ${data.pixel_per_meter}
|
13 |
+
num_scale_bins: 33
|
14 |
+
num_rotations: 64
|
15 |
+
image_encoder:
|
16 |
+
backbone:
|
17 |
+
encoder: resnet101
|
18 |
+
map_encoder:
|
19 |
+
embedding_dim: 16
|
20 |
+
output_dim: ${..matching_dim}
|
21 |
+
num_classes: ${data.num_classes}
|
22 |
+
backbone:
|
23 |
+
encoder: vgg19
|
24 |
+
pretrained: false
|
25 |
+
output_scales: [0]
|
26 |
+
num_downsample: 3
|
27 |
+
decoder: [128, 64, 64]
|
28 |
+
padding: replicate
|
29 |
+
unary_prior: false
|
30 |
+
bev_net:
|
31 |
+
num_blocks: 4
|
32 |
+
latent_dim: ${..latent_dim}
|
33 |
+
output_dim: ${..matching_dim}
|
34 |
+
confidence: true
|
maploc/conf/overfit.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- orienternet
|
3 |
+
- _self_
|
4 |
+
data:
|
5 |
+
loading:
|
6 |
+
train:
|
7 |
+
batch_size: 6
|
8 |
+
random: false
|
9 |
+
split: null
|
10 |
+
model:
|
11 |
+
freeze_batch_normalization: true
|
12 |
+
training:
|
13 |
+
trainer:
|
14 |
+
overfit_batches: 1
|
15 |
+
val_check_interval: 1
|
16 |
+
log_every_n_steps: 1
|
17 |
+
limit_val_batches: 1
|
maploc/conf/training.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
name: ???
|
3 |
+
gpus: 3
|
4 |
+
seed: 0
|
5 |
+
training:
|
6 |
+
lr: 1e-4
|
7 |
+
lr_scheduler: null
|
8 |
+
finetune_from_checkpoint: null
|
9 |
+
trainer:
|
10 |
+
val_check_interval: 5000
|
11 |
+
log_every_n_steps: 100
|
12 |
+
limit_val_batches: 1000
|
13 |
+
max_steps: 500000
|
14 |
+
devices: ${experiment.gpus}
|
15 |
+
checkpointing:
|
16 |
+
monitor: "loss/total/val"
|
17 |
+
save_top_k: 5
|
18 |
+
mode: "min"
|
19 |
+
hydra:
|
20 |
+
job:
|
21 |
+
name: ${experiment.name}
|
22 |
+
chdir: false
|
maploc/data/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .kitti.dataset import KittiDataModule
|
2 |
+
from .mapillary.dataset import MapillaryDataModule
|
3 |
+
|
4 |
+
modules = {"mapillary": MapillaryDataModule, "kitti": KittiDataModule}
|
maploc/data/dataset.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.utils.data as torchdata
|
10 |
+
import torchvision.transforms as tvf
|
11 |
+
from omegaconf import DictConfig, OmegaConf
|
12 |
+
|
13 |
+
from ..models.utils import deg2rad, rotmat2d
|
14 |
+
from ..osm.tiling import TileManager
|
15 |
+
from ..utils.geo import BoundaryBox
|
16 |
+
from ..utils.io import read_image
|
17 |
+
from ..utils.wrappers import Camera
|
18 |
+
from .image import pad_image, rectify_image, resize_image
|
19 |
+
from .utils import decompose_rotmat, random_flip, random_rot90
|
20 |
+
|
21 |
+
|
22 |
+
class MapLocDataset(torchdata.Dataset):
|
23 |
+
default_cfg = {
|
24 |
+
"seed": 0,
|
25 |
+
"accuracy_gps": 15,
|
26 |
+
"random": True,
|
27 |
+
"num_threads": None,
|
28 |
+
# map
|
29 |
+
"num_classes": None,
|
30 |
+
"pixel_per_meter": "???",
|
31 |
+
"crop_size_meters": "???",
|
32 |
+
"max_init_error": "???",
|
33 |
+
"max_init_error_rotation": None,
|
34 |
+
"init_from_gps": False,
|
35 |
+
"return_gps": False,
|
36 |
+
"force_camera_height": None,
|
37 |
+
# pose priors
|
38 |
+
"add_map_mask": False,
|
39 |
+
"mask_radius": None,
|
40 |
+
"mask_pad": 1,
|
41 |
+
"prior_range_rotation": None,
|
42 |
+
# image preprocessing
|
43 |
+
"target_focal_length": None,
|
44 |
+
"reduce_fov": None,
|
45 |
+
"resize_image": None,
|
46 |
+
"pad_to_square": False, # legacy
|
47 |
+
"pad_to_multiple": 32,
|
48 |
+
"rectify_pitch": True,
|
49 |
+
"augmentation": {
|
50 |
+
"rot90": False,
|
51 |
+
"flip": False,
|
52 |
+
"image": {
|
53 |
+
"apply": False,
|
54 |
+
"brightness": 0.5,
|
55 |
+
"contrast": 0.4,
|
56 |
+
"saturation": 0.4,
|
57 |
+
"hue": 0.5 / 3.14,
|
58 |
+
},
|
59 |
+
},
|
60 |
+
}
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
stage: str,
|
65 |
+
cfg: DictConfig,
|
66 |
+
names: List[str],
|
67 |
+
data: Dict[str, Any],
|
68 |
+
image_dirs: Dict[str, Path],
|
69 |
+
tile_managers: Dict[str, TileManager],
|
70 |
+
image_ext: str = "",
|
71 |
+
):
|
72 |
+
self.stage = stage
|
73 |
+
self.cfg = deepcopy(cfg)
|
74 |
+
self.data = data
|
75 |
+
self.image_dirs = image_dirs
|
76 |
+
self.tile_managers = tile_managers
|
77 |
+
self.names = names
|
78 |
+
self.image_ext = image_ext
|
79 |
+
|
80 |
+
tfs = []
|
81 |
+
if stage == "train" and cfg.augmentation.image.apply:
|
82 |
+
args = OmegaConf.masked_copy(
|
83 |
+
cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
|
84 |
+
)
|
85 |
+
tfs.append(tvf.ColorJitter(**args))
|
86 |
+
self.tfs = tvf.Compose(tfs)
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.names)
|
90 |
+
|
91 |
+
def __getitem__(self, idx):
|
92 |
+
if self.stage == "train" and self.cfg.random:
|
93 |
+
seed = None
|
94 |
+
else:
|
95 |
+
seed = [self.cfg.seed, idx]
|
96 |
+
(seed,) = np.random.SeedSequence(seed).generate_state(1)
|
97 |
+
|
98 |
+
scene, seq, name = self.names[idx]
|
99 |
+
if self.cfg.init_from_gps:
|
100 |
+
latlon_gps = self.data["gps_position"][idx][:2].clone().numpy()
|
101 |
+
xy_w_init = self.tile_managers[scene].projection.project(latlon_gps)
|
102 |
+
else:
|
103 |
+
xy_w_init = self.data["t_c2w"][idx][:2].clone().double().numpy()
|
104 |
+
|
105 |
+
if "shifts" in self.data:
|
106 |
+
yaw = self.data["roll_pitch_yaw"][idx][-1]
|
107 |
+
R_c2w = rotmat2d((90 - yaw) / 180 * np.pi).float()
|
108 |
+
error = (R_c2w @ self.data["shifts"][idx][:2]).numpy()
|
109 |
+
else:
|
110 |
+
error = np.random.RandomState(seed).uniform(-1, 1, size=2)
|
111 |
+
xy_w_init += error * self.cfg.max_init_error
|
112 |
+
|
113 |
+
bbox_tile = BoundaryBox(
|
114 |
+
xy_w_init - self.cfg.crop_size_meters,
|
115 |
+
xy_w_init + self.cfg.crop_size_meters,
|
116 |
+
)
|
117 |
+
return self.get_view(idx, scene, seq, name, seed, bbox_tile)
|
118 |
+
|
119 |
+
def get_view(self, idx, scene, seq, name, seed, bbox_tile):
|
120 |
+
data = {
|
121 |
+
"index": idx,
|
122 |
+
"name": name,
|
123 |
+
"scene": scene,
|
124 |
+
"sequence": seq,
|
125 |
+
}
|
126 |
+
cam_dict = self.data["cameras"][scene][seq][self.data["camera_id"][idx]]
|
127 |
+
cam = Camera.from_dict(cam_dict).float()
|
128 |
+
|
129 |
+
if "roll_pitch_yaw" in self.data:
|
130 |
+
roll, pitch, yaw = self.data["roll_pitch_yaw"][idx].numpy()
|
131 |
+
else:
|
132 |
+
roll, pitch, yaw = decompose_rotmat(self.data["R_c2w"][idx].numpy())
|
133 |
+
image = read_image(self.image_dirs[scene] / (name + self.image_ext))
|
134 |
+
|
135 |
+
if "plane_params" in self.data:
|
136 |
+
# transform the plane parameters from world to camera frames
|
137 |
+
plane_w = self.data["plane_params"][idx]
|
138 |
+
data["ground_plane"] = torch.cat(
|
139 |
+
[rotmat2d(deg2rad(torch.tensor(yaw))) @ plane_w[:2], plane_w[2:]]
|
140 |
+
)
|
141 |
+
if self.cfg.force_camera_height is not None:
|
142 |
+
data["camera_height"] = torch.tensor(self.cfg.force_camera_height)
|
143 |
+
elif "camera_height" in self.data:
|
144 |
+
data["camera_height"] = self.data["height"][idx].clone()
|
145 |
+
|
146 |
+
# raster extraction
|
147 |
+
canvas = self.tile_managers[scene].query(bbox_tile)
|
148 |
+
xy_w_gt = self.data["t_c2w"][idx][:2].numpy()
|
149 |
+
uv_gt = canvas.to_uv(xy_w_gt)
|
150 |
+
uv_init = canvas.to_uv(bbox_tile.center)
|
151 |
+
raster = canvas.raster # C, H, W
|
152 |
+
|
153 |
+
# Map augmentations
|
154 |
+
heading = np.deg2rad(90 - yaw) # fixme
|
155 |
+
if self.stage == "train":
|
156 |
+
if self.cfg.augmentation.rot90:
|
157 |
+
raster, uv_gt, heading = random_rot90(raster, uv_gt, heading, seed)
|
158 |
+
if self.cfg.augmentation.flip:
|
159 |
+
image, raster, uv_gt, heading = random_flip(
|
160 |
+
image, raster, uv_gt, heading, seed
|
161 |
+
)
|
162 |
+
yaw = 90 - np.rad2deg(heading) # fixme
|
163 |
+
|
164 |
+
image, valid, cam, roll, pitch = self.process_image(
|
165 |
+
image, cam, roll, pitch, seed
|
166 |
+
)
|
167 |
+
|
168 |
+
# Create the mask for prior location
|
169 |
+
if self.cfg.add_map_mask:
|
170 |
+
data["map_mask"] = torch.from_numpy(self.create_map_mask(canvas))
|
171 |
+
|
172 |
+
if self.cfg.max_init_error_rotation is not None:
|
173 |
+
if "shifts" in self.data:
|
174 |
+
error = self.data["shifts"][idx][-1]
|
175 |
+
else:
|
176 |
+
error = np.random.RandomState(seed + 1).uniform(-1, 1)
|
177 |
+
error = torch.tensor(error, dtype=torch.float)
|
178 |
+
yaw_init = yaw + error * self.cfg.max_init_error_rotation
|
179 |
+
range_ = self.cfg.prior_range_rotation or self.cfg.max_init_error_rotation
|
180 |
+
data["yaw_prior"] = torch.stack([yaw_init, torch.tensor(range_)])
|
181 |
+
|
182 |
+
if self.cfg.return_gps:
|
183 |
+
gps = self.data["gps_position"][idx][:2].numpy()
|
184 |
+
xy_gps = self.tile_managers[scene].projection.project(gps)
|
185 |
+
data["uv_gps"] = torch.from_numpy(canvas.to_uv(xy_gps)).float()
|
186 |
+
data["accuracy_gps"] = torch.tensor(
|
187 |
+
min(self.cfg.accuracy_gps, self.cfg.crop_size_meters)
|
188 |
+
)
|
189 |
+
|
190 |
+
if "chunk_index" in self.data:
|
191 |
+
data["chunk_id"] = (scene, seq, self.data["chunk_index"][idx])
|
192 |
+
|
193 |
+
return {
|
194 |
+
**data,
|
195 |
+
"image": image,
|
196 |
+
"valid": valid,
|
197 |
+
"camera": cam,
|
198 |
+
"canvas": canvas,
|
199 |
+
"map": torch.from_numpy(np.ascontiguousarray(raster)).long(),
|
200 |
+
"uv": torch.from_numpy(uv_gt).float(), # TODO: maybe rename to uv?
|
201 |
+
"uv_init": torch.from_numpy(uv_init).float(), # TODO: maybe rename to uv?
|
202 |
+
"roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
|
203 |
+
"pixels_per_meter": torch.tensor(canvas.ppm).float(),
|
204 |
+
}
|
205 |
+
|
206 |
+
def process_image(self, image, cam, roll, pitch, seed):
|
207 |
+
image = (
|
208 |
+
torch.from_numpy(np.ascontiguousarray(image))
|
209 |
+
.permute(2, 0, 1)
|
210 |
+
.float()
|
211 |
+
.div_(255)
|
212 |
+
)
|
213 |
+
image, valid = rectify_image(
|
214 |
+
image, cam, roll, pitch if self.cfg.rectify_pitch else None
|
215 |
+
)
|
216 |
+
roll = 0.0
|
217 |
+
if self.cfg.rectify_pitch:
|
218 |
+
pitch = 0.0
|
219 |
+
|
220 |
+
if self.cfg.target_focal_length is not None:
|
221 |
+
# resize to a canonical focal length
|
222 |
+
factor = self.cfg.target_focal_length / cam.f.numpy()
|
223 |
+
size = (np.array(image.shape[-2:][::-1]) * factor).astype(int)
|
224 |
+
image, _, cam, valid = resize_image(image, size, camera=cam, valid=valid)
|
225 |
+
size_out = self.cfg.resize_image
|
226 |
+
if size_out is None:
|
227 |
+
# round the edges up such that they are multiple of a factor
|
228 |
+
stride = self.cfg.pad_to_multiple
|
229 |
+
size_out = (np.ceil((size / stride)) * stride).astype(int)
|
230 |
+
# crop or pad such that both edges are of the given size
|
231 |
+
image, valid, cam = pad_image(
|
232 |
+
image, size_out, cam, valid, crop_and_center=True
|
233 |
+
)
|
234 |
+
elif self.cfg.resize_image is not None:
|
235 |
+
image, _, cam, valid = resize_image(
|
236 |
+
image, self.cfg.resize_image, fn=max, camera=cam, valid=valid
|
237 |
+
)
|
238 |
+
if self.cfg.pad_to_square:
|
239 |
+
# pad such that both edges are of the given size
|
240 |
+
image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid)
|
241 |
+
|
242 |
+
if self.cfg.reduce_fov is not None:
|
243 |
+
h, w = image.shape[-2:]
|
244 |
+
f = float(cam.f[0])
|
245 |
+
fov = np.arctan(w / f / 2)
|
246 |
+
w_new = round(2 * f * np.tan(self.cfg.reduce_fov * fov))
|
247 |
+
image, valid, cam = pad_image(
|
248 |
+
image, (w_new, h), cam, valid, crop_and_center=True
|
249 |
+
)
|
250 |
+
|
251 |
+
with torch.random.fork_rng(devices=[]):
|
252 |
+
torch.manual_seed(seed)
|
253 |
+
image = self.tfs(image)
|
254 |
+
return image, valid, cam, roll, pitch
|
255 |
+
|
256 |
+
def create_map_mask(self, canvas):
|
257 |
+
map_mask = np.zeros(canvas.raster.shape[-2:], bool)
|
258 |
+
radius = self.cfg.mask_radius or self.cfg.max_init_error
|
259 |
+
mask_min, mask_max = np.round(
|
260 |
+
canvas.to_uv(canvas.bbox.center)
|
261 |
+
+ np.array([[-1], [1]]) * (radius + self.cfg.mask_pad) * canvas.ppm
|
262 |
+
).astype(int)
|
263 |
+
map_mask[mask_min[1] : mask_max[1], mask_min[0] : mask_max[0]] = True
|
264 |
+
return map_mask
|
maploc/data/image.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Callable, Optional, Union, Sequence
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms.functional as tvf
|
8 |
+
import collections
|
9 |
+
from scipy.spatial.transform import Rotation
|
10 |
+
|
11 |
+
from ..utils.geometry import from_homogeneous, to_homogeneous
|
12 |
+
from ..utils.wrappers import Camera
|
13 |
+
|
14 |
+
|
15 |
+
def rectify_image(
|
16 |
+
image: torch.Tensor,
|
17 |
+
cam: Camera,
|
18 |
+
roll: float,
|
19 |
+
pitch: Optional[float] = None,
|
20 |
+
valid: Optional[torch.Tensor] = None,
|
21 |
+
):
|
22 |
+
*_, h, w = image.shape
|
23 |
+
grid = torch.meshgrid(
|
24 |
+
[torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
|
25 |
+
indexing="xy",
|
26 |
+
)
|
27 |
+
grid = torch.stack(grid, -1).to(image.dtype)
|
28 |
+
|
29 |
+
if pitch is not None:
|
30 |
+
args = ("ZX", (roll, pitch))
|
31 |
+
else:
|
32 |
+
args = ("Z", roll)
|
33 |
+
R = Rotation.from_euler(*args, degrees=True).as_matrix()
|
34 |
+
R = torch.from_numpy(R).to(image)
|
35 |
+
|
36 |
+
grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
|
37 |
+
grid_rect = cam.denormalize(from_homogeneous(grid_rect))
|
38 |
+
grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
|
39 |
+
rectified = torch.nn.functional.grid_sample(
|
40 |
+
image[None],
|
41 |
+
grid_norm[None],
|
42 |
+
align_corners=False,
|
43 |
+
mode="bilinear",
|
44 |
+
).squeeze(0)
|
45 |
+
if valid is None:
|
46 |
+
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
|
47 |
+
else:
|
48 |
+
valid = (
|
49 |
+
torch.nn.functional.grid_sample(
|
50 |
+
valid[None, None].float(),
|
51 |
+
grid_norm[None],
|
52 |
+
align_corners=False,
|
53 |
+
mode="nearest",
|
54 |
+
)[0, 0]
|
55 |
+
> 0
|
56 |
+
)
|
57 |
+
return rectified, valid
|
58 |
+
|
59 |
+
|
60 |
+
def resize_image(
|
61 |
+
image: torch.Tensor,
|
62 |
+
size: Union[int, Sequence, np.ndarray],
|
63 |
+
fn: Optional[Callable] = None,
|
64 |
+
camera: Optional[Camera] = None,
|
65 |
+
valid: np.ndarray = None,
|
66 |
+
):
|
67 |
+
"""Resize an image to a fixed size, or according to max or min edge."""
|
68 |
+
*_, h, w = image.shape
|
69 |
+
if fn is not None:
|
70 |
+
assert isinstance(size, int)
|
71 |
+
scale = size / fn(h, w)
|
72 |
+
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
73 |
+
scale = (scale, scale)
|
74 |
+
else:
|
75 |
+
if isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
76 |
+
w_new, h_new = size
|
77 |
+
elif isinstance(size, int):
|
78 |
+
w_new = h_new = size
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Incorrect new size: {size}")
|
81 |
+
scale = (w_new / w, h_new / h)
|
82 |
+
if (w, h) != (w_new, h_new):
|
83 |
+
mode = tvf.InterpolationMode.BILINEAR
|
84 |
+
image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
|
85 |
+
image.clip_(0, 1)
|
86 |
+
if camera is not None:
|
87 |
+
camera = camera.scale(scale)
|
88 |
+
if valid is not None:
|
89 |
+
valid = tvf.resize(
|
90 |
+
valid.unsqueeze(0),
|
91 |
+
(h_new, w_new),
|
92 |
+
interpolation=tvf.InterpolationMode.NEAREST,
|
93 |
+
).squeeze(0)
|
94 |
+
ret = [image, scale]
|
95 |
+
if camera is not None:
|
96 |
+
ret.append(camera)
|
97 |
+
if valid is not None:
|
98 |
+
ret.append(valid)
|
99 |
+
return ret
|
100 |
+
|
101 |
+
|
102 |
+
def pad_image(
|
103 |
+
image: torch.Tensor,
|
104 |
+
size: Union[int, Sequence, np.ndarray],
|
105 |
+
camera: Optional[Camera] = None,
|
106 |
+
valid: torch.Tensor = None,
|
107 |
+
crop_and_center: bool = False,
|
108 |
+
):
|
109 |
+
if isinstance(size, int):
|
110 |
+
w_new = h_new = size
|
111 |
+
elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
112 |
+
w_new, h_new = size
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Incorrect new size: {size}")
|
115 |
+
*c, h, w = image.shape
|
116 |
+
if crop_and_center:
|
117 |
+
diff = np.array([w - w_new, h - h_new])
|
118 |
+
left, top = left_top = np.round(diff / 2).astype(int)
|
119 |
+
right, bottom = diff - left_top
|
120 |
+
else:
|
121 |
+
assert h <= h_new
|
122 |
+
assert w <= w_new
|
123 |
+
top = bottom = left = right = 0
|
124 |
+
slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
|
125 |
+
slice_in = np.s_[
|
126 |
+
..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
|
127 |
+
]
|
128 |
+
if (w, h) == (w_new, h_new):
|
129 |
+
out = image
|
130 |
+
else:
|
131 |
+
out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
|
132 |
+
out[slice_out] = image[slice_in]
|
133 |
+
if camera is not None:
|
134 |
+
camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
|
135 |
+
out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
|
136 |
+
out_valid[slice_out] = True if valid is None else valid[slice_in]
|
137 |
+
if camera is not None:
|
138 |
+
return out, out_valid, camera
|
139 |
+
else:
|
140 |
+
return out, out_valid
|
maploc/data/kitti/dataset.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import collections.abc
|
5 |
+
from collections import defaultdict
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
import torch
|
12 |
+
import torch.utils.data as torchdata
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from scipy.spatial.transform import Rotation
|
15 |
+
|
16 |
+
from ... import logger, DATASETS_PATH
|
17 |
+
from ...osm.tiling import TileManager
|
18 |
+
from ..dataset import MapLocDataset
|
19 |
+
from ..sequential import chunk_sequence
|
20 |
+
from ..torch import collate, worker_init_fn
|
21 |
+
from .utils import parse_split_file, parse_gps_file, get_camera_calibration
|
22 |
+
|
23 |
+
|
24 |
+
class KittiDataModule(pl.LightningDataModule):
|
25 |
+
default_cfg = {
|
26 |
+
**MapLocDataset.default_cfg,
|
27 |
+
"name": "kitti",
|
28 |
+
# paths and fetch
|
29 |
+
"data_dir": DATASETS_PATH / "kitti",
|
30 |
+
"tiles_filename": "tiles.pkl",
|
31 |
+
"splits": {
|
32 |
+
"train": "train_files.txt",
|
33 |
+
"val": "test1_files.txt",
|
34 |
+
"test": "test2_files.txt",
|
35 |
+
},
|
36 |
+
"loading": {
|
37 |
+
"train": "???",
|
38 |
+
"val": "${.test}",
|
39 |
+
"test": {"batch_size": 1, "num_workers": 0},
|
40 |
+
},
|
41 |
+
"max_num_val": 500,
|
42 |
+
"selection_subset_val": "furthest",
|
43 |
+
"drop_train_too_close_to_val": 5.0,
|
44 |
+
"skip_frames": 1,
|
45 |
+
"camera_index": 2,
|
46 |
+
# overwrite
|
47 |
+
"crop_size_meters": 64,
|
48 |
+
"max_init_error": 20,
|
49 |
+
"max_init_error_rotation": 10,
|
50 |
+
"add_map_mask": True,
|
51 |
+
"mask_pad": 2,
|
52 |
+
"target_focal_length": 256,
|
53 |
+
}
|
54 |
+
dummy_scene_name = "kitti"
|
55 |
+
|
56 |
+
def __init__(self, cfg, tile_manager: Optional[TileManager] = None):
|
57 |
+
super().__init__()
|
58 |
+
default_cfg = OmegaConf.create(self.default_cfg)
|
59 |
+
OmegaConf.set_struct(default_cfg, True) # cannot add new keys
|
60 |
+
self.cfg = OmegaConf.merge(default_cfg, cfg)
|
61 |
+
self.root = Path(self.cfg.data_dir)
|
62 |
+
self.tile_manager = tile_manager
|
63 |
+
if self.cfg.crop_size_meters < self.cfg.max_init_error:
|
64 |
+
raise ValueError("The ground truth location can be outside the map.")
|
65 |
+
assert self.cfg.selection_subset_val in ["random", "furthest"]
|
66 |
+
self.splits = {}
|
67 |
+
self.shifts = {}
|
68 |
+
self.calibrations = {}
|
69 |
+
self.data = {}
|
70 |
+
self.image_paths = {}
|
71 |
+
|
72 |
+
def prepare_data(self):
|
73 |
+
if not (self.root.exists() and (self.root / ".downloaded").exists()):
|
74 |
+
raise FileNotFoundError(
|
75 |
+
"Cannot find the KITTI dataset, run maploc.data.kitti.prepare"
|
76 |
+
)
|
77 |
+
|
78 |
+
def parse_split(self, split_arg):
|
79 |
+
if isinstance(split_arg, str):
|
80 |
+
names, shifts = parse_split_file(self.root / split_arg)
|
81 |
+
elif isinstance(split_arg, collections.abc.Sequence):
|
82 |
+
names = []
|
83 |
+
shifts = None
|
84 |
+
for date_drive in split_arg:
|
85 |
+
data_dir = (
|
86 |
+
self.root / date_drive / f"image_{self.cfg.camera_index:02}/data"
|
87 |
+
)
|
88 |
+
assert data_dir.exists(), data_dir
|
89 |
+
date_drive = tuple(date_drive.split("/"))
|
90 |
+
n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png"))
|
91 |
+
names.extend(n[:: self.cfg.skip_frames])
|
92 |
+
else:
|
93 |
+
raise ValueError(split_arg)
|
94 |
+
return names, shifts
|
95 |
+
|
96 |
+
def setup(self, stage: Optional[str] = None):
|
97 |
+
if stage == "fit":
|
98 |
+
stages = ["train", "val"]
|
99 |
+
elif stage is None:
|
100 |
+
stages = ["train", "val", "test"]
|
101 |
+
else:
|
102 |
+
stages = [stage]
|
103 |
+
for stage in stages:
|
104 |
+
self.splits[stage], self.shifts[stage] = self.parse_split(
|
105 |
+
self.cfg.splits[stage]
|
106 |
+
)
|
107 |
+
do_val_subset = "val" in stages and self.cfg.max_num_val is not None
|
108 |
+
if do_val_subset and self.cfg.selection_subset_val == "random":
|
109 |
+
select = np.random.RandomState(self.cfg.seed).choice(
|
110 |
+
len(self.splits["val"]), self.cfg.max_num_val, replace=False
|
111 |
+
)
|
112 |
+
self.splits["val"] = [self.splits["val"][i] for i in select]
|
113 |
+
if self.shifts["val"] is not None:
|
114 |
+
self.shifts["val"] = self.shifts["val"][select]
|
115 |
+
dates = {d for ns in self.splits.values() for d, _, _ in ns}
|
116 |
+
for d in dates:
|
117 |
+
self.calibrations[d] = get_camera_calibration(
|
118 |
+
self.root / d, self.cfg.camera_index
|
119 |
+
)
|
120 |
+
if self.tile_manager is None:
|
121 |
+
logger.info("Loading the tile manager...")
|
122 |
+
self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename)
|
123 |
+
self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()}
|
124 |
+
self.cfg.pixel_per_meter = self.tile_manager.ppm
|
125 |
+
|
126 |
+
# pack all attributes in a single tensor to optimize memory access
|
127 |
+
self.pack_data(stages)
|
128 |
+
|
129 |
+
dists = None
|
130 |
+
if do_val_subset and self.cfg.selection_subset_val == "furthest":
|
131 |
+
dists = torch.cdist(
|
132 |
+
self.data["val"]["t_c2w"][:, :2].double(),
|
133 |
+
self.data["train"]["t_c2w"][:, :2].double(),
|
134 |
+
)
|
135 |
+
min_dists = dists.min(1).values
|
136 |
+
select = torch.argsort(min_dists)[-self.cfg.max_num_val :]
|
137 |
+
dists = dists[select]
|
138 |
+
self.splits["val"] = [self.splits["val"][i] for i in select]
|
139 |
+
if self.shifts["val"] is not None:
|
140 |
+
self.shifts["val"] = self.shifts["val"][select]
|
141 |
+
for k in list(self.data["val"]):
|
142 |
+
if k != "cameras":
|
143 |
+
self.data["val"][k] = self.data["val"][k][select]
|
144 |
+
self.image_paths["val"] = self.image_paths["val"][select]
|
145 |
+
|
146 |
+
if "train" in stages and self.cfg.drop_train_too_close_to_val is not None:
|
147 |
+
if dists is None:
|
148 |
+
dists = torch.cdist(
|
149 |
+
self.data["val"]["t_c2w"][:, :2].double(),
|
150 |
+
self.data["train"]["t_c2w"][:, :2].double(),
|
151 |
+
)
|
152 |
+
drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0)
|
153 |
+
select = torch.where(~drop)[0]
|
154 |
+
logger.info(
|
155 |
+
"Dropping %d (%f %%) images that are too close to validation images.",
|
156 |
+
drop.sum(),
|
157 |
+
drop.float().mean(),
|
158 |
+
)
|
159 |
+
self.splits["train"] = [self.splits["train"][i] for i in select]
|
160 |
+
if self.shifts["train"] is not None:
|
161 |
+
self.shifts["train"] = self.shifts["train"][select]
|
162 |
+
for k in list(self.data["train"]):
|
163 |
+
if k != "cameras":
|
164 |
+
self.data["train"][k] = self.data["train"][k][select]
|
165 |
+
self.image_paths["train"] = self.image_paths["train"][select]
|
166 |
+
|
167 |
+
def pack_data(self, stages):
|
168 |
+
for stage in stages:
|
169 |
+
names = []
|
170 |
+
data = {}
|
171 |
+
for i, (date, drive, index) in enumerate(self.splits[stage]):
|
172 |
+
d = self.get_frame_data(date, drive, index)
|
173 |
+
for k, v in d.items():
|
174 |
+
if i == 0:
|
175 |
+
data[k] = []
|
176 |
+
data[k].append(v)
|
177 |
+
path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}"
|
178 |
+
names.append((self.dummy_scene_name, f"{date}/{drive}", path))
|
179 |
+
for k in list(data):
|
180 |
+
data[k] = torch.from_numpy(np.stack(data[k]))
|
181 |
+
data["camera_id"] = np.full(len(names), self.cfg.camera_index)
|
182 |
+
|
183 |
+
sequences = {date_drive for _, date_drive, _ in names}
|
184 |
+
data["cameras"] = {
|
185 |
+
self.dummy_scene_name: {
|
186 |
+
seq: {
|
187 |
+
self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0]
|
188 |
+
}
|
189 |
+
for seq in sequences
|
190 |
+
}
|
191 |
+
}
|
192 |
+
shifts = self.shifts[stage]
|
193 |
+
if shifts is not None:
|
194 |
+
data["shifts"] = torch.from_numpy(shifts.astype(np.float32))
|
195 |
+
self.data[stage] = data
|
196 |
+
self.image_paths[stage] = np.array(names)
|
197 |
+
|
198 |
+
def get_frame_data(self, date, drive, index):
|
199 |
+
_, R_cam_gps, t_cam_gps = self.calibrations[date]
|
200 |
+
|
201 |
+
# Transform the GPS pose to the camera pose
|
202 |
+
gps_path = (
|
203 |
+
self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt")
|
204 |
+
)
|
205 |
+
_, R_world_gps, t_world_gps = parse_gps_file(
|
206 |
+
gps_path, self.tile_manager.projection
|
207 |
+
)
|
208 |
+
R_world_cam = R_world_gps @ R_cam_gps.T
|
209 |
+
t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps
|
210 |
+
# Some voodoo to extract correct Euler angles from R_world_cam
|
211 |
+
R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix()
|
212 |
+
R_world_cam_xyz = R_world_cam @ R_cv_xyz
|
213 |
+
y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True)
|
214 |
+
roll, pitch, yaw = r, -p, 90 - y
|
215 |
+
roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason
|
216 |
+
|
217 |
+
return {
|
218 |
+
"t_c2w": t_world_cam.astype(np.float32),
|
219 |
+
"roll_pitch_yaw": roll_pitch_yaw,
|
220 |
+
"index": int(index.split(".")[0]),
|
221 |
+
}
|
222 |
+
|
223 |
+
def dataset(self, stage: str):
|
224 |
+
return MapLocDataset(
|
225 |
+
stage,
|
226 |
+
self.cfg,
|
227 |
+
self.image_paths[stage],
|
228 |
+
self.data[stage],
|
229 |
+
{self.dummy_scene_name: self.root},
|
230 |
+
{self.dummy_scene_name: self.tile_manager},
|
231 |
+
)
|
232 |
+
|
233 |
+
def dataloader(
|
234 |
+
self,
|
235 |
+
stage: str,
|
236 |
+
shuffle: bool = False,
|
237 |
+
num_workers: int = None,
|
238 |
+
sampler: Optional[torchdata.Sampler] = None,
|
239 |
+
):
|
240 |
+
dataset = self.dataset(stage)
|
241 |
+
cfg = self.cfg["loading"][stage]
|
242 |
+
num_workers = cfg["num_workers"] if num_workers is None else num_workers
|
243 |
+
loader = torchdata.DataLoader(
|
244 |
+
dataset,
|
245 |
+
batch_size=cfg["batch_size"],
|
246 |
+
num_workers=num_workers,
|
247 |
+
shuffle=shuffle or (stage == "train"),
|
248 |
+
pin_memory=True,
|
249 |
+
persistent_workers=num_workers > 0,
|
250 |
+
worker_init_fn=worker_init_fn,
|
251 |
+
collate_fn=collate,
|
252 |
+
sampler=sampler,
|
253 |
+
)
|
254 |
+
return loader
|
255 |
+
|
256 |
+
def train_dataloader(self, **kwargs):
|
257 |
+
return self.dataloader("train", **kwargs)
|
258 |
+
|
259 |
+
def val_dataloader(self, **kwargs):
|
260 |
+
return self.dataloader("val", **kwargs)
|
261 |
+
|
262 |
+
def test_dataloader(self, **kwargs):
|
263 |
+
return self.dataloader("test", **kwargs)
|
264 |
+
|
265 |
+
def sequence_dataset(self, stage: str, **kwargs):
|
266 |
+
keys = self.image_paths[stage]
|
267 |
+
# group images by sequence (date/drive)
|
268 |
+
seq2indices = defaultdict(list)
|
269 |
+
for index, (_, date_drive, _) in enumerate(keys):
|
270 |
+
seq2indices[date_drive].append(index)
|
271 |
+
# chunk the sequences to the required length
|
272 |
+
chunk2indices = {}
|
273 |
+
for seq, indices in seq2indices.items():
|
274 |
+
chunks = chunk_sequence(
|
275 |
+
self.data[stage], indices, names=self.image_paths[stage], **kwargs
|
276 |
+
)
|
277 |
+
for i, sub_indices in enumerate(chunks):
|
278 |
+
chunk2indices[seq, i] = sub_indices
|
279 |
+
# store the index of each chunk in its sequence
|
280 |
+
chunk_indices = torch.full((len(keys),), -1)
|
281 |
+
for (_, chunk_index), idx in chunk2indices.items():
|
282 |
+
chunk_indices[idx] = chunk_index
|
283 |
+
self.data[stage]["chunk_index"] = chunk_indices
|
284 |
+
dataset = self.dataset(stage)
|
285 |
+
return dataset, chunk2indices
|
286 |
+
|
287 |
+
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
|
288 |
+
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
|
289 |
+
seq_keys = sorted(chunk2idx)
|
290 |
+
if shuffle:
|
291 |
+
perm = torch.randperm(len(seq_keys))
|
292 |
+
seq_keys = [seq_keys[i] for i in perm]
|
293 |
+
key_indices = [i for key in seq_keys for i in chunk2idx[key]]
|
294 |
+
num_workers = self.cfg["loading"][stage]["num_workers"]
|
295 |
+
loader = torchdata.DataLoader(
|
296 |
+
dataset,
|
297 |
+
batch_size=None,
|
298 |
+
sampler=key_indices,
|
299 |
+
num_workers=num_workers,
|
300 |
+
shuffle=False,
|
301 |
+
pin_memory=True,
|
302 |
+
persistent_workers=num_workers > 0,
|
303 |
+
worker_init_fn=worker_init_fn,
|
304 |
+
collate_fn=collate,
|
305 |
+
)
|
306 |
+
return loader, seq_keys, chunk2idx
|
maploc/data/kitti/prepare.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import zipfile
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
|
11 |
+
from ... import logger
|
12 |
+
from ...osm.tiling import TileManager
|
13 |
+
from ...osm.viz import GeoPlotter
|
14 |
+
from ...utils.geo import BoundaryBox, Projection
|
15 |
+
from ...utils.io import download_file, DATA_URL
|
16 |
+
from .utils import parse_gps_file
|
17 |
+
from .dataset import KittiDataModule
|
18 |
+
|
19 |
+
split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
|
20 |
+
|
21 |
+
|
22 |
+
def prepare_osm(
|
23 |
+
data_dir,
|
24 |
+
osm_path,
|
25 |
+
output_path,
|
26 |
+
tile_margin=512,
|
27 |
+
ppm=2,
|
28 |
+
):
|
29 |
+
all_latlon = []
|
30 |
+
for gps_path in data_dir.glob("2011_*/*/oxts/data/*.txt"):
|
31 |
+
all_latlon.append(parse_gps_file(gps_path)[0])
|
32 |
+
if not all_latlon:
|
33 |
+
raise ValueError(f"Cannot find any GPS file in {data_dir}.")
|
34 |
+
all_latlon = np.stack(all_latlon)
|
35 |
+
projection = Projection.from_points(all_latlon)
|
36 |
+
all_xy = projection.project(all_latlon)
|
37 |
+
bbox_map = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin
|
38 |
+
|
39 |
+
plotter = GeoPlotter()
|
40 |
+
plotter.points(all_latlon, "red", name="GPS")
|
41 |
+
plotter.bbox(projection.unproject(bbox_map), "blue", "tiling bounding box")
|
42 |
+
plotter.fig.write_html(data_dir / "split_kitti.html")
|
43 |
+
|
44 |
+
tile_manager = TileManager.from_bbox(
|
45 |
+
projection,
|
46 |
+
bbox_map,
|
47 |
+
ppm,
|
48 |
+
path=osm_path,
|
49 |
+
)
|
50 |
+
tile_manager.save(output_path)
|
51 |
+
return tile_manager
|
52 |
+
|
53 |
+
|
54 |
+
def download(data_dir: Path):
|
55 |
+
data_dir.mkdir(exist_ok=True, parents=True)
|
56 |
+
this_dir = Path(__file__).parent
|
57 |
+
|
58 |
+
seqs = set()
|
59 |
+
for f in split_files:
|
60 |
+
shutil.copy(this_dir / f, data_dir)
|
61 |
+
with open(this_dir / f, "r") as fid:
|
62 |
+
info = fid.read()
|
63 |
+
for line in info.split("\n"):
|
64 |
+
if line:
|
65 |
+
seq = line.split()[0].split("/")[1][: -len("_sync")]
|
66 |
+
seqs.add(seq)
|
67 |
+
dates = {"_".join(s.split("_")[:3]) for s in seqs}
|
68 |
+
logger.info("Downloading data for %d sequences in %d dates", len(seqs), len(dates))
|
69 |
+
|
70 |
+
for seq in tqdm(seqs):
|
71 |
+
logger.info("Working on %s.", seq)
|
72 |
+
date = "_".join(seq.split("_")[:3])
|
73 |
+
url = f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{seq}/{seq}_sync.zip"
|
74 |
+
seq_dir = data_dir / date / f"{seq}_sync"
|
75 |
+
if seq_dir.exists():
|
76 |
+
continue
|
77 |
+
zip_path = download_file(url, data_dir)
|
78 |
+
with zipfile.ZipFile(zip_path, "r") as z:
|
79 |
+
z.extractall(data_dir)
|
80 |
+
# Delete unused files to save space.
|
81 |
+
for image_index in [0, 1, 3]:
|
82 |
+
shutil.rmtree(seq_dir / f"image_0{image_index}")
|
83 |
+
shutil.rmtree(seq_dir / "velodyne_points")
|
84 |
+
Path(zip_path).unlink()
|
85 |
+
break
|
86 |
+
|
87 |
+
for date in tqdm(dates):
|
88 |
+
url = (
|
89 |
+
f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{date}_calib.zip"
|
90 |
+
)
|
91 |
+
zip_path = download_file(url, data_dir)
|
92 |
+
with zipfile.ZipFile(zip_path, "r") as z:
|
93 |
+
z.extractall(data_dir)
|
94 |
+
Path(zip_path).unlink()
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
parser = argparse.ArgumentParser()
|
99 |
+
parser.add_argument(
|
100 |
+
"--data_dir", type=Path, default=Path(KittiDataModule.default_cfg["local_dir"])
|
101 |
+
)
|
102 |
+
parser.add_argument("--pixel_per_meter", type=int, default=2)
|
103 |
+
parser.add_argument("--generate_tiles", action="store_true")
|
104 |
+
args = parser.parse_args()
|
105 |
+
|
106 |
+
args.data_dir.mkdir(exist_ok=True, parents=True)
|
107 |
+
download(args.data_dir)
|
108 |
+
|
109 |
+
tiles_path = args.data_dir / KittiDataModule.default_cfg["tiles_filename"]
|
110 |
+
if args.generate_tiles:
|
111 |
+
logger.info("Generating the map tiles.")
|
112 |
+
osm_filename = "karlsruhe.osm"
|
113 |
+
osm_path = args.data_dir / osm_filename
|
114 |
+
if not osm_path.exists():
|
115 |
+
logger.info("Downloading OSM raw data.")
|
116 |
+
download_file(DATA_URL + f"/osm/{osm_filename}", osm_path)
|
117 |
+
if not osm_path.exists():
|
118 |
+
raise FileNotFoundError(f"No OSM data file at {osm_path}.")
|
119 |
+
prepare_osm(args.data_dir, osm_path, tiles_path, ppm=args.pixel_per_meter)
|
120 |
+
(args.data_dir / ".downloaded").touch()
|
121 |
+
else:
|
122 |
+
logger.info("Downloading pre-generated map tiles.")
|
123 |
+
download_file(DATA_URL + "/tiles/kitti.pkl", tiles_path)
|
maploc/data/kitti/test1_files.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maploc/data/kitti/test2_files.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maploc/data/kitti/train_files.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maploc/data/kitti/utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from scipy.spatial.transform import Rotation
|
7 |
+
|
8 |
+
from ...utils.geo import Projection
|
9 |
+
|
10 |
+
split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
|
11 |
+
|
12 |
+
|
13 |
+
def parse_gps_file(path, projection: Projection = None):
|
14 |
+
with open(path, "r") as fid:
|
15 |
+
lat, lon, _, roll, pitch, yaw, *_ = map(float, fid.read().split())
|
16 |
+
latlon = np.array([lat, lon])
|
17 |
+
R_world_gps = Rotation.from_euler("ZYX", [yaw, pitch, roll]).as_matrix()
|
18 |
+
t_world_gps = None if projection is None else np.r_[projection.project(latlon), 0]
|
19 |
+
return latlon, R_world_gps, t_world_gps
|
20 |
+
|
21 |
+
|
22 |
+
def parse_split_file(path: Path):
|
23 |
+
with open(path, "r") as fid:
|
24 |
+
info = fid.read()
|
25 |
+
names = []
|
26 |
+
shifts = []
|
27 |
+
for line in info.split("\n"):
|
28 |
+
if not line:
|
29 |
+
continue
|
30 |
+
name, *shift = line.split()
|
31 |
+
names.append(tuple(name.split("/")))
|
32 |
+
if len(shift) > 0:
|
33 |
+
assert len(shift) == 3
|
34 |
+
shifts.append(np.array(shift, float))
|
35 |
+
shifts = None if len(shifts) == 0 else np.stack(shifts)
|
36 |
+
return names, shifts
|
37 |
+
|
38 |
+
|
39 |
+
def parse_calibration_file(path):
|
40 |
+
calib = {}
|
41 |
+
with open(path, "r") as fid:
|
42 |
+
for line in fid.read().split("\n"):
|
43 |
+
if not line:
|
44 |
+
continue
|
45 |
+
key, *data = line.split(" ")
|
46 |
+
key = key.rstrip(":")
|
47 |
+
if key.startswith("R"):
|
48 |
+
data = np.array(data, float).reshape(3, 3)
|
49 |
+
elif key.startswith("T"):
|
50 |
+
data = np.array(data, float).reshape(3)
|
51 |
+
elif key.startswith("P"):
|
52 |
+
data = np.array(data, float).reshape(3, 4)
|
53 |
+
calib[key] = data
|
54 |
+
return calib
|
55 |
+
|
56 |
+
|
57 |
+
def get_camera_calibration(calib_dir, cam_index: int):
|
58 |
+
calib_path = calib_dir / "calib_cam_to_cam.txt"
|
59 |
+
calib_cam = parse_calibration_file(calib_path)
|
60 |
+
P = calib_cam[f"P_rect_{cam_index:02}"]
|
61 |
+
K = P[:3, :3]
|
62 |
+
size = np.array(calib_cam[f"S_rect_{cam_index:02}"], float).astype(int)
|
63 |
+
camera = {
|
64 |
+
"model": "PINHOLE",
|
65 |
+
"width": size[0],
|
66 |
+
"height": size[1],
|
67 |
+
"params": K[[0, 1, 0, 1], [0, 1, 2, 2]],
|
68 |
+
}
|
69 |
+
|
70 |
+
t_cam_cam0 = P[:3, 3] / K[[0, 1, 2], [0, 1, 2]]
|
71 |
+
R_rect_cam0 = calib_cam["R_rect_00"]
|
72 |
+
|
73 |
+
calib_gps_velo = parse_calibration_file(calib_dir / "calib_imu_to_velo.txt")
|
74 |
+
calib_velo_cam0 = parse_calibration_file(calib_dir / "calib_velo_to_cam.txt")
|
75 |
+
R_cam0_gps = calib_velo_cam0["R"] @ calib_gps_velo["R"]
|
76 |
+
t_cam0_gps = calib_velo_cam0["R"] @ calib_gps_velo["T"] + calib_velo_cam0["T"]
|
77 |
+
R_cam_gps = R_rect_cam0 @ R_cam0_gps
|
78 |
+
t_cam_gps = t_cam_cam0 + R_rect_cam0 @ t_cam0_gps
|
79 |
+
return camera, R_cam_gps, t_cam_gps
|
maploc/data/mapillary/dataset.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
from collections import defaultdict
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import tarfile
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
import torch
|
14 |
+
import torch.utils.data as torchdata
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
|
17 |
+
from ... import logger, DATASETS_PATH
|
18 |
+
from ...osm.tiling import TileManager
|
19 |
+
from ..dataset import MapLocDataset
|
20 |
+
from ..sequential import chunk_sequence
|
21 |
+
from ..torch import collate, worker_init_fn
|
22 |
+
|
23 |
+
|
24 |
+
def pack_dump_dict(dump):
|
25 |
+
for per_seq in dump.values():
|
26 |
+
if "points" in per_seq:
|
27 |
+
for chunk in list(per_seq["points"]):
|
28 |
+
points = per_seq["points"].pop(chunk)
|
29 |
+
if points is not None:
|
30 |
+
per_seq["points"][chunk] = np.array(
|
31 |
+
per_seq["points"][chunk], np.float64
|
32 |
+
)
|
33 |
+
for view in per_seq["views"].values():
|
34 |
+
for k in ["R_c2w", "roll_pitch_yaw"]:
|
35 |
+
view[k] = np.array(view[k], np.float32)
|
36 |
+
for k in ["chunk_id"]:
|
37 |
+
if k in view:
|
38 |
+
view.pop(k)
|
39 |
+
if "observations" in view:
|
40 |
+
view["observations"] = np.array(view["observations"])
|
41 |
+
for camera in per_seq["cameras"].values():
|
42 |
+
for k in ["params"]:
|
43 |
+
camera[k] = np.array(camera[k], np.float32)
|
44 |
+
return dump
|
45 |
+
|
46 |
+
|
47 |
+
class MapillaryDataModule(pl.LightningDataModule):
|
48 |
+
dump_filename = "dump.json"
|
49 |
+
images_archive = "images.tar.gz"
|
50 |
+
images_dirname = "images/"
|
51 |
+
|
52 |
+
default_cfg = {
|
53 |
+
**MapLocDataset.default_cfg,
|
54 |
+
"name": "mapillary",
|
55 |
+
# paths and fetch
|
56 |
+
"data_dir": DATASETS_PATH / "MGL",
|
57 |
+
"local_dir": None,
|
58 |
+
"tiles_filename": "tiles.pkl",
|
59 |
+
"scenes": "???",
|
60 |
+
"split": None,
|
61 |
+
"loading": {
|
62 |
+
"train": "???",
|
63 |
+
"val": "${.test}",
|
64 |
+
"test": {"batch_size": 1, "num_workers": 0},
|
65 |
+
},
|
66 |
+
"filter_for": None,
|
67 |
+
"filter_by_ground_angle": None,
|
68 |
+
"min_num_points": "???",
|
69 |
+
}
|
70 |
+
|
71 |
+
def __init__(self, cfg: Dict[str, Any]):
|
72 |
+
super().__init__()
|
73 |
+
default_cfg = OmegaConf.create(self.default_cfg)
|
74 |
+
OmegaConf.set_struct(default_cfg, True) # cannot add new keys
|
75 |
+
self.cfg = OmegaConf.merge(default_cfg, cfg)
|
76 |
+
self.root = Path(self.cfg.data_dir)
|
77 |
+
self.local_dir = self.cfg.local_dir or os.environ.get("TMPDIR")
|
78 |
+
if self.local_dir is not None:
|
79 |
+
self.local_dir = Path(self.local_dir, "MGL")
|
80 |
+
if self.cfg.crop_size_meters < self.cfg.max_init_error:
|
81 |
+
raise ValueError("The ground truth location can be outside the map.")
|
82 |
+
|
83 |
+
def prepare_data(self):
|
84 |
+
for scene in self.cfg.scenes:
|
85 |
+
dump_dir = self.root / scene
|
86 |
+
assert (dump_dir / self.dump_filename).exists(), dump_dir
|
87 |
+
assert (dump_dir / self.cfg.tiles_filename).exists(), dump_dir
|
88 |
+
if self.local_dir is None:
|
89 |
+
assert (dump_dir / self.images_dirname).exists(), dump_dir
|
90 |
+
continue
|
91 |
+
# Cache the folder of images locally to speed up reading
|
92 |
+
local_dir = self.local_dir / scene
|
93 |
+
if local_dir.exists():
|
94 |
+
shutil.rmtree(local_dir)
|
95 |
+
local_dir.mkdir(exist_ok=True, parents=True)
|
96 |
+
images_archive = dump_dir / self.images_archive
|
97 |
+
logger.info("Extracting the image archive %s.", images_archive)
|
98 |
+
with tarfile.open(images_archive) as fp:
|
99 |
+
fp.extractall(local_dir)
|
100 |
+
|
101 |
+
def setup(self, stage: Optional[str] = None):
|
102 |
+
self.dumps = {}
|
103 |
+
self.tile_managers = {}
|
104 |
+
self.image_dirs = {}
|
105 |
+
names = []
|
106 |
+
|
107 |
+
for scene in self.cfg.scenes:
|
108 |
+
logger.info("Loading scene %s.", scene)
|
109 |
+
dump_dir = self.root / scene
|
110 |
+
logger.info("Loading map tiles %s.", self.cfg.tiles_filename)
|
111 |
+
self.tile_managers[scene] = TileManager.load(
|
112 |
+
dump_dir / self.cfg.tiles_filename
|
113 |
+
)
|
114 |
+
groups = self.tile_managers[scene].groups
|
115 |
+
if self.cfg.num_classes: # check consistency
|
116 |
+
if set(groups.keys()) != set(self.cfg.num_classes.keys()):
|
117 |
+
raise ValueError(
|
118 |
+
f"Inconsistent groups: {groups.keys()} {self.cfg.num_classes.keys()}"
|
119 |
+
)
|
120 |
+
for k in groups:
|
121 |
+
if len(groups[k]) != self.cfg.num_classes[k]:
|
122 |
+
raise ValueError(
|
123 |
+
f"{k}: {len(groups[k])} vs {self.cfg.num_classes[k]}"
|
124 |
+
)
|
125 |
+
ppm = self.tile_managers[scene].ppm
|
126 |
+
if ppm != self.cfg.pixel_per_meter:
|
127 |
+
raise ValueError(
|
128 |
+
"The tile manager and the config/model have different ground resolutions: "
|
129 |
+
f"{ppm} vs {self.cfg.pixel_per_meter}"
|
130 |
+
)
|
131 |
+
|
132 |
+
logger.info("Loading dump json file %s.", self.dump_filename)
|
133 |
+
with (dump_dir / self.dump_filename).open("r") as fp:
|
134 |
+
self.dumps[scene] = pack_dump_dict(json.load(fp))
|
135 |
+
for seq, per_seq in self.dumps[scene].items():
|
136 |
+
for cam_id, cam_dict in per_seq["cameras"].items():
|
137 |
+
if cam_dict["model"] != "PINHOLE":
|
138 |
+
raise ValueError(
|
139 |
+
f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
|
140 |
+
)
|
141 |
+
|
142 |
+
self.image_dirs[scene] = (
|
143 |
+
(self.local_dir or self.root) / scene / self.images_dirname
|
144 |
+
)
|
145 |
+
assert self.image_dirs[scene].exists(), self.image_dirs[scene]
|
146 |
+
|
147 |
+
for seq, data in self.dumps[scene].items():
|
148 |
+
for name in data["views"]:
|
149 |
+
names.append((scene, seq, name))
|
150 |
+
|
151 |
+
self.parse_splits(self.cfg.split, names)
|
152 |
+
if self.cfg.filter_for is not None:
|
153 |
+
self.filter_elements()
|
154 |
+
self.pack_data()
|
155 |
+
|
156 |
+
def pack_data(self):
|
157 |
+
# We pack the data into compact tensors that can be shared across processes without copy
|
158 |
+
exclude = {
|
159 |
+
"compass_angle",
|
160 |
+
"compass_accuracy",
|
161 |
+
"gps_accuracy",
|
162 |
+
"chunk_key",
|
163 |
+
"panorama_offset",
|
164 |
+
}
|
165 |
+
cameras = {
|
166 |
+
scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()}
|
167 |
+
for scene, per_scene in self.dumps.items()
|
168 |
+
}
|
169 |
+
points = {
|
170 |
+
scene: {
|
171 |
+
seq: {
|
172 |
+
i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items()
|
173 |
+
}
|
174 |
+
for seq, per_seq in per_scene.items()
|
175 |
+
}
|
176 |
+
for scene, per_scene in self.dumps.items()
|
177 |
+
}
|
178 |
+
self.data = {}
|
179 |
+
for stage, names in self.splits.items():
|
180 |
+
view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]]
|
181 |
+
data = {k: [] for k in view.keys() - exclude}
|
182 |
+
for scene, seq, name in names:
|
183 |
+
for k in data:
|
184 |
+
data[k].append(self.dumps[scene][seq]["views"][name].get(k, None))
|
185 |
+
for k in data:
|
186 |
+
v = np.array(data[k])
|
187 |
+
if np.issubdtype(v.dtype, np.integer) or np.issubdtype(
|
188 |
+
v.dtype, np.floating
|
189 |
+
):
|
190 |
+
v = torch.from_numpy(v)
|
191 |
+
data[k] = v
|
192 |
+
data["cameras"] = cameras
|
193 |
+
data["points"] = points
|
194 |
+
self.data[stage] = data
|
195 |
+
self.splits[stage] = np.array(names)
|
196 |
+
|
197 |
+
def filter_elements(self):
|
198 |
+
for stage, names in self.splits.items():
|
199 |
+
names_select = []
|
200 |
+
for scene, seq, name in names:
|
201 |
+
view = self.dumps[scene][seq]["views"][name]
|
202 |
+
if self.cfg.filter_for == "ground_plane":
|
203 |
+
if not (1.0 <= view["height"] <= 3.0):
|
204 |
+
continue
|
205 |
+
planes = self.dumps[scene][seq].get("plane")
|
206 |
+
if planes is not None:
|
207 |
+
inliers = planes[str(view["chunk_id"])][-1]
|
208 |
+
if inliers < 10:
|
209 |
+
continue
|
210 |
+
if self.cfg.filter_by_ground_angle is not None:
|
211 |
+
plane = np.array(view["plane_params"])
|
212 |
+
normal = plane[:3] / np.linalg.norm(plane[:3])
|
213 |
+
angle = np.rad2deg(np.arccos(np.abs(normal[-1])))
|
214 |
+
if angle > self.cfg.filter_by_ground_angle:
|
215 |
+
continue
|
216 |
+
elif self.cfg.filter_for == "pointcloud":
|
217 |
+
if len(view["observations"]) < self.cfg.min_num_points:
|
218 |
+
continue
|
219 |
+
elif self.cfg.filter_for is not None:
|
220 |
+
raise ValueError(f"Unknown filtering: {self.cfg.filter_for}")
|
221 |
+
names_select.append((scene, seq, name))
|
222 |
+
logger.info(
|
223 |
+
"%s: Keep %d/%d images after filtering for %s.",
|
224 |
+
stage,
|
225 |
+
len(names_select),
|
226 |
+
len(names),
|
227 |
+
self.cfg.filter_for,
|
228 |
+
)
|
229 |
+
self.splits[stage] = names_select
|
230 |
+
|
231 |
+
def parse_splits(self, split_arg, names):
|
232 |
+
if split_arg is None:
|
233 |
+
self.splits = {
|
234 |
+
"train": names,
|
235 |
+
"val": names,
|
236 |
+
}
|
237 |
+
elif isinstance(split_arg, int):
|
238 |
+
names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
|
239 |
+
self.splits = {
|
240 |
+
"train": names[split_arg:],
|
241 |
+
"val": names[:split_arg],
|
242 |
+
}
|
243 |
+
elif isinstance(split_arg, DictConfig):
|
244 |
+
scenes_val = set(split_arg.val)
|
245 |
+
scenes_train = set(split_arg.train)
|
246 |
+
assert len(scenes_val - set(self.cfg.scenes)) == 0
|
247 |
+
assert len(scenes_train - set(self.cfg.scenes)) == 0
|
248 |
+
self.splits = {
|
249 |
+
"train": [n for n in names if n[0] in scenes_train],
|
250 |
+
"val": [n for n in names if n[0] in scenes_val],
|
251 |
+
}
|
252 |
+
elif isinstance(split_arg, str):
|
253 |
+
with (self.root / split_arg).open("r") as fp:
|
254 |
+
splits = json.load(fp)
|
255 |
+
splits = {
|
256 |
+
k: {loc: set(ids) for loc, ids in split.items()}
|
257 |
+
for k, split in splits.items()
|
258 |
+
}
|
259 |
+
self.splits = {}
|
260 |
+
for k, split in splits.items():
|
261 |
+
self.splits[k] = [
|
262 |
+
n
|
263 |
+
for n in names
|
264 |
+
if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
|
265 |
+
]
|
266 |
+
else:
|
267 |
+
raise ValueError(split_arg)
|
268 |
+
|
269 |
+
def dataset(self, stage: str):
|
270 |
+
return MapLocDataset(
|
271 |
+
stage,
|
272 |
+
self.cfg,
|
273 |
+
self.splits[stage],
|
274 |
+
self.data[stage],
|
275 |
+
self.image_dirs,
|
276 |
+
self.tile_managers,
|
277 |
+
image_ext=".jpg",
|
278 |
+
)
|
279 |
+
|
280 |
+
def dataloader(
|
281 |
+
self,
|
282 |
+
stage: str,
|
283 |
+
shuffle: bool = False,
|
284 |
+
num_workers: int = None,
|
285 |
+
sampler: Optional[torchdata.Sampler] = None,
|
286 |
+
):
|
287 |
+
dataset = self.dataset(stage)
|
288 |
+
cfg = self.cfg["loading"][stage]
|
289 |
+
num_workers = cfg["num_workers"] if num_workers is None else num_workers
|
290 |
+
loader = torchdata.DataLoader(
|
291 |
+
dataset,
|
292 |
+
batch_size=cfg["batch_size"],
|
293 |
+
num_workers=num_workers,
|
294 |
+
shuffle=shuffle or (stage == "train"),
|
295 |
+
pin_memory=True,
|
296 |
+
persistent_workers=num_workers > 0,
|
297 |
+
worker_init_fn=worker_init_fn,
|
298 |
+
collate_fn=collate,
|
299 |
+
sampler=sampler,
|
300 |
+
)
|
301 |
+
return loader
|
302 |
+
|
303 |
+
def train_dataloader(self, **kwargs):
|
304 |
+
return self.dataloader("train", **kwargs)
|
305 |
+
|
306 |
+
def val_dataloader(self, **kwargs):
|
307 |
+
return self.dataloader("val", **kwargs)
|
308 |
+
|
309 |
+
def test_dataloader(self, **kwargs):
|
310 |
+
return self.dataloader("test", **kwargs)
|
311 |
+
|
312 |
+
def sequence_dataset(self, stage: str, **kwargs):
|
313 |
+
keys = self.splits[stage]
|
314 |
+
seq2indices = defaultdict(list)
|
315 |
+
for index, (_, seq, _) in enumerate(keys):
|
316 |
+
seq2indices[seq].append(index)
|
317 |
+
# chunk the sequences to the required length
|
318 |
+
chunk2indices = {}
|
319 |
+
for seq, indices in seq2indices.items():
|
320 |
+
chunks = chunk_sequence(self.data[stage], indices, **kwargs)
|
321 |
+
for i, sub_indices in enumerate(chunks):
|
322 |
+
chunk2indices[seq, i] = sub_indices
|
323 |
+
# store the index of each chunk in its sequence
|
324 |
+
chunk_indices = torch.full((len(keys),), -1)
|
325 |
+
for (_, chunk_index), idx in chunk2indices.items():
|
326 |
+
chunk_indices[idx] = chunk_index
|
327 |
+
self.data[stage]["chunk_index"] = chunk_indices
|
328 |
+
dataset = self.dataset(stage)
|
329 |
+
return dataset, chunk2indices
|
330 |
+
|
331 |
+
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
|
332 |
+
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
|
333 |
+
chunk_keys = sorted(chunk2idx)
|
334 |
+
if shuffle:
|
335 |
+
perm = torch.randperm(len(chunk_keys))
|
336 |
+
chunk_keys = [chunk_keys[i] for i in perm]
|
337 |
+
key_indices = [i for key in chunk_keys for i in chunk2idx[key]]
|
338 |
+
num_workers = self.cfg["loading"][stage]["num_workers"]
|
339 |
+
loader = torchdata.DataLoader(
|
340 |
+
dataset,
|
341 |
+
batch_size=None,
|
342 |
+
sampler=key_indices,
|
343 |
+
num_workers=num_workers,
|
344 |
+
shuffle=False,
|
345 |
+
pin_memory=True,
|
346 |
+
persistent_workers=num_workers > 0,
|
347 |
+
worker_init_fn=worker_init_fn,
|
348 |
+
collate_fn=collate,
|
349 |
+
)
|
350 |
+
return loader, chunk_keys, chunk2idx
|
maploc/data/mapillary/download.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import httpx
|
8 |
+
import asyncio
|
9 |
+
from aiolimiter import AsyncLimiter
|
10 |
+
import tqdm
|
11 |
+
|
12 |
+
from opensfm.pygeometry import Camera, Pose
|
13 |
+
from opensfm.pymap import Shot
|
14 |
+
|
15 |
+
from ... import logger
|
16 |
+
from ...utils.geo import Projection
|
17 |
+
|
18 |
+
|
19 |
+
semaphore = asyncio.Semaphore(100) # number of parallel threads.
|
20 |
+
image_filename = "{image_id}.jpg"
|
21 |
+
info_filename = "{image_id}.json"
|
22 |
+
|
23 |
+
|
24 |
+
class MapillaryDownloader:
|
25 |
+
image_fields = (
|
26 |
+
"id",
|
27 |
+
"height",
|
28 |
+
"width",
|
29 |
+
"camera_parameters",
|
30 |
+
"camera_type",
|
31 |
+
"captured_at",
|
32 |
+
"compass_angle",
|
33 |
+
"geometry",
|
34 |
+
"altitude",
|
35 |
+
"computed_compass_angle",
|
36 |
+
"computed_geometry",
|
37 |
+
"computed_altitude",
|
38 |
+
"computed_rotation",
|
39 |
+
"thumb_2048_url",
|
40 |
+
"thumb_original_url",
|
41 |
+
"sequence",
|
42 |
+
"sfm_cluster",
|
43 |
+
)
|
44 |
+
image_info_url = (
|
45 |
+
"https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
|
46 |
+
)
|
47 |
+
seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}"
|
48 |
+
max_requests_per_minute = 50_000
|
49 |
+
|
50 |
+
def __init__(self, token: str):
|
51 |
+
self.token = token
|
52 |
+
self.client = httpx.AsyncClient(
|
53 |
+
transport=httpx.AsyncHTTPTransport(retries=20), timeout=20.0
|
54 |
+
)
|
55 |
+
self.limiter = AsyncLimiter(self.max_requests_per_minute // 2, time_period=60)
|
56 |
+
|
57 |
+
async def call_api(self, url: str):
|
58 |
+
async with self.limiter:
|
59 |
+
r = await self.client.get(url)
|
60 |
+
if not r.is_success:
|
61 |
+
logger.error("Error in API call: %s", r.text)
|
62 |
+
return r
|
63 |
+
|
64 |
+
async def get_image_info(self, image_id: int):
|
65 |
+
url = self.image_info_url.format(
|
66 |
+
image_id=image_id,
|
67 |
+
token=self.token,
|
68 |
+
fields=",".join(self.image_fields),
|
69 |
+
)
|
70 |
+
r = await self.call_api(url)
|
71 |
+
if r.is_success:
|
72 |
+
return json.loads(r.text)
|
73 |
+
|
74 |
+
async def get_sequence_info(self, seq_id: str):
|
75 |
+
url = self.seq_info_url.format(seq_id=seq_id, token=self.token)
|
76 |
+
r = await self.call_api(url)
|
77 |
+
if r.is_success:
|
78 |
+
return json.loads(r.text)
|
79 |
+
|
80 |
+
async def download_image_pixels(self, url: str, path: Path):
|
81 |
+
r = await self.call_api(url)
|
82 |
+
if r.is_success:
|
83 |
+
with open(path, "wb") as fid:
|
84 |
+
fid.write(r.content)
|
85 |
+
return r.is_success
|
86 |
+
|
87 |
+
async def get_image_info_cached(self, image_id: int, path: Path):
|
88 |
+
if path.exists():
|
89 |
+
info = json.loads(path.read_text())
|
90 |
+
else:
|
91 |
+
info = await self.get_image_info(image_id)
|
92 |
+
path.write_text(json.dumps(info))
|
93 |
+
return info
|
94 |
+
|
95 |
+
async def download_image_pixels_cached(self, url: str, path: Path):
|
96 |
+
if path.exists():
|
97 |
+
return True
|
98 |
+
else:
|
99 |
+
return await self.download_image_pixels(url, path)
|
100 |
+
|
101 |
+
|
102 |
+
async def fetch_images_in_sequence(i, downloader):
|
103 |
+
async with semaphore:
|
104 |
+
info = await downloader.get_sequence_info(i)
|
105 |
+
image_ids = [int(d["id"]) for d in info["data"]]
|
106 |
+
return i, image_ids
|
107 |
+
|
108 |
+
|
109 |
+
async def fetch_images_in_sequences(sequence_ids, downloader):
|
110 |
+
seq_to_images_ids = {}
|
111 |
+
tasks = [fetch_images_in_sequence(i, downloader) for i in sequence_ids]
|
112 |
+
for task in tqdm.asyncio.tqdm.as_completed(tasks):
|
113 |
+
i, image_ids = await task
|
114 |
+
seq_to_images_ids[i] = image_ids
|
115 |
+
return seq_to_images_ids
|
116 |
+
|
117 |
+
|
118 |
+
async def fetch_image_info(i, downloader, dir_):
|
119 |
+
async with semaphore:
|
120 |
+
path = dir_ / info_filename.format(image_id=i)
|
121 |
+
info = await downloader.get_image_info_cached(i, path)
|
122 |
+
return i, info
|
123 |
+
|
124 |
+
|
125 |
+
async def fetch_image_infos(image_ids, downloader, dir_):
|
126 |
+
infos = {}
|
127 |
+
num_fail = 0
|
128 |
+
tasks = [fetch_image_info(i, downloader, dir_) for i in image_ids]
|
129 |
+
for task in tqdm.asyncio.tqdm.as_completed(tasks):
|
130 |
+
i, info = await task
|
131 |
+
if info is None:
|
132 |
+
num_fail += 1
|
133 |
+
else:
|
134 |
+
infos[i] = info
|
135 |
+
return infos, num_fail
|
136 |
+
|
137 |
+
|
138 |
+
async def fetch_image_pixels(i, url, downloader, dir_, overwrite=False):
|
139 |
+
async with semaphore:
|
140 |
+
path = dir_ / image_filename.format(image_id=i)
|
141 |
+
if overwrite:
|
142 |
+
path.unlink(missing_ok=True)
|
143 |
+
success = await downloader.download_image_pixels_cached(url, path)
|
144 |
+
return i, success
|
145 |
+
|
146 |
+
|
147 |
+
async def fetch_images_pixels(image_urls, downloader, dir_):
|
148 |
+
num_fail = 0
|
149 |
+
tasks = [fetch_image_pixels(*id_url, downloader, dir_) for id_url in image_urls]
|
150 |
+
for task in tqdm.asyncio.tqdm.as_completed(tasks):
|
151 |
+
i, success = await task
|
152 |
+
num_fail += not success
|
153 |
+
return num_fail
|
154 |
+
|
155 |
+
|
156 |
+
def opensfm_camera_from_info(info: dict) -> Camera:
|
157 |
+
cam_type = info["camera_type"]
|
158 |
+
if cam_type == "perspective":
|
159 |
+
camera = Camera.create_perspective(*info["camera_parameters"])
|
160 |
+
elif cam_type == "fisheye":
|
161 |
+
camera = Camera.create_fisheye(*info["camera_parameters"])
|
162 |
+
elif Camera.is_panorama(cam_type):
|
163 |
+
camera = Camera.create_spherical()
|
164 |
+
else:
|
165 |
+
raise ValueError(cam_type)
|
166 |
+
camera.width = info["width"]
|
167 |
+
camera.height = info["height"]
|
168 |
+
camera.id = info["id"]
|
169 |
+
return camera
|
170 |
+
|
171 |
+
|
172 |
+
def opensfm_shot_from_info(info: dict, projection: Projection) -> Shot:
|
173 |
+
latlong = info["computed_geometry"]["coordinates"][::-1]
|
174 |
+
alt = info["computed_altitude"]
|
175 |
+
xyz = projection.project(np.array([*latlong, alt]), return_z=True)
|
176 |
+
c_rotvec_w = np.array(info["computed_rotation"])
|
177 |
+
pose = Pose()
|
178 |
+
pose.set_from_cam_to_world(-c_rotvec_w, xyz)
|
179 |
+
camera = opensfm_camera_from_info(info)
|
180 |
+
return latlong, Shot(info["id"], camera, pose)
|
maploc/data/mapillary/prepare.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import asyncio
|
4 |
+
import argparse
|
5 |
+
from collections import defaultdict
|
6 |
+
import json
|
7 |
+
import shutil
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
from tqdm import tqdm
|
14 |
+
from tqdm.contrib.concurrent import thread_map
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
from opensfm.pygeometry import Camera
|
17 |
+
from opensfm.pymap import Shot
|
18 |
+
from opensfm.undistort import (
|
19 |
+
perspective_camera_from_fisheye,
|
20 |
+
perspective_camera_from_perspective,
|
21 |
+
)
|
22 |
+
|
23 |
+
from ... import logger
|
24 |
+
from ...osm.tiling import TileManager
|
25 |
+
from ...osm.viz import GeoPlotter
|
26 |
+
from ...utils.geo import BoundaryBox, Projection
|
27 |
+
from ...utils.io import write_json, download_file, DATA_URL
|
28 |
+
from ..utils import decompose_rotmat
|
29 |
+
from .utils import (
|
30 |
+
keyframe_selection,
|
31 |
+
perspective_camera_from_pano,
|
32 |
+
scale_camera,
|
33 |
+
CameraUndistorter,
|
34 |
+
PanoramaUndistorter,
|
35 |
+
undistort_shot,
|
36 |
+
)
|
37 |
+
from .download import (
|
38 |
+
MapillaryDownloader,
|
39 |
+
opensfm_shot_from_info,
|
40 |
+
image_filename,
|
41 |
+
fetch_image_infos,
|
42 |
+
fetch_images_pixels,
|
43 |
+
)
|
44 |
+
from .dataset import MapillaryDataModule
|
45 |
+
|
46 |
+
|
47 |
+
location_to_params = {
|
48 |
+
"sanfrancisco_soma": {
|
49 |
+
"bbox": BoundaryBox(
|
50 |
+
[-122.410307, 37.770364][::-1], [-122.388772, 37.795545][::-1]
|
51 |
+
),
|
52 |
+
"camera_models": ["GoPro Max"],
|
53 |
+
"osm_file": "sanfrancisco.osm",
|
54 |
+
},
|
55 |
+
"sanfrancisco_hayes": {
|
56 |
+
"bbox": BoundaryBox(
|
57 |
+
[-122.438415, 37.768634][::-1], [-122.410605, 37.783894][::-1]
|
58 |
+
),
|
59 |
+
"camera_models": ["GoPro Max"],
|
60 |
+
"osm_file": "sanfrancisco.osm",
|
61 |
+
},
|
62 |
+
"amsterdam": {
|
63 |
+
"bbox": BoundaryBox([4.845284, 52.340679][::-1], [4.926147, 52.386299][::-1]),
|
64 |
+
"camera_models": ["GoPro Max"],
|
65 |
+
"osm_file": "amsterdam.osm",
|
66 |
+
},
|
67 |
+
"lemans": {
|
68 |
+
"bbox": BoundaryBox([0.185752, 47.995125][::-1], [0.224088, 48.014209][::-1]),
|
69 |
+
"owners": ["xXOocM1jUB4jaaeukKkmgw"], # sogefi
|
70 |
+
"osm_file": "lemans.osm",
|
71 |
+
},
|
72 |
+
"berlin": {
|
73 |
+
"bbox": BoundaryBox([13.416271, 52.459656][::-1], [13.469829, 52.499195][::-1]),
|
74 |
+
"owners": ["LT3ajUxH6qsosamrOHIrFw"], # supaplex030
|
75 |
+
"osm_file": "berlin.osm",
|
76 |
+
},
|
77 |
+
"montrouge": {
|
78 |
+
"bbox": BoundaryBox([2.298958, 48.80874][::-1], [2.332989, 48.825276][::-1]),
|
79 |
+
"owners": [
|
80 |
+
"XtzGKZX2_VIJRoiJ8IWRNQ",
|
81 |
+
"C4ENdWpJdFNf8CvnQd7NrQ",
|
82 |
+
"e_ZBE6mFd7CYNjRSpLl-Lg",
|
83 |
+
], # overflorian, phyks, francois2
|
84 |
+
"camera_models": ["LG-R105"],
|
85 |
+
"osm_file": "paris.osm",
|
86 |
+
},
|
87 |
+
"nantes": {
|
88 |
+
"bbox": BoundaryBox([-1.585839, 47.198289][::-1], [-1.51318, 47.236161][::-1]),
|
89 |
+
"owners": [
|
90 |
+
"jGdq3CL-9N-Esvj3mtCWew",
|
91 |
+
"s-j5BH9JRIzsgORgaJF3aA",
|
92 |
+
], # c_mobilite, cartocite
|
93 |
+
"osm_file": "nantes.osm",
|
94 |
+
},
|
95 |
+
"toulouse": {
|
96 |
+
"bbox": BoundaryBox([1.429457, 43.591434][::-1], [1.456653, 43.61343][::-1]),
|
97 |
+
"owners": ["MNkhq6MCoPsdQNGTMh3qsQ"], # tyndare
|
98 |
+
"osm_file": "toulouse.osm",
|
99 |
+
},
|
100 |
+
"vilnius": {
|
101 |
+
"bbox": BoundaryBox([25.258633, 54.672956][::-1], [25.296094, 54.696755][::-1]),
|
102 |
+
"owners": ["bClduFF6Gq16cfwCdhWivw", "u5ukBseATUS8jUbtE43fcO"], # kedas, vms
|
103 |
+
"osm_file": "vilnius.osm",
|
104 |
+
},
|
105 |
+
"helsinki": {
|
106 |
+
"bbox": BoundaryBox(
|
107 |
+
[24.8975480117, 60.1449128318][::-1], [24.9816543235, 60.1770977471][::-1]
|
108 |
+
),
|
109 |
+
"camera_types": ["spherical", "equirectangular"],
|
110 |
+
"osm_file": "helsinki.osm",
|
111 |
+
},
|
112 |
+
"milan": {
|
113 |
+
"bbox": BoundaryBox(
|
114 |
+
[9.1732723899, 45.4810977947][::-1],
|
115 |
+
[9.2255987917, 45.5284238563][::-1],
|
116 |
+
),
|
117 |
+
"camera_types": ["spherical", "equirectangular"],
|
118 |
+
"osm_file": "milan.osm",
|
119 |
+
},
|
120 |
+
"avignon": {
|
121 |
+
"bbox": BoundaryBox(
|
122 |
+
[4.7887045302, 43.9416178156][::-1], [4.8227015622, 43.9584848909][::-1]
|
123 |
+
),
|
124 |
+
"camera_types": ["spherical", "equirectangular"],
|
125 |
+
"osm_file": "avignon.osm",
|
126 |
+
},
|
127 |
+
"paris": {
|
128 |
+
"bbox": BoundaryBox([2.306823, 48.833827][::-1], [2.39067, 48.889335][::-1]),
|
129 |
+
"camera_types": ["spherical", "equirectangular"],
|
130 |
+
"osm_file": "paris.osm",
|
131 |
+
},
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
cfg = OmegaConf.create(
|
136 |
+
{
|
137 |
+
"max_image_size": 512,
|
138 |
+
"do_legacy_pano_offset": True,
|
139 |
+
"min_dist_between_keyframes": 4,
|
140 |
+
"tiling": {
|
141 |
+
"tile_size": 128,
|
142 |
+
"margin": 128,
|
143 |
+
"ppm": 2,
|
144 |
+
},
|
145 |
+
}
|
146 |
+
)
|
147 |
+
|
148 |
+
|
149 |
+
def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float:
|
150 |
+
if do_legacy:
|
151 |
+
seed = int(image_info["sfm_cluster"]["id"])
|
152 |
+
else:
|
153 |
+
seed = image_info["sequence"].__hash__()
|
154 |
+
seed = seed % (2**32 - 1)
|
155 |
+
return np.random.RandomState(seed).uniform(-45, 45)
|
156 |
+
|
157 |
+
|
158 |
+
def process_shot(
|
159 |
+
shot: Shot, info: dict, image_path: Path, output_dir: Path, cfg: DictConfig
|
160 |
+
) -> List[Shot]:
|
161 |
+
if not image_path.exists():
|
162 |
+
return None
|
163 |
+
|
164 |
+
image_orig = cv2.imread(str(image_path))
|
165 |
+
max_size = cfg.max_image_size
|
166 |
+
pano_offset = None
|
167 |
+
|
168 |
+
camera = shot.camera
|
169 |
+
camera.width, camera.height = image_orig.shape[:2][::-1]
|
170 |
+
if camera.is_panorama(camera.projection_type):
|
171 |
+
camera_new = perspective_camera_from_pano(camera, max_size)
|
172 |
+
undistorter = PanoramaUndistorter(camera, camera_new)
|
173 |
+
pano_offset = get_pano_offset(info, cfg.do_legacy_pano_offset)
|
174 |
+
elif camera.projection_type in ["fisheye", "perspective"]:
|
175 |
+
if camera.projection_type == "fisheye":
|
176 |
+
camera_new = perspective_camera_from_fisheye(camera)
|
177 |
+
else:
|
178 |
+
camera_new = perspective_camera_from_perspective(camera)
|
179 |
+
camera_new = scale_camera(camera_new, max_size)
|
180 |
+
camera_new.id = camera.id + "_undistorted"
|
181 |
+
undistorter = CameraUndistorter(camera, camera_new)
|
182 |
+
else:
|
183 |
+
raise NotImplementedError(camera.projection_type)
|
184 |
+
|
185 |
+
shots_undist, images_undist = undistort_shot(
|
186 |
+
image_orig, shot, undistorter, pano_offset
|
187 |
+
)
|
188 |
+
for shot, image in zip(shots_undist, images_undist):
|
189 |
+
cv2.imwrite(str(output_dir / f"{shot.id}.jpg"), image)
|
190 |
+
|
191 |
+
return shots_undist
|
192 |
+
|
193 |
+
|
194 |
+
def pack_shot_dict(shot: Shot, info: dict) -> dict:
|
195 |
+
latlong = info["computed_geometry"]["coordinates"][::-1]
|
196 |
+
latlong_gps = info["geometry"]["coordinates"][::-1]
|
197 |
+
w_p_c = shot.pose.get_origin()
|
198 |
+
w_r_c = shot.pose.get_R_cam_to_world()
|
199 |
+
rpy = decompose_rotmat(w_r_c)
|
200 |
+
return dict(
|
201 |
+
camera_id=shot.camera.id,
|
202 |
+
latlong=latlong,
|
203 |
+
t_c2w=w_p_c,
|
204 |
+
R_c2w=w_r_c,
|
205 |
+
roll_pitch_yaw=rpy,
|
206 |
+
capture_time=info["captured_at"],
|
207 |
+
gps_position=np.r_[latlong_gps, info["altitude"]],
|
208 |
+
compass_angle=info["compass_angle"],
|
209 |
+
chunk_id=int(info["sfm_cluster"]["id"]),
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def pack_camera_dict(camera: Camera) -> dict:
|
214 |
+
assert camera.projection_type == "perspective"
|
215 |
+
K = camera.get_K_in_pixel_coordinates(camera.width, camera.height)
|
216 |
+
return dict(
|
217 |
+
id=camera.id,
|
218 |
+
model="PINHOLE",
|
219 |
+
width=camera.width,
|
220 |
+
height=camera.height,
|
221 |
+
params=K[[0, 1, 0, 1], [0, 1, 2, 2]],
|
222 |
+
)
|
223 |
+
|
224 |
+
|
225 |
+
def process_sequence(
|
226 |
+
image_ids: List[int],
|
227 |
+
image_infos: dict,
|
228 |
+
projection: Projection,
|
229 |
+
cfg: DictConfig,
|
230 |
+
raw_image_dir: Path,
|
231 |
+
out_image_dir: Path,
|
232 |
+
):
|
233 |
+
shots = []
|
234 |
+
image_ids = sorted(image_ids, key=lambda i: image_infos[i]["captured_at"])
|
235 |
+
for i in image_ids:
|
236 |
+
_, shot = opensfm_shot_from_info(image_infos[i], projection)
|
237 |
+
shots.append(shot)
|
238 |
+
if not shots:
|
239 |
+
return {}
|
240 |
+
|
241 |
+
shot_idxs = keyframe_selection(shots, min_dist=cfg.min_dist_between_keyframes)
|
242 |
+
shots = [shots[i] for i in shot_idxs]
|
243 |
+
|
244 |
+
shots_out = thread_map(
|
245 |
+
lambda shot: process_shot(
|
246 |
+
shot,
|
247 |
+
image_infos[int(shot.id)],
|
248 |
+
raw_image_dir / image_filename.format(image_id=shot.id),
|
249 |
+
out_image_dir,
|
250 |
+
cfg,
|
251 |
+
),
|
252 |
+
shots,
|
253 |
+
disable=True,
|
254 |
+
)
|
255 |
+
shots_out = [(i, s) for i, ss in enumerate(shots_out) for s in ss if ss is not None]
|
256 |
+
|
257 |
+
dump = {}
|
258 |
+
for index, shot in shots_out:
|
259 |
+
i, suffix = shot.id.rsplit("_", 1)
|
260 |
+
info = image_infos[int(i)]
|
261 |
+
seq_id = info["sequence"]
|
262 |
+
is_pano = not suffix.endswith("undistorted")
|
263 |
+
if is_pano:
|
264 |
+
seq_id += f"_{suffix}"
|
265 |
+
if seq_id not in dump:
|
266 |
+
dump[seq_id] = dict(views={}, cameras={})
|
267 |
+
|
268 |
+
view = pack_shot_dict(shot, info)
|
269 |
+
view["index"] = index
|
270 |
+
dump[seq_id]["views"][shot.id] = view
|
271 |
+
dump[seq_id]["cameras"][shot.camera.id] = pack_camera_dict(shot.camera)
|
272 |
+
return dump
|
273 |
+
|
274 |
+
|
275 |
+
def process_location(
|
276 |
+
location: str,
|
277 |
+
data_dir: Path,
|
278 |
+
split_path: Path,
|
279 |
+
token: str,
|
280 |
+
generate_tiles: bool = False,
|
281 |
+
):
|
282 |
+
params = location_to_params[location]
|
283 |
+
bbox = params["bbox"]
|
284 |
+
projection = Projection(*bbox.center)
|
285 |
+
|
286 |
+
splits = json.loads(split_path.read_text())
|
287 |
+
image_ids = [i for split in splits.values() for i in split[location]]
|
288 |
+
|
289 |
+
loc_dir = data_dir / location
|
290 |
+
infos_dir = loc_dir / "image_infos"
|
291 |
+
raw_image_dir = loc_dir / "images_raw"
|
292 |
+
out_image_dir = loc_dir / "images"
|
293 |
+
for d in (infos_dir, raw_image_dir, out_image_dir):
|
294 |
+
d.mkdir(parents=True, exist_ok=True)
|
295 |
+
|
296 |
+
downloader = MapillaryDownloader(token)
|
297 |
+
loop = asyncio.get_event_loop()
|
298 |
+
|
299 |
+
logger.info("Fetching metadata for all images.")
|
300 |
+
image_infos, num_fail = loop.run_until_complete(
|
301 |
+
fetch_image_infos(image_ids, downloader, infos_dir)
|
302 |
+
)
|
303 |
+
logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids))
|
304 |
+
|
305 |
+
logger.info("Fetching image pixels.")
|
306 |
+
image_urls = [(i, info["thumb_2048_url"]) for i, info in image_infos.items()]
|
307 |
+
num_fail = loop.run_until_complete(
|
308 |
+
fetch_images_pixels(image_urls, downloader, raw_image_dir)
|
309 |
+
)
|
310 |
+
logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_urls))
|
311 |
+
|
312 |
+
seq_to_image_ids = defaultdict(list)
|
313 |
+
for i, info in image_infos.items():
|
314 |
+
seq_to_image_ids[info["sequence"]].append(i)
|
315 |
+
seq_to_image_ids = dict(seq_to_image_ids)
|
316 |
+
|
317 |
+
dump = {}
|
318 |
+
for seq_image_ids in tqdm(seq_to_image_ids.values()):
|
319 |
+
dump.update(
|
320 |
+
process_sequence(
|
321 |
+
seq_image_ids,
|
322 |
+
image_infos,
|
323 |
+
projection,
|
324 |
+
cfg,
|
325 |
+
raw_image_dir,
|
326 |
+
out_image_dir,
|
327 |
+
)
|
328 |
+
)
|
329 |
+
write_json(loc_dir / "dump.json", dump)
|
330 |
+
|
331 |
+
# Get the view locations
|
332 |
+
view_ids = []
|
333 |
+
views_latlon = []
|
334 |
+
for seq in dump:
|
335 |
+
for view_id, view in dump[seq]["views"].items():
|
336 |
+
view_ids.append(view_id)
|
337 |
+
views_latlon.append(view["latlong"])
|
338 |
+
views_latlon = np.stack(views_latlon)
|
339 |
+
view_ids = np.array(view_ids)
|
340 |
+
views_xy = projection.project(views_latlon)
|
341 |
+
|
342 |
+
tiles_path = loc_dir / MapillaryDataModule.default_cfg["tiles_filename"]
|
343 |
+
if generate_tiles:
|
344 |
+
logger.info("Creating the map tiles.")
|
345 |
+
bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0))
|
346 |
+
bbox_tiling = bbox_data + cfg.tiling.margin
|
347 |
+
osm_dir = data_dir / "osm"
|
348 |
+
osm_path = osm_dir / params["osm_file"]
|
349 |
+
if not osm_path.exists():
|
350 |
+
logger.info("Downloading OSM raw data.")
|
351 |
+
download_file(DATA_URL + f"/osm/{params['osm_file']}", osm_path)
|
352 |
+
if not osm_path.exists():
|
353 |
+
raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.")
|
354 |
+
tile_manager = TileManager.from_bbox(
|
355 |
+
projection,
|
356 |
+
bbox_tiling,
|
357 |
+
cfg.tiling.ppm,
|
358 |
+
tile_size=cfg.tiling.tile_size,
|
359 |
+
path=osm_path,
|
360 |
+
)
|
361 |
+
tile_manager.save(tiles_path)
|
362 |
+
else:
|
363 |
+
logger.info("Downloading pre-generated map tiles.")
|
364 |
+
download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path)
|
365 |
+
|
366 |
+
# Visualize the data split
|
367 |
+
plotter = GeoPlotter()
|
368 |
+
view_ids_val = set(splits["val"][location])
|
369 |
+
is_val = np.array([int(i.rsplit("_", 1)[0]) in view_ids_val for i in view_ids])
|
370 |
+
plotter.points(views_latlon[~is_val], "red", view_ids[~is_val], "train")
|
371 |
+
plotter.points(views_latlon[is_val], "green", view_ids[is_val], "val")
|
372 |
+
plotter.bbox(bbox, "blue", "query bounding box")
|
373 |
+
plotter.bbox(projection.unproject(bbox_tiling), "black", "tiling bounding box")
|
374 |
+
geo_viz_path = loc_dir / f"split_{location}.html"
|
375 |
+
plotter.fig.write_html(geo_viz_path)
|
376 |
+
logger.info("Wrote split visualization to %s.", geo_viz_path)
|
377 |
+
|
378 |
+
shutil.rmtree(raw_image_dir)
|
379 |
+
logger.info("Done processing for location %s.", location)
|
380 |
+
|
381 |
+
|
382 |
+
if __name__ == "__main__":
|
383 |
+
parser = argparse.ArgumentParser()
|
384 |
+
parser.add_argument(
|
385 |
+
"--locations", type=str, nargs="+", default=list(location_to_params)
|
386 |
+
)
|
387 |
+
parser.add_argument("--split_filename", type=str, default="splits_MGL_13loc.json")
|
388 |
+
parser.add_argument("--token", type=str, required=True)
|
389 |
+
parser.add_argument(
|
390 |
+
"--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"]
|
391 |
+
)
|
392 |
+
parser.add_argument("--generate_tiles", action="store_true")
|
393 |
+
args = parser.parse_args()
|
394 |
+
|
395 |
+
args.data_dir.mkdir(exist_ok=True, parents=True)
|
396 |
+
shutil.copy(Path(__file__).parent / args.split_filename, args.data_dir)
|
397 |
+
|
398 |
+
for location in args.locations:
|
399 |
+
logger.info("Starting processing for location %s.", location)
|
400 |
+
process_location(
|
401 |
+
location,
|
402 |
+
args.data_dir,
|
403 |
+
args.data_dir / args.split_filename,
|
404 |
+
args.token,
|
405 |
+
args.generate_tiles,
|
406 |
+
)
|
maploc/data/mapillary/splits_MGL_13loc.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maploc/data/mapillary/utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import List, Tuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from opensfm import features
|
9 |
+
from opensfm.pygeometry import Camera, compute_camera_mapping, Pose
|
10 |
+
from opensfm.pymap import Shot
|
11 |
+
from scipy.spatial.transform import Rotation
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def keyframe_selection(shots: List[Shot], min_dist: float = 4) -> List[int]:
|
17 |
+
camera_centers = np.stack([shot.pose.get_origin() for shot in shots], 0)
|
18 |
+
distances = np.linalg.norm(np.diff(camera_centers, axis=0), axis=1)
|
19 |
+
selected = [0]
|
20 |
+
cum = 0
|
21 |
+
for i in range(1, len(camera_centers)):
|
22 |
+
cum += distances[i - 1]
|
23 |
+
if cum >= min_dist:
|
24 |
+
selected.append(i)
|
25 |
+
cum = 0
|
26 |
+
return selected
|
27 |
+
|
28 |
+
|
29 |
+
def perspective_camera_from_pano(camera: Camera, size: int) -> Camera:
|
30 |
+
camera_new = Camera.create_perspective(0.5, 0, 0)
|
31 |
+
camera_new.height = camera_new.width = size
|
32 |
+
camera_new.id = "perspective_from_pano"
|
33 |
+
return camera_new
|
34 |
+
|
35 |
+
|
36 |
+
def scale_camera(camera: Camera, max_size: int) -> Camera:
|
37 |
+
height = camera.height
|
38 |
+
width = camera.width
|
39 |
+
factor = max_size / float(max(height, width))
|
40 |
+
if factor >= 1:
|
41 |
+
return camera
|
42 |
+
camera.width = int(round(width * factor))
|
43 |
+
camera.height = int(round(height * factor))
|
44 |
+
return camera
|
45 |
+
|
46 |
+
|
47 |
+
class PanoramaUndistorter:
|
48 |
+
def __init__(self, camera_pano: Camera, camera_new: Camera):
|
49 |
+
w, h = camera_new.width, camera_new.height
|
50 |
+
self.shape = (h, w)
|
51 |
+
|
52 |
+
dst_y, dst_x = np.indices(self.shape).astype(np.float32)
|
53 |
+
dst_pixels_denormalized = np.column_stack([dst_x.ravel(), dst_y.ravel()])
|
54 |
+
dst_pixels = features.normalized_image_coordinates(
|
55 |
+
dst_pixels_denormalized, w, h
|
56 |
+
)
|
57 |
+
self.dst_bearings = camera_new.pixel_bearing_many(dst_pixels)
|
58 |
+
|
59 |
+
self.camera_pano = camera_pano
|
60 |
+
self.camera_perspective = camera_new
|
61 |
+
|
62 |
+
def __call__(
|
63 |
+
self, image: np.ndarray, panoshot: Shot, perspectiveshot: Shot
|
64 |
+
) -> np.ndarray:
|
65 |
+
# Rotate to panorama reference frame
|
66 |
+
rotation = np.dot(
|
67 |
+
panoshot.pose.get_rotation_matrix(),
|
68 |
+
perspectiveshot.pose.get_rotation_matrix().T,
|
69 |
+
)
|
70 |
+
rotated_bearings = np.dot(self.dst_bearings, rotation.T)
|
71 |
+
|
72 |
+
# Project to panorama pixels
|
73 |
+
src_pixels = panoshot.camera.project_many(rotated_bearings)
|
74 |
+
src_pixels_denormalized = features.denormalized_image_coordinates(
|
75 |
+
src_pixels, image.shape[1], image.shape[0]
|
76 |
+
)
|
77 |
+
src_pixels_denormalized.shape = self.shape + (2,)
|
78 |
+
|
79 |
+
# Sample color
|
80 |
+
x = src_pixels_denormalized[..., 0].astype(np.float32)
|
81 |
+
y = src_pixels_denormalized[..., 1].astype(np.float32)
|
82 |
+
colors = cv2.remap(image, x, y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
|
83 |
+
return colors
|
84 |
+
|
85 |
+
|
86 |
+
class CameraUndistorter:
|
87 |
+
def __init__(self, camera_distorted: Camera, camera_new: Camera):
|
88 |
+
self.maps = compute_camera_mapping(
|
89 |
+
camera_distorted,
|
90 |
+
camera_new,
|
91 |
+
camera_distorted.width,
|
92 |
+
camera_distorted.height,
|
93 |
+
)
|
94 |
+
self.camera_perspective = camera_new
|
95 |
+
self.camera_distorted = camera_distorted
|
96 |
+
|
97 |
+
def __call__(self, image: np.ndarray) -> np.ndarray:
|
98 |
+
assert image.shape[:2] == (
|
99 |
+
self.camera_distorted.height,
|
100 |
+
self.camera_distorted.width,
|
101 |
+
)
|
102 |
+
undistorted = cv2.remap(image, *self.maps, cv2.INTER_LINEAR)
|
103 |
+
resized = cv2.resize(
|
104 |
+
undistorted,
|
105 |
+
(self.camera_perspective.width, self.camera_perspective.height),
|
106 |
+
interpolation=cv2.INTER_AREA,
|
107 |
+
)
|
108 |
+
return resized
|
109 |
+
|
110 |
+
|
111 |
+
def render_panorama(
|
112 |
+
shot: Shot,
|
113 |
+
pano: np.ndarray,
|
114 |
+
undistorter: PanoramaUndistorter,
|
115 |
+
offset: float = 0.0,
|
116 |
+
) -> Tuple[List[Shot], List[np.ndarray]]:
|
117 |
+
yaws = [0, 90, 180, 270]
|
118 |
+
suffixes = ["front", "left", "back", "right"]
|
119 |
+
images = []
|
120 |
+
shots = []
|
121 |
+
|
122 |
+
# To reduce aliasing, since cv2.remap does not support area samplimg,
|
123 |
+
# we first resize with anti-aliasing.
|
124 |
+
h, w = undistorter.shape
|
125 |
+
h, w = (w * 2, w * 4) # assuming 90deg FOV
|
126 |
+
pano_resized = cv2.resize(pano, (w, h), interpolation=cv2.INTER_AREA)
|
127 |
+
|
128 |
+
for yaw, suffix in zip(yaws, suffixes):
|
129 |
+
R_pano2persp = Rotation.from_euler("Y", yaw + offset, degrees=True).as_matrix()
|
130 |
+
name = f"{shot.id}_{suffix}"
|
131 |
+
shot_new = Shot(
|
132 |
+
name,
|
133 |
+
undistorter.camera_perspective,
|
134 |
+
Pose.compose(Pose(R_pano2persp), shot.pose),
|
135 |
+
)
|
136 |
+
shot_new.metadata = shot.metadata
|
137 |
+
perspective = undistorter(pano_resized, shot, shot_new)
|
138 |
+
images.append(perspective)
|
139 |
+
shots.append(shot_new)
|
140 |
+
return shots, images
|
141 |
+
|
142 |
+
|
143 |
+
def undistort_camera(
|
144 |
+
shot: Shot, image: np.ndarray, undistorter: CameraUndistorter
|
145 |
+
) -> Tuple[Shot, np.ndarray]:
|
146 |
+
name = f"{shot.id}_undistorted"
|
147 |
+
shot_out = Shot(name, undistorter.camera_perspective, shot.pose)
|
148 |
+
shot_out.metadata = shot.metadata
|
149 |
+
undistorted = undistorter(image)
|
150 |
+
return shot_out, undistorted
|
151 |
+
|
152 |
+
|
153 |
+
def undistort_shot(
|
154 |
+
image_raw: np.ndarray,
|
155 |
+
shot_orig: Shot,
|
156 |
+
undistorter,
|
157 |
+
pano_offset: float,
|
158 |
+
) -> Tuple[List[Shot], List[np.ndarray]]:
|
159 |
+
camera = shot_orig.camera
|
160 |
+
if image_raw.shape[:2] != (camera.height, camera.width):
|
161 |
+
raise ValueError(
|
162 |
+
shot_orig.id, image_raw.shape[:2], (camera.height, camera.width)
|
163 |
+
)
|
164 |
+
if camera.is_panorama(camera.projection_type):
|
165 |
+
shots, undistorted = render_panorama(
|
166 |
+
shot_orig, image_raw, undistorter, offset=pano_offset
|
167 |
+
)
|
168 |
+
elif camera.projection_type in ("perspective", "fisheye"):
|
169 |
+
shot, undistorted = undistort_camera(shot_orig, image_raw, undistorter)
|
170 |
+
shots, undistorted = [shot], [undistorted]
|
171 |
+
else:
|
172 |
+
raise NotImplementedError(camera.projection_type)
|
173 |
+
return shots, undistorted
|
maploc/data/sequential.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def chunk_sequence(
|
8 |
+
data,
|
9 |
+
indices,
|
10 |
+
*,
|
11 |
+
names=None,
|
12 |
+
max_length=100,
|
13 |
+
min_length=1,
|
14 |
+
max_delay_s=None,
|
15 |
+
max_inter_dist=None,
|
16 |
+
max_total_dist=None,
|
17 |
+
):
|
18 |
+
sort_array = data.get("capture_time", data.get("index", names or indices))
|
19 |
+
indices = sorted(indices, key=lambda i: sort_array[i].tolist())
|
20 |
+
centers = torch.stack([data["t_c2w"][i][:2] for i in indices]).numpy()
|
21 |
+
dists = np.linalg.norm(np.diff(centers, axis=0), axis=-1)
|
22 |
+
if "capture_time" in data:
|
23 |
+
times = torch.stack([data["capture_time"][i] for i in indices])
|
24 |
+
times = times.double() / 1e3 # ms to s
|
25 |
+
delays = np.diff(times, axis=0)
|
26 |
+
else:
|
27 |
+
delays = np.zeros_like(dists)
|
28 |
+
chunks = [[indices[0]]]
|
29 |
+
dist_total = 0
|
30 |
+
for dist, delay, idx in zip(dists, delays, indices[1:]):
|
31 |
+
dist_total += dist
|
32 |
+
if (
|
33 |
+
(max_inter_dist is not None and dist > max_inter_dist)
|
34 |
+
or (max_total_dist is not None and dist_total > max_total_dist)
|
35 |
+
or (max_delay_s is not None and delay > max_delay_s)
|
36 |
+
or len(chunks[-1]) >= max_length
|
37 |
+
):
|
38 |
+
chunks.append([])
|
39 |
+
dist_total = 0
|
40 |
+
chunks[-1].append(idx)
|
41 |
+
chunks = list(filter(lambda c: len(c) >= min_length, chunks))
|
42 |
+
chunks = sorted(chunks, key=len, reverse=True)
|
43 |
+
return chunks
|
44 |
+
|
45 |
+
|
46 |
+
def unpack_batches(batches):
|
47 |
+
images = [b["image"].permute(1, 2, 0) for b in batches]
|
48 |
+
canvas = [b["canvas"] for b in batches]
|
49 |
+
rasters = [b["map"] for b in batches]
|
50 |
+
yaws = torch.stack([b["roll_pitch_yaw"][-1] for b in batches])
|
51 |
+
uv_gt = torch.stack([b["uv"] for b in batches])
|
52 |
+
xy_gt = torch.stack(
|
53 |
+
[canv.to_xy(uv.cpu().double()) for uv, canv in zip(uv_gt, canvas)]
|
54 |
+
)
|
55 |
+
ret = [images, canvas, rasters, yaws, uv_gt, xy_gt.to(uv_gt)]
|
56 |
+
if "uv_gps" in batches[0]:
|
57 |
+
xy_gps = torch.stack(
|
58 |
+
[c.to_xy(b["uv_gps"].cpu().double()) for b, c in zip(batches, canvas)]
|
59 |
+
)
|
60 |
+
ret.append(xy_gps.to(uv_gt))
|
61 |
+
return ret
|
maploc/data/torch.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import os
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import get_worker_info
|
8 |
+
from torch.utils.data._utils.collate import (
|
9 |
+
default_collate_err_msg_format,
|
10 |
+
np_str_obj_array_pattern,
|
11 |
+
)
|
12 |
+
from lightning_fabric.utilities.seed import pl_worker_init_function
|
13 |
+
from lightning_utilities.core.apply_func import apply_to_collection
|
14 |
+
from lightning_fabric.utilities.apply_func import move_data_to_device
|
15 |
+
|
16 |
+
|
17 |
+
def collate(batch):
|
18 |
+
"""Difference with PyTorch default_collate: it can stack other tensor-like objects.
|
19 |
+
Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
20 |
+
https://github.com/cvg/pixloc
|
21 |
+
Released under the Apache License 2.0
|
22 |
+
"""
|
23 |
+
if not isinstance(batch, list): # no batching
|
24 |
+
return batch
|
25 |
+
elem = batch[0]
|
26 |
+
elem_type = type(elem)
|
27 |
+
if isinstance(elem, torch.Tensor):
|
28 |
+
out = None
|
29 |
+
if torch.utils.data.get_worker_info() is not None:
|
30 |
+
# If we're in a background process, concatenate directly into a
|
31 |
+
# shared memory tensor to avoid an extra copy
|
32 |
+
numel = sum(x.numel() for x in batch)
|
33 |
+
storage = elem.storage()._new_shared(numel, device=elem.device)
|
34 |
+
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
35 |
+
return torch.stack(batch, 0, out=out)
|
36 |
+
elif (
|
37 |
+
elem_type.__module__ == "numpy"
|
38 |
+
and elem_type.__name__ != "str_"
|
39 |
+
and elem_type.__name__ != "string_"
|
40 |
+
):
|
41 |
+
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
|
42 |
+
# array of string classes and object
|
43 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
44 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
45 |
+
|
46 |
+
return collate([torch.as_tensor(b) for b in batch])
|
47 |
+
elif elem.shape == (): # scalars
|
48 |
+
return torch.as_tensor(batch)
|
49 |
+
elif isinstance(elem, float):
|
50 |
+
return torch.tensor(batch, dtype=torch.float64)
|
51 |
+
elif isinstance(elem, int):
|
52 |
+
return torch.tensor(batch)
|
53 |
+
elif isinstance(elem, (str, bytes)):
|
54 |
+
return batch
|
55 |
+
elif isinstance(elem, collections.abc.Mapping):
|
56 |
+
return {key: collate([d[key] for d in batch]) for key in elem}
|
57 |
+
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
58 |
+
return elem_type(*(collate(samples) for samples in zip(*batch)))
|
59 |
+
elif isinstance(elem, collections.abc.Sequence):
|
60 |
+
# check to make sure that the elements in batch have consistent size
|
61 |
+
it = iter(batch)
|
62 |
+
elem_size = len(next(it))
|
63 |
+
if not all(len(elem) == elem_size for elem in it):
|
64 |
+
raise RuntimeError("each element in list of batch should be of equal size")
|
65 |
+
transposed = zip(*batch)
|
66 |
+
return [collate(samples) for samples in transposed]
|
67 |
+
else:
|
68 |
+
# try to stack anyway in case the object implements stacking.
|
69 |
+
try:
|
70 |
+
return torch.stack(batch, 0)
|
71 |
+
except TypeError as e:
|
72 |
+
if "expected Tensor as element" in str(e):
|
73 |
+
return batch
|
74 |
+
else:
|
75 |
+
raise e
|
76 |
+
|
77 |
+
|
78 |
+
def set_num_threads(nt):
|
79 |
+
"""Force numpy and other libraries to use a limited number of threads."""
|
80 |
+
try:
|
81 |
+
import mkl
|
82 |
+
except ImportError:
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
mkl.set_num_threads(nt)
|
86 |
+
torch.set_num_threads(1)
|
87 |
+
os.environ["IPC_ENABLE"] = "1"
|
88 |
+
for o in [
|
89 |
+
"OPENBLAS_NUM_THREADS",
|
90 |
+
"NUMEXPR_NUM_THREADS",
|
91 |
+
"OMP_NUM_THREADS",
|
92 |
+
"MKL_NUM_THREADS",
|
93 |
+
]:
|
94 |
+
os.environ[o] = str(nt)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(i):
|
98 |
+
info = get_worker_info()
|
99 |
+
pl_worker_init_function(info.id)
|
100 |
+
num_threads = info.dataset.cfg.get("num_threads")
|
101 |
+
if num_threads is not None:
|
102 |
+
set_num_threads(num_threads)
|
103 |
+
|
104 |
+
|
105 |
+
def unbatch_to_device(data, device="cpu"):
|
106 |
+
data = move_data_to_device(data, device)
|
107 |
+
data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
|
108 |
+
data = apply_to_collection(
|
109 |
+
data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
|
110 |
+
)
|
111 |
+
return data
|
maploc/data/utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from scipy.spatial.transform import Rotation
|
5 |
+
|
6 |
+
|
7 |
+
def crop_map(raster, xy, size, seed=None):
|
8 |
+
h, w = raster.shape[-2:]
|
9 |
+
state = np.random.RandomState(seed)
|
10 |
+
top = state.randint(0, h - size + 1)
|
11 |
+
left = state.randint(0, w - size + 1)
|
12 |
+
raster = raster[..., top : top + size, left : left + size]
|
13 |
+
xy -= np.array([left, top])
|
14 |
+
return raster, xy
|
15 |
+
|
16 |
+
|
17 |
+
def random_rot90(raster, xy, heading, seed=None):
|
18 |
+
rot = np.random.RandomState(seed).randint(0, 4)
|
19 |
+
heading = (heading + rot * np.pi / 2) % (2 * np.pi)
|
20 |
+
h, w = raster.shape[-2:]
|
21 |
+
if rot == 0:
|
22 |
+
xy2 = xy
|
23 |
+
elif rot == 2:
|
24 |
+
xy2 = np.array([w, h]) - 1 - xy
|
25 |
+
elif rot == 1:
|
26 |
+
xy2 = np.array([xy[1], w - 1 - xy[0]])
|
27 |
+
elif rot == 3:
|
28 |
+
xy2 = np.array([h - 1 - xy[1], xy[0]])
|
29 |
+
else:
|
30 |
+
raise ValueError(rot)
|
31 |
+
raster = np.rot90(raster, rot, axes=(-2, -1))
|
32 |
+
return raster, xy2, heading
|
33 |
+
|
34 |
+
|
35 |
+
def random_flip(image, raster, xy, heading, seed=None):
|
36 |
+
state = np.random.RandomState(seed)
|
37 |
+
if state.rand() > 0.5: # no flip
|
38 |
+
return image, raster, xy, heading
|
39 |
+
image = image[:, ::-1]
|
40 |
+
h, w = raster.shape[-2:]
|
41 |
+
if state.rand() > 0.5: # flip x
|
42 |
+
raster = raster[..., :, ::-1]
|
43 |
+
xy = np.array([w - 1 - xy[0], xy[1]])
|
44 |
+
heading = np.pi - heading
|
45 |
+
else: # flip y
|
46 |
+
raster = raster[..., ::-1, :]
|
47 |
+
xy = np.array([xy[0], h - 1 - xy[1]])
|
48 |
+
heading = -heading
|
49 |
+
heading = heading % (2 * np.pi)
|
50 |
+
return image, raster, xy, heading
|
51 |
+
|
52 |
+
|
53 |
+
def decompose_rotmat(R_c2w):
|
54 |
+
R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
|
55 |
+
rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
|
56 |
+
roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
|
57 |
+
# rot_w2c_check = R_cv2xyz.inv() * Rotation.from_euler('YXZ', [roll, pitch, yaw], degrees=True)
|
58 |
+
# np.testing.assert_allclose(rot_w2c_check.as_matrix(), R_c2w.T, rtol=1e-6, atol=1e-6)
|
59 |
+
# R_plane2c = Rotation.from_euler("ZX", [roll, pitch], degrees=True).as_matrix()
|
60 |
+
return roll, pitch, yaw
|
maploc/demo.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from . import logger
|
9 |
+
from .evaluation.run import resolve_checkpoint_path, pretrained_models
|
10 |
+
from .models.orienternet import OrienterNet
|
11 |
+
from .models.voting import fuse_gps, argmax_xyr
|
12 |
+
from .data.image import resize_image, pad_image, rectify_image
|
13 |
+
from .osm.raster import Canvas
|
14 |
+
from .utils.wrappers import Camera
|
15 |
+
from .utils.io import read_image
|
16 |
+
from .utils.geo import BoundaryBox, Projection
|
17 |
+
from .utils.exif import EXIF
|
18 |
+
|
19 |
+
try:
|
20 |
+
from geopy.geocoders import Nominatim
|
21 |
+
|
22 |
+
geolocator = Nominatim(user_agent="orienternet")
|
23 |
+
except ImportError:
|
24 |
+
geolocator = None
|
25 |
+
|
26 |
+
try:
|
27 |
+
from gradio_client import Client
|
28 |
+
|
29 |
+
calibrator = Client("https://jinlinyi-perspectivefields.hf.space/")
|
30 |
+
except (ImportError, ValueError):
|
31 |
+
calibrator = None
|
32 |
+
|
33 |
+
|
34 |
+
def image_calibration(image_path):
|
35 |
+
logger.info("Calling the PerspectiveFields calibrator, this may take some time.")
|
36 |
+
result = calibrator.predict(
|
37 |
+
image_path, "NEW:Paramnet-360Cities-edina-centered", api_name="/predict"
|
38 |
+
)
|
39 |
+
result = dict(r.rsplit(" ", 1) for r in result[1].split("\n"))
|
40 |
+
roll_pitch = float(result["roll"]), float(result["pitch"])
|
41 |
+
return roll_pitch, float(result["vertical fov"])
|
42 |
+
|
43 |
+
|
44 |
+
def camera_from_exif(exif: EXIF, fov: Optional[float] = None) -> Camera:
|
45 |
+
w, h = image_size = exif.extract_image_size()
|
46 |
+
_, f_ratio = exif.extract_focal()
|
47 |
+
if f_ratio == 0:
|
48 |
+
if fov is not None:
|
49 |
+
# This is the vertical FoV.
|
50 |
+
f = h / 2 / np.tan(np.deg2rad(fov) / 2)
|
51 |
+
else:
|
52 |
+
return None
|
53 |
+
else:
|
54 |
+
f = f_ratio * max(image_size)
|
55 |
+
return Camera.from_dict(
|
56 |
+
dict(
|
57 |
+
model="SIMPLE_PINHOLE",
|
58 |
+
width=w,
|
59 |
+
height=h,
|
60 |
+
params=[f, w / 2 + 0.5, h / 2 + 0.5],
|
61 |
+
)
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def read_input_image(
|
66 |
+
image_path: str,
|
67 |
+
prior_latlon: Optional[Tuple[float, float]] = None,
|
68 |
+
prior_address: Optional[str] = None,
|
69 |
+
fov: Optional[float] = None,
|
70 |
+
tile_size_meters: int = 64,
|
71 |
+
):
|
72 |
+
image = read_image(image_path)
|
73 |
+
|
74 |
+
roll_pitch = None
|
75 |
+
if calibrator is not None:
|
76 |
+
roll_pitch, fov = image_calibration(image_path)
|
77 |
+
else:
|
78 |
+
logger.info("Could not call PerspectiveFields, maybe install gradio_client?")
|
79 |
+
if roll_pitch is not None:
|
80 |
+
logger.info("Using (roll, pitch) %s.", roll_pitch)
|
81 |
+
|
82 |
+
with open(image_path, "rb") as fid:
|
83 |
+
exif = EXIF(fid, lambda: image.shape[:2])
|
84 |
+
camera = camera_from_exif(exif, fov)
|
85 |
+
if camera is None:
|
86 |
+
raise ValueError(
|
87 |
+
"No camera intrinsics found in the EXIF, provide an FoV guess."
|
88 |
+
)
|
89 |
+
|
90 |
+
latlon = None
|
91 |
+
if prior_latlon is not None:
|
92 |
+
latlon = prior_latlon
|
93 |
+
logger.info("Using prior latlon %s.", prior_latlon)
|
94 |
+
if prior_address is not None:
|
95 |
+
if geolocator is None:
|
96 |
+
raise ValueError("geocoding unavailable, install geopy.")
|
97 |
+
location = geolocator.geocode(prior_address)
|
98 |
+
if location is None:
|
99 |
+
logger.info("Could not find any location for %s.", prior_address)
|
100 |
+
else:
|
101 |
+
logger.info("Using prior address: %s", location.address)
|
102 |
+
latlon = (location.latitude, location.longitude)
|
103 |
+
if latlon is None:
|
104 |
+
geo = exif.extract_geo()
|
105 |
+
if geo:
|
106 |
+
alt = geo.get("altitude", 0) # read if available
|
107 |
+
latlon = (geo["latitude"], geo["longitude"], alt)
|
108 |
+
logger.info("Using prior location from EXIF.")
|
109 |
+
else:
|
110 |
+
logger.info("Could not find any prior location in EXIF.")
|
111 |
+
if latlon is None:
|
112 |
+
raise ValueError("Need prior latlon")
|
113 |
+
latlon = np.array(latlon)
|
114 |
+
|
115 |
+
proj = Projection(*latlon)
|
116 |
+
center = proj.project(latlon)
|
117 |
+
bbox = BoundaryBox(center, center) + tile_size_meters
|
118 |
+
return image, camera, roll_pitch, proj, bbox, latlon
|
119 |
+
|
120 |
+
|
121 |
+
class Demo:
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
experiment_or_path: Optional[str] = "OrienterNet_MGL",
|
125 |
+
device=None,
|
126 |
+
**kwargs
|
127 |
+
):
|
128 |
+
if experiment_or_path in pretrained_models:
|
129 |
+
experiment_or_path, _ = pretrained_models[experiment_or_path]
|
130 |
+
path = resolve_checkpoint_path(experiment_or_path)
|
131 |
+
ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
|
132 |
+
config = ckpt["hyper_parameters"]
|
133 |
+
config.model.update(kwargs)
|
134 |
+
config.model.image_encoder.backbone.pretrained = False
|
135 |
+
|
136 |
+
model = OrienterNet(config.model).eval()
|
137 |
+
state = {k[len("model.") :]: v for k, v in ckpt["state_dict"].items()}
|
138 |
+
model.load_state_dict(state, strict=True)
|
139 |
+
if device is None:
|
140 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
141 |
+
model = model.to(device)
|
142 |
+
|
143 |
+
self.model = model
|
144 |
+
self.config = config
|
145 |
+
self.device = device
|
146 |
+
|
147 |
+
def prepare_data(
|
148 |
+
self,
|
149 |
+
image: np.ndarray,
|
150 |
+
camera: Camera,
|
151 |
+
canvas: Canvas,
|
152 |
+
roll_pitch: Optional[Tuple[float]] = None,
|
153 |
+
):
|
154 |
+
assert image.shape[:2][::-1] == tuple(camera.size.tolist())
|
155 |
+
target_focal_length = self.config.data.resize_image / 2
|
156 |
+
factor = target_focal_length / camera.f
|
157 |
+
size = (camera.size * factor).round().int()
|
158 |
+
|
159 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
|
160 |
+
valid = None
|
161 |
+
if roll_pitch is not None:
|
162 |
+
roll, pitch = roll_pitch
|
163 |
+
image, valid = rectify_image(
|
164 |
+
image,
|
165 |
+
camera.float(),
|
166 |
+
roll=-roll,
|
167 |
+
pitch=-pitch,
|
168 |
+
)
|
169 |
+
image, _, camera, *maybe_valid = resize_image(
|
170 |
+
image, size.numpy(), camera=camera, valid=valid
|
171 |
+
)
|
172 |
+
valid = None if valid is None else maybe_valid
|
173 |
+
|
174 |
+
max_stride = max(self.model.image_encoder.layer_strides)
|
175 |
+
size = (np.ceil((size / max_stride)) * max_stride).int()
|
176 |
+
image, valid, camera = pad_image(
|
177 |
+
image, size.numpy(), camera, crop_and_center=True
|
178 |
+
)
|
179 |
+
|
180 |
+
return dict(
|
181 |
+
image=image,
|
182 |
+
map=torch.from_numpy(canvas.raster).long(),
|
183 |
+
camera=camera.float(),
|
184 |
+
valid=valid,
|
185 |
+
)
|
186 |
+
|
187 |
+
def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
|
188 |
+
data = self.prepare_data(image, camera, canvas, **kwargs)
|
189 |
+
data_ = {k: v.to(self.device)[None] for k, v in data.items()}
|
190 |
+
with torch.no_grad():
|
191 |
+
pred = self.model(data_)
|
192 |
+
|
193 |
+
xy_gps = canvas.bbox.center
|
194 |
+
uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
|
195 |
+
|
196 |
+
lp_xyr = pred["log_probs"].squeeze(0)
|
197 |
+
tile_size = canvas.bbox.size.min() / 2
|
198 |
+
sigma = tile_size - 20 # 20 meters margin
|
199 |
+
lp_xyr = fuse_gps(
|
200 |
+
lp_xyr,
|
201 |
+
uv_gps.to(lp_xyr),
|
202 |
+
self.config.model.pixel_per_meter,
|
203 |
+
sigma=sigma,
|
204 |
+
)
|
205 |
+
xyr = argmax_xyr(lp_xyr).cpu()
|
206 |
+
|
207 |
+
prob = lp_xyr.exp().cpu()
|
208 |
+
neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
|
209 |
+
return xyr[:2], xyr[1], prob, neural_map, data["image"]
|
maploc/evaluation/kitti.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf, DictConfig
|
8 |
+
|
9 |
+
from .. import logger
|
10 |
+
from ..data import KittiDataModule
|
11 |
+
from .run import evaluate
|
12 |
+
|
13 |
+
|
14 |
+
default_cfg_single = OmegaConf.create({})
|
15 |
+
# For the sequential evaluation, we need to center the map around the GT location,
|
16 |
+
# since random offsets would accumulate and leave only the GT location with a valid mask.
|
17 |
+
# This should not have much impact on the results.
|
18 |
+
default_cfg_sequential = OmegaConf.create(
|
19 |
+
{
|
20 |
+
"data": {
|
21 |
+
"mask_radius": KittiDataModule.default_cfg["max_init_error"],
|
22 |
+
"prior_range_rotation": KittiDataModule.default_cfg[
|
23 |
+
"max_init_error_rotation"
|
24 |
+
]
|
25 |
+
+ 1,
|
26 |
+
"max_init_error": 0,
|
27 |
+
"max_init_error_rotation": 0,
|
28 |
+
},
|
29 |
+
"chunking": {
|
30 |
+
"max_length": 100, # about 10s?
|
31 |
+
},
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def run(
|
37 |
+
split: str,
|
38 |
+
experiment: str,
|
39 |
+
cfg: Optional[DictConfig] = None,
|
40 |
+
sequential: bool = False,
|
41 |
+
thresholds: Tuple[int] = (1, 3, 5),
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
cfg = cfg or {}
|
45 |
+
if isinstance(cfg, dict):
|
46 |
+
cfg = OmegaConf.create(cfg)
|
47 |
+
default = default_cfg_sequential if sequential else default_cfg_single
|
48 |
+
cfg = OmegaConf.merge(default, cfg)
|
49 |
+
dataset = KittiDataModule(cfg.get("data", {}))
|
50 |
+
|
51 |
+
metrics = evaluate(
|
52 |
+
experiment,
|
53 |
+
cfg,
|
54 |
+
dataset,
|
55 |
+
split=split,
|
56 |
+
sequential=sequential,
|
57 |
+
viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
|
58 |
+
**kwargs,
|
59 |
+
)
|
60 |
+
|
61 |
+
keys = ["directional_error", "yaw_max_error"]
|
62 |
+
if sequential:
|
63 |
+
keys += ["directional_seq_error", "yaw_seq_error"]
|
64 |
+
for k in keys:
|
65 |
+
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
|
66 |
+
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
|
67 |
+
return metrics
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
parser = argparse.ArgumentParser()
|
72 |
+
parser.add_argument("--experiment", type=str, required=True)
|
73 |
+
parser.add_argument(
|
74 |
+
"--split", type=str, default="test", choices=["test", "val", "train"]
|
75 |
+
)
|
76 |
+
parser.add_argument("--sequential", action="store_true")
|
77 |
+
parser.add_argument("--output_dir", type=Path)
|
78 |
+
parser.add_argument("--num", type=int)
|
79 |
+
parser.add_argument("dotlist", nargs="*")
|
80 |
+
args = parser.parse_args()
|
81 |
+
cfg = OmegaConf.from_cli(args.dotlist)
|
82 |
+
run(
|
83 |
+
args.split,
|
84 |
+
args.experiment,
|
85 |
+
cfg,
|
86 |
+
args.sequential,
|
87 |
+
output_dir=args.output_dir,
|
88 |
+
num=args.num,
|
89 |
+
)
|
maploc/evaluation/mapillary.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf, DictConfig
|
8 |
+
|
9 |
+
from .. import logger
|
10 |
+
from ..conf import data as conf_data_dir
|
11 |
+
from ..data import MapillaryDataModule
|
12 |
+
from .run import evaluate
|
13 |
+
|
14 |
+
|
15 |
+
split_overrides = {
|
16 |
+
"val": {
|
17 |
+
"scenes": [
|
18 |
+
"sanfrancisco_soma",
|
19 |
+
"sanfrancisco_hayes",
|
20 |
+
"amsterdam",
|
21 |
+
"berlin",
|
22 |
+
"lemans",
|
23 |
+
"montrouge",
|
24 |
+
"toulouse",
|
25 |
+
"nantes",
|
26 |
+
"vilnius",
|
27 |
+
"avignon",
|
28 |
+
"helsinki",
|
29 |
+
"milan",
|
30 |
+
"paris",
|
31 |
+
],
|
32 |
+
},
|
33 |
+
}
|
34 |
+
data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
|
35 |
+
data_cfg = OmegaConf.merge(
|
36 |
+
data_cfg_train,
|
37 |
+
{
|
38 |
+
"return_gps": True,
|
39 |
+
"add_map_mask": True,
|
40 |
+
"max_init_error": 32,
|
41 |
+
"loading": {"val": {"batch_size": 1, "num_workers": 0}},
|
42 |
+
},
|
43 |
+
)
|
44 |
+
default_cfg_single = OmegaConf.create({"data": data_cfg})
|
45 |
+
default_cfg_sequential = OmegaConf.create(
|
46 |
+
{
|
47 |
+
**default_cfg_single,
|
48 |
+
"chunking": {
|
49 |
+
"max_length": 10,
|
50 |
+
},
|
51 |
+
}
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def run(
|
56 |
+
split: str,
|
57 |
+
experiment: str,
|
58 |
+
cfg: Optional[DictConfig] = None,
|
59 |
+
sequential: bool = False,
|
60 |
+
thresholds: Tuple[int] = (1, 3, 5),
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
cfg = cfg or {}
|
64 |
+
if isinstance(cfg, dict):
|
65 |
+
cfg = OmegaConf.create(cfg)
|
66 |
+
default = default_cfg_sequential if sequential else default_cfg_single
|
67 |
+
default = OmegaConf.merge(default, split_overrides[split])
|
68 |
+
cfg = OmegaConf.merge(default, cfg)
|
69 |
+
dataset = MapillaryDataModule(cfg.get("data", {}))
|
70 |
+
|
71 |
+
metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
|
72 |
+
|
73 |
+
keys = [
|
74 |
+
"xy_max_error",
|
75 |
+
"xy_gps_error",
|
76 |
+
"yaw_max_error",
|
77 |
+
]
|
78 |
+
if sequential:
|
79 |
+
keys += [
|
80 |
+
"xy_seq_error",
|
81 |
+
"xy_gps_seq_error",
|
82 |
+
"yaw_seq_error",
|
83 |
+
"yaw_gps_seq_error",
|
84 |
+
]
|
85 |
+
for k in keys:
|
86 |
+
if k not in metrics:
|
87 |
+
logger.warning("Key %s not in metrics.", k)
|
88 |
+
continue
|
89 |
+
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
|
90 |
+
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
|
91 |
+
return metrics
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser()
|
96 |
+
parser.add_argument("--experiment", type=str, required=True)
|
97 |
+
parser.add_argument("--split", type=str, default="val", choices=["val"])
|
98 |
+
parser.add_argument("--sequential", action="store_true")
|
99 |
+
parser.add_argument("--output_dir", type=Path)
|
100 |
+
parser.add_argument("--num", type=int)
|
101 |
+
parser.add_argument("dotlist", nargs="*")
|
102 |
+
args = parser.parse_args()
|
103 |
+
cfg = OmegaConf.from_cli(args.dotlist)
|
104 |
+
run(
|
105 |
+
args.split,
|
106 |
+
args.experiment,
|
107 |
+
cfg,
|
108 |
+
args.sequential,
|
109 |
+
output_dir=args.output_dir,
|
110 |
+
num=args.num,
|
111 |
+
)
|
maploc/evaluation/run.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from itertools import islice
|
5 |
+
from typing import Callable, Dict, Optional, Tuple
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
from torchmetrics import MetricCollection
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from .. import logger, EXPERIMENTS_PATH
|
16 |
+
from ..data.torch import collate, unbatch_to_device
|
17 |
+
from ..models.voting import argmax_xyr, fuse_gps
|
18 |
+
from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError
|
19 |
+
from ..models.sequential import GPSAligner, RigidAligner
|
20 |
+
from ..module import GenericModule
|
21 |
+
from ..utils.io import download_file, DATA_URL
|
22 |
+
from .viz import plot_example_single, plot_example_sequential
|
23 |
+
from .utils import write_dump
|
24 |
+
|
25 |
+
|
26 |
+
pretrained_models = dict(
|
27 |
+
OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def resolve_checkpoint_path(experiment_or_path: str) -> Path:
|
32 |
+
path = Path(experiment_or_path)
|
33 |
+
if not path.exists():
|
34 |
+
# provided name of experiment
|
35 |
+
path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
|
36 |
+
if not path.exists():
|
37 |
+
if experiment_or_path in set(p for p, _ in pretrained_models.values()):
|
38 |
+
download_file(f"{DATA_URL}/{experiment_or_path}", path)
|
39 |
+
else:
|
40 |
+
raise FileNotFoundError(path)
|
41 |
+
if path.is_file():
|
42 |
+
return path
|
43 |
+
# provided only the experiment name
|
44 |
+
maybe_path = path / "last-step.ckpt"
|
45 |
+
if not maybe_path.exists():
|
46 |
+
maybe_path = path / "step.ckpt"
|
47 |
+
if not maybe_path.exists():
|
48 |
+
raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
|
49 |
+
return maybe_path
|
50 |
+
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def evaluate_single_image(
|
54 |
+
dataloader: torch.utils.data.DataLoader,
|
55 |
+
model: GenericModule,
|
56 |
+
num: Optional[int] = None,
|
57 |
+
callback: Optional[Callable] = None,
|
58 |
+
progress: bool = True,
|
59 |
+
mask_index: Optional[Tuple[int]] = None,
|
60 |
+
has_gps: bool = False,
|
61 |
+
):
|
62 |
+
ppm = model.model.conf.pixel_per_meter
|
63 |
+
metrics = MetricCollection(model.model.metrics())
|
64 |
+
metrics["directional_error"] = LateralLongitudinalError(ppm)
|
65 |
+
if has_gps:
|
66 |
+
metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
|
67 |
+
metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
|
68 |
+
metrics["yaw_fused_error"] = AngleError("yaw_fused")
|
69 |
+
metrics = metrics.to(model.device)
|
70 |
+
|
71 |
+
for i, batch_ in enumerate(
|
72 |
+
islice(tqdm(dataloader, total=num, disable=not progress), num)
|
73 |
+
):
|
74 |
+
batch = model.transfer_batch_to_device(batch_, model.device, i)
|
75 |
+
# Ablation: mask semantic classes
|
76 |
+
if mask_index is not None:
|
77 |
+
mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
|
78 |
+
batch["map"][0, mask_index[0]][mask] = 0
|
79 |
+
pred = model(batch)
|
80 |
+
|
81 |
+
if has_gps:
|
82 |
+
(uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
|
83 |
+
pred["log_probs_fused"] = fuse_gps(
|
84 |
+
pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
|
85 |
+
)
|
86 |
+
uvt_fused = argmax_xyr(pred["log_probs_fused"])
|
87 |
+
pred["uv_fused"] = uvt_fused[..., :2]
|
88 |
+
pred["yaw_fused"] = uvt_fused[..., -1]
|
89 |
+
del uv_gps, uvt_fused
|
90 |
+
|
91 |
+
results = metrics(pred, batch)
|
92 |
+
if callback is not None:
|
93 |
+
callback(
|
94 |
+
i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
|
95 |
+
)
|
96 |
+
del batch_, batch, pred, results
|
97 |
+
|
98 |
+
return metrics.cpu()
|
99 |
+
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def evaluate_sequential(
|
103 |
+
dataset: torch.utils.data.Dataset,
|
104 |
+
chunk2idx: Dict,
|
105 |
+
model: GenericModule,
|
106 |
+
num: Optional[int] = None,
|
107 |
+
shuffle: bool = False,
|
108 |
+
callback: Optional[Callable] = None,
|
109 |
+
progress: bool = True,
|
110 |
+
num_rotations: int = 512,
|
111 |
+
mask_index: Optional[Tuple[int]] = None,
|
112 |
+
has_gps: bool = True,
|
113 |
+
):
|
114 |
+
chunk_keys = list(chunk2idx)
|
115 |
+
if shuffle:
|
116 |
+
chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
|
117 |
+
if num is not None:
|
118 |
+
chunk_keys = chunk_keys[:num]
|
119 |
+
lengths = [len(chunk2idx[k]) for k in chunk_keys]
|
120 |
+
logger.info(
|
121 |
+
"Min/max/med lengths: %d/%d/%d, total number of images: %d",
|
122 |
+
min(lengths),
|
123 |
+
np.median(lengths),
|
124 |
+
max(lengths),
|
125 |
+
sum(lengths),
|
126 |
+
)
|
127 |
+
viz = callback is not None
|
128 |
+
|
129 |
+
metrics = MetricCollection(model.model.metrics())
|
130 |
+
ppm = model.model.conf.pixel_per_meter
|
131 |
+
metrics["directional_error"] = LateralLongitudinalError(ppm)
|
132 |
+
metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
|
133 |
+
metrics["yaw_seq_error"] = AngleError("yaw_seq")
|
134 |
+
metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
|
135 |
+
if has_gps:
|
136 |
+
metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
|
137 |
+
metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
|
138 |
+
metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
|
139 |
+
metrics = metrics.to(model.device)
|
140 |
+
|
141 |
+
keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
|
142 |
+
if has_gps:
|
143 |
+
keys_save.append("uv_gps")
|
144 |
+
if viz:
|
145 |
+
keys_save.append("log_probs")
|
146 |
+
|
147 |
+
for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
|
148 |
+
indices = chunk2idx[key]
|
149 |
+
aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
|
150 |
+
if has_gps:
|
151 |
+
aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
|
152 |
+
batches = []
|
153 |
+
preds = []
|
154 |
+
for i in indices:
|
155 |
+
data = dataset[i]
|
156 |
+
data = model.transfer_batch_to_device(data, model.device, 0)
|
157 |
+
pred = model(collate([data]))
|
158 |
+
|
159 |
+
canvas = data["canvas"]
|
160 |
+
data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
|
161 |
+
data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
|
162 |
+
aligner.update(pred["log_probs"][0], canvas, xy, yaw)
|
163 |
+
|
164 |
+
if has_gps:
|
165 |
+
(uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
|
166 |
+
xy_gps = canvas.to_xy(uv_gps.double())
|
167 |
+
aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
|
168 |
+
|
169 |
+
if not viz:
|
170 |
+
data.pop("image")
|
171 |
+
data.pop("map")
|
172 |
+
batches.append(data)
|
173 |
+
preds.append({k: pred[k][0] for k in keys_save})
|
174 |
+
del pred
|
175 |
+
|
176 |
+
xy_gt = torch.stack([b["xy_geo"] for b in batches])
|
177 |
+
yaw_gt = torch.stack([b["yaw"] for b in batches])
|
178 |
+
aligner.compute()
|
179 |
+
xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
|
180 |
+
if has_gps:
|
181 |
+
aligner_gps.compute()
|
182 |
+
xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
|
183 |
+
results = []
|
184 |
+
for i in range(len(indices)):
|
185 |
+
preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
|
186 |
+
preds[i]["yaw_seq"] = yaw_seq[i].float()
|
187 |
+
if has_gps:
|
188 |
+
preds[i]["uv_gps_seq"] = (
|
189 |
+
batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
|
190 |
+
)
|
191 |
+
preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
|
192 |
+
results.append(metrics(preds[i], batches[i]))
|
193 |
+
if viz:
|
194 |
+
callback(chunk_index, model, batches, preds, results, aligner)
|
195 |
+
del aligner, preds, batches, results
|
196 |
+
return metrics.cpu()
|
197 |
+
|
198 |
+
|
199 |
+
def evaluate(
|
200 |
+
experiment: str,
|
201 |
+
cfg: DictConfig,
|
202 |
+
dataset,
|
203 |
+
split: str,
|
204 |
+
sequential: bool = False,
|
205 |
+
output_dir: Optional[Path] = None,
|
206 |
+
callback: Optional[Callable] = None,
|
207 |
+
num_workers: int = 1,
|
208 |
+
viz_kwargs=None,
|
209 |
+
**kwargs,
|
210 |
+
):
|
211 |
+
if experiment in pretrained_models:
|
212 |
+
experiment, cfg_override = pretrained_models[experiment]
|
213 |
+
cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
|
214 |
+
|
215 |
+
logger.info("Evaluating model %s with config %s", experiment, cfg)
|
216 |
+
checkpoint_path = resolve_checkpoint_path(experiment)
|
217 |
+
model = GenericModule.load_from_checkpoint(
|
218 |
+
checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
|
219 |
+
)
|
220 |
+
model = model.eval()
|
221 |
+
if torch.cuda.is_available():
|
222 |
+
model = model.cuda()
|
223 |
+
|
224 |
+
dataset.prepare_data()
|
225 |
+
dataset.setup()
|
226 |
+
|
227 |
+
if output_dir is not None:
|
228 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
229 |
+
if callback is None:
|
230 |
+
if sequential:
|
231 |
+
callback = plot_example_sequential
|
232 |
+
else:
|
233 |
+
callback = plot_example_single
|
234 |
+
callback = functools.partial(
|
235 |
+
callback, out_dir=output_dir, **(viz_kwargs or {})
|
236 |
+
)
|
237 |
+
kwargs = {**kwargs, "callback": callback}
|
238 |
+
|
239 |
+
seed_everything(dataset.cfg.seed)
|
240 |
+
if sequential:
|
241 |
+
dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
|
242 |
+
metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
|
243 |
+
else:
|
244 |
+
loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
|
245 |
+
metrics = evaluate_single_image(loader, model, **kwargs)
|
246 |
+
|
247 |
+
results = metrics.compute()
|
248 |
+
logger.info("All results: %s", results)
|
249 |
+
if output_dir is not None:
|
250 |
+
write_dump(output_dir, experiment, cfg, results, metrics)
|
251 |
+
logger.info("Outputs have been written to %s.", output_dir)
|
252 |
+
return metrics
|
maploc/evaluation/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from ..utils.io import write_json
|
7 |
+
|
8 |
+
|
9 |
+
def compute_recall(errors):
|
10 |
+
num_elements = len(errors)
|
11 |
+
sort_idx = np.argsort(errors)
|
12 |
+
errors = np.array(errors.copy())[sort_idx]
|
13 |
+
recall = (np.arange(num_elements) + 1) / num_elements
|
14 |
+
recall = np.r_[0, recall]
|
15 |
+
errors = np.r_[0, errors]
|
16 |
+
return errors, recall
|
17 |
+
|
18 |
+
|
19 |
+
def compute_auc(errors, recall, thresholds):
|
20 |
+
aucs = []
|
21 |
+
for t in thresholds:
|
22 |
+
last_index = np.searchsorted(errors, t, side="right")
|
23 |
+
r = np.r_[recall[:last_index], recall[last_index - 1]]
|
24 |
+
e = np.r_[errors[:last_index], t]
|
25 |
+
auc = np.trapz(r, x=e) / t
|
26 |
+
aucs.append(auc * 100)
|
27 |
+
return aucs
|
28 |
+
|
29 |
+
|
30 |
+
def write_dump(output_dir, experiment, cfg, results, metrics):
|
31 |
+
dump = {
|
32 |
+
"experiment": experiment,
|
33 |
+
"cfg": OmegaConf.to_container(cfg),
|
34 |
+
"results": results,
|
35 |
+
"errors": {},
|
36 |
+
}
|
37 |
+
for k, m in metrics.items():
|
38 |
+
if hasattr(m, "get_errors"):
|
39 |
+
dump["errors"][k] = m.get_errors().numpy()
|
40 |
+
write_json(output_dir / "log.json", dump)
|
maploc/evaluation/viz.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
from ..utils.io import write_torch_image
|
8 |
+
from ..utils.viz_2d import plot_images, features_to_RGB, save_plot
|
9 |
+
from ..utils.viz_localization import (
|
10 |
+
likelihood_overlay,
|
11 |
+
plot_pose,
|
12 |
+
plot_dense_rotations,
|
13 |
+
add_circle_inset,
|
14 |
+
)
|
15 |
+
from ..osm.viz import Colormap, plot_nodes
|
16 |
+
|
17 |
+
|
18 |
+
def plot_example_single(
|
19 |
+
idx,
|
20 |
+
model,
|
21 |
+
pred,
|
22 |
+
data,
|
23 |
+
results,
|
24 |
+
plot_bev=True,
|
25 |
+
out_dir=None,
|
26 |
+
fig_for_paper=False,
|
27 |
+
show_gps=False,
|
28 |
+
show_fused=False,
|
29 |
+
show_dir_error=False,
|
30 |
+
show_masked_prob=False,
|
31 |
+
):
|
32 |
+
scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
|
33 |
+
uv_gps = data.get("uv_gps")
|
34 |
+
yaw_gt = data["roll_pitch_yaw"][-1].numpy()
|
35 |
+
image = data["image"].permute(1, 2, 0)
|
36 |
+
if "valid" in data:
|
37 |
+
image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
|
38 |
+
|
39 |
+
lp_uvt = lp_uv = pred["log_probs"]
|
40 |
+
if show_fused and "log_probs_fused" in pred:
|
41 |
+
lp_uvt = lp_uv = pred["log_probs_fused"]
|
42 |
+
elif not show_masked_prob and "scores_unmasked" in pred:
|
43 |
+
lp_uvt = lp_uv = pred["scores_unmasked"]
|
44 |
+
has_rotation = lp_uvt.ndim == 3
|
45 |
+
if has_rotation:
|
46 |
+
lp_uv = lp_uvt.max(-1).values
|
47 |
+
if lp_uv.min() > -np.inf:
|
48 |
+
lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
|
49 |
+
prob = lp_uv.exp()
|
50 |
+
uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
|
51 |
+
if show_fused and "uv_fused" in pred:
|
52 |
+
uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
|
53 |
+
feats_map = pred["map"]["map_features"][0]
|
54 |
+
(feats_map_rgb,) = features_to_RGB(feats_map.numpy())
|
55 |
+
|
56 |
+
text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
|
57 |
+
if has_rotation:
|
58 |
+
text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
|
59 |
+
if show_fused and "xy_fused_error" in results:
|
60 |
+
text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
|
61 |
+
text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
|
62 |
+
if show_dir_error and "directional_error" in results:
|
63 |
+
err_lat, err_lon = results["directional_error"]
|
64 |
+
text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
|
65 |
+
if "xy_gps_error" in results:
|
66 |
+
text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
|
67 |
+
|
68 |
+
map_viz = Colormap.apply(rasters)
|
69 |
+
overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
|
70 |
+
plot_images(
|
71 |
+
[image, map_viz, overlay, feats_map_rgb],
|
72 |
+
titles=[text1, "map", "likelihood", "neural map"],
|
73 |
+
dpi=75,
|
74 |
+
cmaps="jet",
|
75 |
+
)
|
76 |
+
fig = plt.gcf()
|
77 |
+
axes = fig.axes
|
78 |
+
axes[1].images[0].set_interpolation("none")
|
79 |
+
axes[2].images[0].set_interpolation("none")
|
80 |
+
Colormap.add_colorbar()
|
81 |
+
plot_nodes(1, rasters[2])
|
82 |
+
|
83 |
+
if show_gps and uv_gps is not None:
|
84 |
+
plot_pose([1], uv_gps, c="blue")
|
85 |
+
plot_pose([1], uv_gt, yaw_gt, c="red")
|
86 |
+
plot_pose([1], uv_p, yaw_p, c="k")
|
87 |
+
plot_dense_rotations(2, lp_uvt.exp())
|
88 |
+
inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
|
89 |
+
axins = add_circle_inset(axes[2], inset_center)
|
90 |
+
axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
|
91 |
+
axes[0].text(
|
92 |
+
0.003,
|
93 |
+
0.003,
|
94 |
+
f"{scene}/{name}",
|
95 |
+
transform=axes[0].transAxes,
|
96 |
+
fontsize=3,
|
97 |
+
va="bottom",
|
98 |
+
ha="left",
|
99 |
+
color="w",
|
100 |
+
)
|
101 |
+
plt.show()
|
102 |
+
if out_dir is not None:
|
103 |
+
name_ = name.replace("/", "_")
|
104 |
+
p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
|
105 |
+
save_plot(p.format("pred"))
|
106 |
+
plt.close()
|
107 |
+
|
108 |
+
if fig_for_paper:
|
109 |
+
# !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
|
110 |
+
plot_images([map_viz])
|
111 |
+
plt.gca().images[0].set_interpolation("none")
|
112 |
+
plot_nodes(0, rasters[2])
|
113 |
+
plot_pose([0], uv_gt, yaw_gt, c="red")
|
114 |
+
plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
|
115 |
+
save_plot(p.format("map"))
|
116 |
+
plt.close()
|
117 |
+
plot_images([lp_uv], cmaps="jet")
|
118 |
+
plot_dense_rotations(0, lp_uvt.exp())
|
119 |
+
save_plot(p.format("loglikelihood"), dpi=100)
|
120 |
+
plt.close()
|
121 |
+
plot_images([overlay])
|
122 |
+
plt.gca().images[0].set_interpolation("none")
|
123 |
+
axins = add_circle_inset(plt.gca(), inset_center)
|
124 |
+
axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
|
125 |
+
save_plot(p.format("likelihood"))
|
126 |
+
plt.close()
|
127 |
+
write_torch_image(
|
128 |
+
p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
|
129 |
+
)
|
130 |
+
write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
|
131 |
+
|
132 |
+
if not plot_bev:
|
133 |
+
return
|
134 |
+
|
135 |
+
feats_q = pred["features_bev"]
|
136 |
+
mask_bev = pred["valid_bev"]
|
137 |
+
prior = None
|
138 |
+
if "log_prior" in pred["map"]:
|
139 |
+
prior = pred["map"]["log_prior"][0].sigmoid()
|
140 |
+
if "bev" in pred and "confidence" in pred["bev"]:
|
141 |
+
conf_q = pred["bev"]["confidence"]
|
142 |
+
else:
|
143 |
+
conf_q = torch.norm(feats_q, dim=0)
|
144 |
+
conf_q = conf_q.masked_fill(~mask_bev, np.nan)
|
145 |
+
(feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
|
146 |
+
# feats_map_rgb, feats_q_rgb, = features_to_RGB(
|
147 |
+
# feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
|
148 |
+
norm_map = torch.norm(feats_map, dim=0)
|
149 |
+
|
150 |
+
plot_images(
|
151 |
+
[conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
|
152 |
+
titles=["BEV confidence", "BEV features", "map norm"]
|
153 |
+
+ ([] if prior is None else ["map prior"]),
|
154 |
+
dpi=50,
|
155 |
+
cmaps="jet",
|
156 |
+
)
|
157 |
+
plt.show()
|
158 |
+
|
159 |
+
if out_dir is not None:
|
160 |
+
save_plot(p.format("bev"))
|
161 |
+
plt.close()
|
162 |
+
|
163 |
+
|
164 |
+
def plot_example_sequential(
|
165 |
+
idx,
|
166 |
+
model,
|
167 |
+
pred,
|
168 |
+
data,
|
169 |
+
results,
|
170 |
+
plot_bev=True,
|
171 |
+
out_dir=None,
|
172 |
+
fig_for_paper=False,
|
173 |
+
show_gps=False,
|
174 |
+
show_fused=False,
|
175 |
+
show_dir_error=False,
|
176 |
+
show_masked_prob=False,
|
177 |
+
):
|
178 |
+
return
|
maploc/models/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
|
9 |
+
from .base import BaseModel
|
10 |
+
|
11 |
+
|
12 |
+
def get_class(mod_name, base_path, BaseClass):
|
13 |
+
"""Get the class object which inherits from BaseClass and is defined in
|
14 |
+
the module named mod_name, child of base_path.
|
15 |
+
"""
|
16 |
+
mod_path = "{}.{}".format(base_path, mod_name)
|
17 |
+
mod = __import__(mod_path, fromlist=[""])
|
18 |
+
classes = inspect.getmembers(mod, inspect.isclass)
|
19 |
+
# Filter classes defined in the module
|
20 |
+
classes = [c for c in classes if c[1].__module__ == mod_path]
|
21 |
+
# Filter classes inherited from BaseModel
|
22 |
+
classes = [c for c in classes if issubclass(c[1], BaseClass)]
|
23 |
+
assert len(classes) == 1, classes
|
24 |
+
return classes[0][1]
|
25 |
+
|
26 |
+
|
27 |
+
def get_model(name):
|
28 |
+
if name == "localizer":
|
29 |
+
name = "localizer_basic"
|
30 |
+
elif name == "rotation_localizer":
|
31 |
+
name = "localizer_basic_rotation"
|
32 |
+
elif name == "bev_localizer":
|
33 |
+
name = "localizer_bev_plane"
|
34 |
+
return get_class(name, __name__, BaseModel)
|
maploc/models/base.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
"""
|
8 |
+
Base class for trainable models.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from abc import ABCMeta, abstractmethod
|
12 |
+
from copy import copy
|
13 |
+
|
14 |
+
import omegaconf
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class BaseModel(nn.Module, metaclass=ABCMeta):
|
20 |
+
"""
|
21 |
+
What the child model is expect to declare:
|
22 |
+
default_conf: dictionary of the default configuration of the model.
|
23 |
+
It recursively updates the default_conf of all parent classes, and
|
24 |
+
it is updated by the user-provided configuration passed to __init__.
|
25 |
+
Configurations can be nested.
|
26 |
+
|
27 |
+
required_data_keys: list of expected keys in the input data dictionary.
|
28 |
+
|
29 |
+
strict_conf (optional): boolean. If false, BaseModel does not raise
|
30 |
+
an error when the user provides an unknown configuration entry.
|
31 |
+
|
32 |
+
_init(self, conf): initialization method, where conf is the final
|
33 |
+
configuration object (also accessible with `self.conf`). Accessing
|
34 |
+
unknown configuration entries will raise an error.
|
35 |
+
|
36 |
+
_forward(self, data): method that returns a dictionary of batched
|
37 |
+
prediction tensors based on a dictionary of batched input data tensors.
|
38 |
+
|
39 |
+
loss(self, pred, data): method that returns a dictionary of losses,
|
40 |
+
computed from model predictions and input data. Each loss is a batch
|
41 |
+
of scalars, i.e. a torch.Tensor of shape (B,).
|
42 |
+
The total loss to be optimized has the key `'total'`.
|
43 |
+
|
44 |
+
metrics(self, pred, data): method that returns a dictionary of metrics,
|
45 |
+
each as a batch of scalars.
|
46 |
+
"""
|
47 |
+
|
48 |
+
base_default_conf = {
|
49 |
+
"name": None,
|
50 |
+
"trainable": True, # if false: do not optimize this model parameters
|
51 |
+
"freeze_batch_normalization": False, # use test-time statistics
|
52 |
+
}
|
53 |
+
default_conf = {}
|
54 |
+
required_data_keys = []
|
55 |
+
strict_conf = True
|
56 |
+
|
57 |
+
def __init__(self, conf):
|
58 |
+
"""Perform some logic and call the _init method of the child model."""
|
59 |
+
super().__init__()
|
60 |
+
default_conf = OmegaConf.merge(
|
61 |
+
self.base_default_conf, OmegaConf.create(self.default_conf)
|
62 |
+
)
|
63 |
+
if self.strict_conf:
|
64 |
+
OmegaConf.set_struct(default_conf, True)
|
65 |
+
|
66 |
+
# fixme: backward compatibility
|
67 |
+
if "pad" in conf and "pad" not in default_conf: # backward compat.
|
68 |
+
with omegaconf.read_write(conf):
|
69 |
+
with omegaconf.open_dict(conf):
|
70 |
+
conf["interpolation"] = {"pad": conf.pop("pad")}
|
71 |
+
|
72 |
+
if isinstance(conf, dict):
|
73 |
+
conf = OmegaConf.create(conf)
|
74 |
+
self.conf = conf = OmegaConf.merge(default_conf, conf)
|
75 |
+
OmegaConf.set_readonly(conf, True)
|
76 |
+
OmegaConf.set_struct(conf, True)
|
77 |
+
self.required_data_keys = copy(self.required_data_keys)
|
78 |
+
self._init(conf)
|
79 |
+
|
80 |
+
if not conf.trainable:
|
81 |
+
for p in self.parameters():
|
82 |
+
p.requires_grad = False
|
83 |
+
|
84 |
+
def train(self, mode=True):
|
85 |
+
super().train(mode)
|
86 |
+
|
87 |
+
def freeze_bn(module):
|
88 |
+
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
89 |
+
module.eval()
|
90 |
+
|
91 |
+
if self.conf.freeze_batch_normalization:
|
92 |
+
self.apply(freeze_bn)
|
93 |
+
|
94 |
+
return self
|
95 |
+
|
96 |
+
def forward(self, data):
|
97 |
+
"""Check the data and call the _forward method of the child model."""
|
98 |
+
|
99 |
+
def recursive_key_check(expected, given):
|
100 |
+
for key in expected:
|
101 |
+
assert key in given, f"Missing key {key} in data"
|
102 |
+
if isinstance(expected, dict):
|
103 |
+
recursive_key_check(expected[key], given[key])
|
104 |
+
|
105 |
+
recursive_key_check(self.required_data_keys, data)
|
106 |
+
return self._forward(data)
|
107 |
+
|
108 |
+
@abstractmethod
|
109 |
+
def _init(self, conf):
|
110 |
+
"""To be implemented by the child class."""
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
@abstractmethod
|
114 |
+
def _forward(self, data):
|
115 |
+
"""To be implemented by the child class."""
|
116 |
+
raise NotImplementedError
|
117 |
+
|
118 |
+
def loss(self, pred, data):
|
119 |
+
"""To be implemented by the child class."""
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
def metrics(self):
|
123 |
+
return {} # no metrics
|
maploc/models/bev_net.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from torchvision.models.resnet import Bottleneck
|
5 |
+
|
6 |
+
from .base import BaseModel
|
7 |
+
from .feature_extractor import AdaptationBlock
|
8 |
+
from .utils import checkpointed
|
9 |
+
|
10 |
+
|
11 |
+
class BEVNet(BaseModel):
|
12 |
+
default_conf = {
|
13 |
+
"pretrained": True,
|
14 |
+
"num_blocks": "???",
|
15 |
+
"latent_dim": "???",
|
16 |
+
"input_dim": "${.latent_dim}",
|
17 |
+
"output_dim": "${.latent_dim}",
|
18 |
+
"confidence": False,
|
19 |
+
"norm_layer": "nn.BatchNorm2d", # normalization ind decoder blocks
|
20 |
+
"checkpointed": False, # whether to use gradient checkpointing
|
21 |
+
"padding": "zeros",
|
22 |
+
}
|
23 |
+
|
24 |
+
def _init(self, conf):
|
25 |
+
blocks = []
|
26 |
+
Block = checkpointed(Bottleneck, do=conf.checkpointed)
|
27 |
+
for i in range(conf.num_blocks):
|
28 |
+
dim = conf.input_dim if i == 0 else conf.latent_dim
|
29 |
+
blocks.append(
|
30 |
+
Block(
|
31 |
+
dim,
|
32 |
+
conf.latent_dim // Bottleneck.expansion,
|
33 |
+
norm_layer=eval(conf.norm_layer),
|
34 |
+
)
|
35 |
+
)
|
36 |
+
self.blocks = nn.Sequential(*blocks)
|
37 |
+
self.output_layer = AdaptationBlock(conf.latent_dim, conf.output_dim)
|
38 |
+
if conf.confidence:
|
39 |
+
self.confidence_layer = AdaptationBlock(conf.latent_dim, 1)
|
40 |
+
|
41 |
+
def update_padding(module):
|
42 |
+
if isinstance(module, nn.Conv2d):
|
43 |
+
module.padding_mode = conf.padding
|
44 |
+
|
45 |
+
if conf.padding != "zeros":
|
46 |
+
self.bocks.apply(update_padding)
|
47 |
+
|
48 |
+
def _forward(self, data):
|
49 |
+
features = self.blocks(data["input"])
|
50 |
+
pred = {
|
51 |
+
"output": self.output_layer(features),
|
52 |
+
}
|
53 |
+
if self.conf.confidence:
|
54 |
+
pred["confidence"] = self.confidence_layer(features).squeeze(1).sigmoid()
|
55 |
+
return pred
|
56 |
+
|
57 |
+
def loss(self, pred, data):
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def metrics(self, pred, data):
|
61 |
+
raise NotImplementedError
|
maploc/models/bev_projection.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn.functional import grid_sample
|
5 |
+
|
6 |
+
from ..utils.geometry import from_homogeneous
|
7 |
+
from .utils import make_grid
|
8 |
+
|
9 |
+
|
10 |
+
class PolarProjectionDepth(torch.nn.Module):
|
11 |
+
def __init__(self, z_max, ppm, scale_range, z_min=None):
|
12 |
+
super().__init__()
|
13 |
+
self.z_max = z_max
|
14 |
+
self.Δ = Δ = 1 / ppm
|
15 |
+
self.z_min = z_min = Δ if z_min is None else z_min
|
16 |
+
self.scale_range = scale_range
|
17 |
+
z_steps = torch.arange(z_min, z_max + Δ, Δ)
|
18 |
+
self.register_buffer("depth_steps", z_steps, persistent=False)
|
19 |
+
|
20 |
+
def sample_depth_scores(self, pixel_scales, camera):
|
21 |
+
scale_steps = camera.f[..., None, 1] / self.depth_steps.flip(-1)
|
22 |
+
log_scale_steps = torch.log2(scale_steps)
|
23 |
+
scale_min, scale_max = self.scale_range
|
24 |
+
log_scale_norm = (log_scale_steps - scale_min) / (scale_max - scale_min)
|
25 |
+
log_scale_norm = log_scale_norm * 2 - 1 # in [-1, 1]
|
26 |
+
|
27 |
+
values = pixel_scales.flatten(1, 2).unsqueeze(-1)
|
28 |
+
indices = log_scale_norm.unsqueeze(-1)
|
29 |
+
indices = torch.stack([torch.zeros_like(indices), indices], -1)
|
30 |
+
depth_scores = grid_sample(values, indices, align_corners=True)
|
31 |
+
depth_scores = depth_scores.reshape(
|
32 |
+
pixel_scales.shape[:-1] + (len(self.depth_steps),)
|
33 |
+
)
|
34 |
+
return depth_scores
|
35 |
+
|
36 |
+
def forward(
|
37 |
+
self,
|
38 |
+
image,
|
39 |
+
pixel_scales,
|
40 |
+
camera,
|
41 |
+
return_total_score=False,
|
42 |
+
):
|
43 |
+
depth_scores = self.sample_depth_scores(pixel_scales, camera)
|
44 |
+
depth_prob = torch.softmax(depth_scores, dim=1)
|
45 |
+
image_polar = torch.einsum("...dhw,...hwz->...dzw", image, depth_prob)
|
46 |
+
if return_total_score:
|
47 |
+
cell_score = torch.logsumexp(depth_scores, dim=1, keepdim=True)
|
48 |
+
return image_polar, cell_score.squeeze(1)
|
49 |
+
return image_polar
|
50 |
+
|
51 |
+
|
52 |
+
class CartesianProjection(torch.nn.Module):
|
53 |
+
def __init__(self, z_max, x_max, ppm, z_min=None):
|
54 |
+
super().__init__()
|
55 |
+
self.z_max = z_max
|
56 |
+
self.x_max = x_max
|
57 |
+
self.Δ = Δ = 1 / ppm
|
58 |
+
self.z_min = z_min = Δ if z_min is None else z_min
|
59 |
+
|
60 |
+
grid_xz = make_grid(
|
61 |
+
x_max * 2 + Δ, z_max, step_y=Δ, step_x=Δ, orig_y=Δ, orig_x=-x_max, y_up=True
|
62 |
+
)
|
63 |
+
self.register_buffer("grid_xz", grid_xz, persistent=False)
|
64 |
+
|
65 |
+
def grid_to_polar(self, cam):
|
66 |
+
f, c = cam.f[..., 0][..., None, None], cam.c[..., 0][..., None, None]
|
67 |
+
u = from_homogeneous(self.grid_xz).squeeze(-1) * f + c
|
68 |
+
z_idx = (self.grid_xz[..., 1] - self.z_min) / self.Δ # convert z value to index
|
69 |
+
z_idx = z_idx[None].expand_as(u)
|
70 |
+
grid_polar = torch.stack([u, z_idx], -1)
|
71 |
+
return grid_polar
|
72 |
+
|
73 |
+
def sample_from_polar(self, image_polar, valid_polar, grid_uz):
|
74 |
+
size = grid_uz.new_tensor(image_polar.shape[-2:][::-1])
|
75 |
+
grid_uz_norm = (grid_uz + 0.5) / size * 2 - 1
|
76 |
+
grid_uz_norm = grid_uz_norm * grid_uz.new_tensor([1, -1]) # y axis is up
|
77 |
+
image_bev = grid_sample(image_polar, grid_uz_norm, align_corners=False)
|
78 |
+
|
79 |
+
if valid_polar is None:
|
80 |
+
valid = torch.ones_like(image_polar[..., :1, :, :])
|
81 |
+
else:
|
82 |
+
valid = valid_polar.to(image_polar)[:, None]
|
83 |
+
valid = grid_sample(valid, grid_uz_norm, align_corners=False)
|
84 |
+
valid = valid.squeeze(1) > (1 - 1e-4)
|
85 |
+
|
86 |
+
return image_bev, valid
|
87 |
+
|
88 |
+
def forward(self, image_polar, valid_polar, cam):
|
89 |
+
grid_uz = self.grid_to_polar(cam)
|
90 |
+
image, valid = self.sample_from_polar(image_polar, valid_polar, grid_uz)
|
91 |
+
return image, valid, grid_uz
|
maploc/models/feature_extractor.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
"""
|
8 |
+
Flexible UNet model which takes any Torchvision backbone as encoder.
|
9 |
+
Predicts multi-level feature and makes sure that they are well aligned.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision
|
15 |
+
|
16 |
+
from .base import BaseModel
|
17 |
+
from .utils import checkpointed
|
18 |
+
|
19 |
+
|
20 |
+
class DecoderBlock(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.upsample = nn.Upsample(
|
27 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
28 |
+
)
|
29 |
+
|
30 |
+
layers = []
|
31 |
+
for i in range(num_convs):
|
32 |
+
conv = nn.Conv2d(
|
33 |
+
previous + skip if i == 0 else out,
|
34 |
+
out,
|
35 |
+
kernel_size=3,
|
36 |
+
padding=1,
|
37 |
+
bias=norm is None,
|
38 |
+
padding_mode=padding,
|
39 |
+
)
|
40 |
+
layers.append(conv)
|
41 |
+
if norm is not None:
|
42 |
+
layers.append(norm(out))
|
43 |
+
layers.append(nn.ReLU(inplace=True))
|
44 |
+
self.layers = nn.Sequential(*layers)
|
45 |
+
|
46 |
+
def forward(self, previous, skip):
|
47 |
+
upsampled = self.upsample(previous)
|
48 |
+
# If the shape of the input map `skip` is not a multiple of 2,
|
49 |
+
# it will not match the shape of the upsampled map `upsampled`.
|
50 |
+
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
|
51 |
+
# If it uses ceil_mode=True (not supported here), we should pad it.
|
52 |
+
_, _, hu, wu = upsampled.shape
|
53 |
+
_, _, hs, ws = skip.shape
|
54 |
+
assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?"
|
55 |
+
# assert (hu == hs) and (wu == ws), 'Careful about padding'
|
56 |
+
skip = skip[:, :, :hu, :wu]
|
57 |
+
return self.layers(torch.cat([upsampled, skip], dim=1))
|
58 |
+
|
59 |
+
|
60 |
+
class AdaptationBlock(nn.Sequential):
|
61 |
+
def __init__(self, inp, out):
|
62 |
+
conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True)
|
63 |
+
super().__init__(conv)
|
64 |
+
|
65 |
+
|
66 |
+
class FeatureExtractor(BaseModel):
|
67 |
+
default_conf = {
|
68 |
+
"pretrained": True,
|
69 |
+
"input_dim": 3,
|
70 |
+
"output_scales": [0, 2, 4], # what scales to adapt and output
|
71 |
+
"output_dim": 128, # # of channels in output feature maps
|
72 |
+
"encoder": "vgg16", # string (torchvision net) or list of channels
|
73 |
+
"num_downsample": 4, # how many downsample block (if VGG-style net)
|
74 |
+
"decoder": [64, 64, 64, 64], # list of channels of decoder
|
75 |
+
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
|
76 |
+
"do_average_pooling": False,
|
77 |
+
"checkpointed": False, # whether to use gradient checkpointing
|
78 |
+
"padding": "zeros",
|
79 |
+
}
|
80 |
+
mean = [0.485, 0.456, 0.406]
|
81 |
+
std = [0.229, 0.224, 0.225]
|
82 |
+
|
83 |
+
def build_encoder(self, conf):
|
84 |
+
assert isinstance(conf.encoder, str)
|
85 |
+
if conf.pretrained:
|
86 |
+
assert conf.input_dim == 3
|
87 |
+
Encoder = getattr(torchvision.models, conf.encoder)
|
88 |
+
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None)
|
89 |
+
Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed)
|
90 |
+
assert max(conf.output_scales) <= conf.num_downsample
|
91 |
+
|
92 |
+
if conf.encoder.startswith("vgg"):
|
93 |
+
# Parse the layers and pack them into downsampling blocks
|
94 |
+
# It's easy for VGG-style nets because of their linear structure.
|
95 |
+
# This does not handle strided convs and residual connections
|
96 |
+
skip_dims = []
|
97 |
+
previous_dim = None
|
98 |
+
blocks = [[]]
|
99 |
+
for i, layer in enumerate(encoder.features):
|
100 |
+
if isinstance(layer, torch.nn.Conv2d):
|
101 |
+
# Change the first conv layer if the input dim mismatches
|
102 |
+
if i == 0 and conf.input_dim != layer.in_channels:
|
103 |
+
args = {k: getattr(layer, k) for k in layer.__constants__}
|
104 |
+
args.pop("output_padding")
|
105 |
+
layer = torch.nn.Conv2d(
|
106 |
+
**{**args, "in_channels": conf.input_dim}
|
107 |
+
)
|
108 |
+
previous_dim = layer.out_channels
|
109 |
+
elif isinstance(layer, torch.nn.MaxPool2d):
|
110 |
+
assert previous_dim is not None
|
111 |
+
skip_dims.append(previous_dim)
|
112 |
+
if (conf.num_downsample + 1) == len(blocks):
|
113 |
+
break
|
114 |
+
blocks.append([]) # start a new block
|
115 |
+
if conf.do_average_pooling:
|
116 |
+
assert layer.dilation == 1
|
117 |
+
layer = torch.nn.AvgPool2d(
|
118 |
+
kernel_size=layer.kernel_size,
|
119 |
+
stride=layer.stride,
|
120 |
+
padding=layer.padding,
|
121 |
+
ceil_mode=layer.ceil_mode,
|
122 |
+
count_include_pad=False,
|
123 |
+
)
|
124 |
+
blocks[-1].append(layer)
|
125 |
+
encoder = [Block(*b) for b in blocks]
|
126 |
+
elif conf.encoder.startswith("resnet"):
|
127 |
+
# Manually define the ResNet blocks such that the downsampling comes first
|
128 |
+
assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"]
|
129 |
+
assert conf.input_dim == 3, "Unsupported for now."
|
130 |
+
block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
|
131 |
+
block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1)
|
132 |
+
block3 = encoder.layer2
|
133 |
+
block4 = encoder.layer3
|
134 |
+
block5 = encoder.layer4
|
135 |
+
blocks = [block1, block2, block3, block4, block5]
|
136 |
+
# Extract the output dimension of each block
|
137 |
+
skip_dims = [encoder.conv1.out_channels]
|
138 |
+
for i in range(1, 5):
|
139 |
+
modules = getattr(encoder, f"layer{i}")[-1]._modules
|
140 |
+
conv = sorted(k for k in modules if k.startswith("conv"))[-1]
|
141 |
+
skip_dims.append(modules[conv].out_channels)
|
142 |
+
# Add a dummy block such that the first one does not downsample
|
143 |
+
encoder = [torch.nn.Identity()] + [Block(b) for b in blocks]
|
144 |
+
skip_dims = [3] + skip_dims
|
145 |
+
# Trim based on the requested encoder size
|
146 |
+
encoder = encoder[: conf.num_downsample + 1]
|
147 |
+
skip_dims = skip_dims[: conf.num_downsample + 1]
|
148 |
+
else:
|
149 |
+
raise NotImplementedError(conf.encoder)
|
150 |
+
|
151 |
+
assert (conf.num_downsample + 1) == len(encoder)
|
152 |
+
encoder = nn.ModuleList(encoder)
|
153 |
+
|
154 |
+
return encoder, skip_dims
|
155 |
+
|
156 |
+
def _init(self, conf):
|
157 |
+
# Encoder
|
158 |
+
self.encoder, skip_dims = self.build_encoder(conf)
|
159 |
+
self.skip_dims = skip_dims
|
160 |
+
|
161 |
+
def update_padding(module):
|
162 |
+
if isinstance(module, nn.Conv2d):
|
163 |
+
module.padding_mode = conf.padding
|
164 |
+
|
165 |
+
if conf.padding != "zeros":
|
166 |
+
self.encoder.apply(update_padding)
|
167 |
+
|
168 |
+
# Decoder
|
169 |
+
if conf.decoder is not None:
|
170 |
+
assert len(conf.decoder) == (len(skip_dims) - 1)
|
171 |
+
Block = checkpointed(DecoderBlock, do=conf.checkpointed)
|
172 |
+
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
|
173 |
+
|
174 |
+
previous = skip_dims[-1]
|
175 |
+
decoder = []
|
176 |
+
for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]):
|
177 |
+
decoder.append(
|
178 |
+
Block(previous, skip, out, norm=norm, padding=conf.padding)
|
179 |
+
)
|
180 |
+
previous = out
|
181 |
+
self.decoder = nn.ModuleList(decoder)
|
182 |
+
|
183 |
+
# Adaptation layers
|
184 |
+
adaptation = []
|
185 |
+
for idx, i in enumerate(conf.output_scales):
|
186 |
+
if conf.decoder is None or i == (len(self.encoder) - 1):
|
187 |
+
input_ = skip_dims[i]
|
188 |
+
else:
|
189 |
+
input_ = conf.decoder[-1 - i]
|
190 |
+
|
191 |
+
# out_dim can be an int (same for all scales) or a list (per scale)
|
192 |
+
dim = conf.output_dim
|
193 |
+
if not isinstance(dim, int):
|
194 |
+
dim = dim[idx]
|
195 |
+
|
196 |
+
block = AdaptationBlock(input_, dim)
|
197 |
+
adaptation.append(block)
|
198 |
+
self.adaptation = nn.ModuleList(adaptation)
|
199 |
+
self.scales = [2**s for s in conf.output_scales]
|
200 |
+
|
201 |
+
def _forward(self, data):
|
202 |
+
image = data["image"]
|
203 |
+
if self.conf.pretrained:
|
204 |
+
mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
|
205 |
+
image = (image - mean[:, None, None]) / std[:, None, None]
|
206 |
+
|
207 |
+
skip_features = []
|
208 |
+
features = image
|
209 |
+
for block in self.encoder:
|
210 |
+
features = block(features)
|
211 |
+
skip_features.append(features)
|
212 |
+
|
213 |
+
if self.conf.decoder:
|
214 |
+
pre_features = [skip_features[-1]]
|
215 |
+
for block, skip in zip(self.decoder, skip_features[:-1][::-1]):
|
216 |
+
pre_features.append(block(pre_features[-1], skip))
|
217 |
+
pre_features = pre_features[::-1] # fine to coarse
|
218 |
+
else:
|
219 |
+
pre_features = skip_features
|
220 |
+
|
221 |
+
out_features = []
|
222 |
+
for adapt, i in zip(self.adaptation, self.conf.output_scales):
|
223 |
+
out_features.append(adapt(pre_features[i]))
|
224 |
+
pred = {"feature_maps": out_features, "skip_features": skip_features}
|
225 |
+
return pred
|
226 |
+
|
227 |
+
def loss(self, pred, data):
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
def metrics(self, pred, data):
|
231 |
+
raise NotImplementedError
|
maploc/models/feature_extractor_v2.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
8 |
+
|
9 |
+
from .base import BaseModel
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class DecoderBlock(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
layers = []
|
20 |
+
for i in range(num_convs):
|
21 |
+
conv = nn.Conv2d(
|
22 |
+
previous if i == 0 else out,
|
23 |
+
out,
|
24 |
+
kernel_size=ksize,
|
25 |
+
padding=ksize // 2,
|
26 |
+
bias=norm is None,
|
27 |
+
padding_mode=padding,
|
28 |
+
)
|
29 |
+
layers.append(conv)
|
30 |
+
if norm is not None:
|
31 |
+
layers.append(norm(out))
|
32 |
+
layers.append(nn.ReLU(inplace=True))
|
33 |
+
self.layers = nn.Sequential(*layers)
|
34 |
+
|
35 |
+
def forward(self, previous, skip):
|
36 |
+
_, _, hp, wp = previous.shape
|
37 |
+
_, _, hs, ws = skip.shape
|
38 |
+
scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
|
39 |
+
upsampled = nn.functional.interpolate(
|
40 |
+
previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
|
41 |
+
)
|
42 |
+
# If the shape of the input map `skip` is not a multiple of 2,
|
43 |
+
# it will not match the shape of the upsampled map `upsampled`.
|
44 |
+
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
|
45 |
+
# If it uses ceil_mode=True (not supported here), we should pad it.
|
46 |
+
_, _, hu, wu = upsampled.shape
|
47 |
+
_, _, hs, ws = skip.shape
|
48 |
+
if (hu <= hs) and (wu <= ws):
|
49 |
+
skip = skip[:, :, :hu, :wu]
|
50 |
+
elif (hu >= hs) and (wu >= ws):
|
51 |
+
skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
|
52 |
+
else:
|
53 |
+
raise ValueError(
|
54 |
+
f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
|
55 |
+
)
|
56 |
+
|
57 |
+
return self.layers(skip) + upsampled
|
58 |
+
|
59 |
+
|
60 |
+
class FPN(nn.Module):
|
61 |
+
def __init__(self, in_channels_list, out_channels, **kw):
|
62 |
+
super().__init__()
|
63 |
+
self.first = nn.Conv2d(
|
64 |
+
in_channels_list[-1], out_channels, 1, padding=0, bias=True
|
65 |
+
)
|
66 |
+
self.blocks = nn.ModuleList(
|
67 |
+
[
|
68 |
+
DecoderBlock(c, out_channels, ksize=1, **kw)
|
69 |
+
for c in in_channels_list[::-1][1:]
|
70 |
+
]
|
71 |
+
)
|
72 |
+
self.out = nn.Sequential(
|
73 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
74 |
+
nn.BatchNorm2d(out_channels),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, layers):
|
79 |
+
feats = None
|
80 |
+
for idx, x in enumerate(reversed(layers.values())):
|
81 |
+
if feats is None:
|
82 |
+
feats = self.first(x)
|
83 |
+
else:
|
84 |
+
feats = self.blocks[idx - 1](feats, x)
|
85 |
+
out = self.out(feats)
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
def remove_conv_stride(conv):
|
90 |
+
conv_new = nn.Conv2d(
|
91 |
+
conv.in_channels,
|
92 |
+
conv.out_channels,
|
93 |
+
conv.kernel_size,
|
94 |
+
bias=conv.bias is not None,
|
95 |
+
stride=1,
|
96 |
+
padding=conv.padding,
|
97 |
+
)
|
98 |
+
conv_new.weight = conv.weight
|
99 |
+
conv_new.bias = conv.bias
|
100 |
+
return conv_new
|
101 |
+
|
102 |
+
|
103 |
+
class FeatureExtractor(BaseModel):
|
104 |
+
default_conf = {
|
105 |
+
"pretrained": True,
|
106 |
+
"input_dim": 3,
|
107 |
+
"output_dim": 128, # # of channels in output feature maps
|
108 |
+
"encoder": "resnet50", # torchvision net as string
|
109 |
+
"remove_stride_from_first_conv": False,
|
110 |
+
"num_downsample": None, # how many downsample block
|
111 |
+
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
|
112 |
+
"do_average_pooling": False,
|
113 |
+
"checkpointed": False, # whether to use gradient checkpointing
|
114 |
+
}
|
115 |
+
mean = [0.485, 0.456, 0.406]
|
116 |
+
std = [0.229, 0.224, 0.225]
|
117 |
+
|
118 |
+
def build_encoder(self, conf):
|
119 |
+
assert isinstance(conf.encoder, str)
|
120 |
+
if conf.pretrained:
|
121 |
+
assert conf.input_dim == 3
|
122 |
+
Encoder = getattr(torchvision.models, conf.encoder)
|
123 |
+
|
124 |
+
kw = {}
|
125 |
+
if conf.encoder.startswith("resnet"):
|
126 |
+
layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
|
127 |
+
kw["replace_stride_with_dilation"] = [False, False, False]
|
128 |
+
elif conf.encoder == "vgg13":
|
129 |
+
layers = [
|
130 |
+
"features.3",
|
131 |
+
"features.8",
|
132 |
+
"features.13",
|
133 |
+
"features.18",
|
134 |
+
"features.23",
|
135 |
+
]
|
136 |
+
elif conf.encoder == "vgg16":
|
137 |
+
layers = [
|
138 |
+
"features.3",
|
139 |
+
"features.8",
|
140 |
+
"features.15",
|
141 |
+
"features.22",
|
142 |
+
"features.29",
|
143 |
+
]
|
144 |
+
else:
|
145 |
+
raise NotImplementedError(conf.encoder)
|
146 |
+
|
147 |
+
if conf.num_downsample is not None:
|
148 |
+
layers = layers[: conf.num_downsample]
|
149 |
+
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
|
150 |
+
encoder = create_feature_extractor(encoder, return_nodes=layers)
|
151 |
+
if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
|
152 |
+
encoder.conv1 = remove_conv_stride(encoder.conv1)
|
153 |
+
|
154 |
+
if conf.do_average_pooling:
|
155 |
+
raise NotImplementedError
|
156 |
+
if conf.checkpointed:
|
157 |
+
raise NotImplementedError
|
158 |
+
|
159 |
+
return encoder, layers
|
160 |
+
|
161 |
+
def _init(self, conf):
|
162 |
+
# Preprocessing
|
163 |
+
self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
|
164 |
+
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
|
165 |
+
|
166 |
+
# Encoder
|
167 |
+
self.encoder, self.layers = self.build_encoder(conf)
|
168 |
+
s = 128
|
169 |
+
inp = torch.zeros(1, 3, s, s)
|
170 |
+
features = list(self.encoder(inp).values())
|
171 |
+
self.skip_dims = [x.shape[1] for x in features]
|
172 |
+
self.layer_strides = [s / f.shape[-1] for f in features]
|
173 |
+
self.scales = [self.layer_strides[0]]
|
174 |
+
|
175 |
+
# Decoder
|
176 |
+
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
|
177 |
+
self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
|
178 |
+
|
179 |
+
logger.debug(
|
180 |
+
"Built feature extractor with layers {name:dim:stride}:\n"
|
181 |
+
f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
|
182 |
+
f"and output scales {self.scales}."
|
183 |
+
)
|
184 |
+
|
185 |
+
def _forward(self, data):
|
186 |
+
image = data["image"]
|
187 |
+
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
|
188 |
+
|
189 |
+
skip_features = self.encoder(image)
|
190 |
+
output = self.decoder(skip_features)
|
191 |
+
pred = {"feature_maps": [output], "skip_features": skip_features}
|
192 |
+
return pred
|
maploc/models/map_encoder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .base import BaseModel
|
7 |
+
from .feature_extractor import FeatureExtractor
|
8 |
+
|
9 |
+
|
10 |
+
class MapEncoder(BaseModel):
|
11 |
+
default_conf = {
|
12 |
+
"embedding_dim": "???",
|
13 |
+
"output_dim": None,
|
14 |
+
"num_classes": "???",
|
15 |
+
"backbone": "???",
|
16 |
+
"unary_prior": False,
|
17 |
+
}
|
18 |
+
|
19 |
+
def _init(self, conf):
|
20 |
+
self.embeddings = torch.nn.ModuleDict(
|
21 |
+
{
|
22 |
+
k: torch.nn.Embedding(n + 1, conf.embedding_dim)
|
23 |
+
for k, n in conf.num_classes.items()
|
24 |
+
}
|
25 |
+
)
|
26 |
+
input_dim = len(conf.num_classes) * conf.embedding_dim
|
27 |
+
output_dim = conf.output_dim
|
28 |
+
if output_dim is None:
|
29 |
+
output_dim = conf.backbone.output_dim
|
30 |
+
if conf.unary_prior:
|
31 |
+
output_dim += 1
|
32 |
+
if conf.backbone is None:
|
33 |
+
self.encoder = nn.Conv2d(input_dim, output_dim, 1)
|
34 |
+
elif conf.backbone == "simple":
|
35 |
+
self.encoder = nn.Sequential(
|
36 |
+
nn.Conv2d(input_dim, 128, 3, padding=1),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
nn.Conv2d(128, output_dim, 3, padding=1),
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
self.encoder = FeatureExtractor(
|
44 |
+
{
|
45 |
+
**conf.backbone,
|
46 |
+
"input_dim": input_dim,
|
47 |
+
"output_dim": output_dim,
|
48 |
+
}
|
49 |
+
)
|
50 |
+
|
51 |
+
def _forward(self, data):
|
52 |
+
embeddings = [
|
53 |
+
self.embeddings[k](data["map"][:, i])
|
54 |
+
for i, k in enumerate(("areas", "ways", "nodes"))
|
55 |
+
]
|
56 |
+
embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
|
57 |
+
if isinstance(self.encoder, BaseModel):
|
58 |
+
features = self.encoder({"image": embeddings})["feature_maps"]
|
59 |
+
else:
|
60 |
+
features = [self.encoder(embeddings)]
|
61 |
+
pred = {}
|
62 |
+
if self.conf.unary_prior:
|
63 |
+
pred["log_prior"] = [f[:, -1] for f in features]
|
64 |
+
features = [f[:, :-1] for f in features]
|
65 |
+
pred["map_features"] = features
|
66 |
+
return pred
|