sarlinpe commited on
Commit
9665c2c
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .flake8 +3 -0
  2. .gitignore +138 -0
  3. CODE_OF_CONDUCT.md +80 -0
  4. CONTRIBUTING.md +31 -0
  5. LICENSE +1 -0
  6. README.md +229 -0
  7. assets/demo.jpg +0 -0
  8. assets/teaser.svg +0 -0
  9. demo.ipynb +0 -0
  10. maploc/__init__.py +28 -0
  11. maploc/conf/__init__.py +0 -0
  12. maploc/conf/data/__init__.py +0 -0
  13. maploc/conf/data/kitti.yaml +29 -0
  14. maploc/conf/data/mapillary.yaml +40 -0
  15. maploc/conf/model/image_encoder/global.yaml +9 -0
  16. maploc/conf/model/image_encoder/resnet_fpn.yaml +7 -0
  17. maploc/conf/model/image_encoder/vgg_unet.yaml +8 -0
  18. maploc/conf/orienternet.yaml +34 -0
  19. maploc/conf/overfit.yaml +17 -0
  20. maploc/conf/training.yaml +22 -0
  21. maploc/data/__init__.py +4 -0
  22. maploc/data/dataset.py +264 -0
  23. maploc/data/image.py +140 -0
  24. maploc/data/kitti/dataset.py +306 -0
  25. maploc/data/kitti/prepare.py +123 -0
  26. maploc/data/kitti/test1_files.txt +0 -0
  27. maploc/data/kitti/test2_files.txt +0 -0
  28. maploc/data/kitti/train_files.txt +0 -0
  29. maploc/data/kitti/utils.py +79 -0
  30. maploc/data/mapillary/dataset.py +350 -0
  31. maploc/data/mapillary/download.py +180 -0
  32. maploc/data/mapillary/prepare.py +406 -0
  33. maploc/data/mapillary/splits_MGL_13loc.json +0 -0
  34. maploc/data/mapillary/utils.py +173 -0
  35. maploc/data/sequential.py +61 -0
  36. maploc/data/torch.py +111 -0
  37. maploc/data/utils.py +60 -0
  38. maploc/demo.py +209 -0
  39. maploc/evaluation/kitti.py +89 -0
  40. maploc/evaluation/mapillary.py +111 -0
  41. maploc/evaluation/run.py +252 -0
  42. maploc/evaluation/utils.py +40 -0
  43. maploc/evaluation/viz.py +178 -0
  44. maploc/models/__init__.py +34 -0
  45. maploc/models/base.py +123 -0
  46. maploc/models/bev_net.py +61 -0
  47. maploc/models/bev_projection.py +91 -0
  48. maploc/models/feature_extractor.py +231 -0
  49. maploc/models/feature_extractor_v2.py +192 -0
  50. maploc/models/map_encoder.py +66 -0
.flake8 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 88
3
+ extend-ignore = E203
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets/
2
+ experiments/
3
+ outputs/
4
+ *.mp4
5
+ lsf*
6
+ .DS_Store
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ pip-wheel-metadata/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101
+ __pypackages__/
102
+
103
+ # Celery stuff
104
+ celerybeat-schedule
105
+ celerybeat.pid
106
+
107
+ # SageMath parsed files
108
+ *.sage.py
109
+
110
+ # Environments
111
+ .env
112
+ .venv
113
+ env/
114
+ venv/
115
+ ENV/
116
+ env.bak/
117
+ venv.bak/
118
+
119
+ # Spyder project settings
120
+ .spyderproject
121
+ .spyproject
122
+
123
+ # Rope project settings
124
+ .ropeproject
125
+
126
+ # mkdocs documentation
127
+ /site
128
+
129
+ # mypy
130
+ .mypy_cache/
131
+ .dmypy.json
132
+ dmypy.json
133
+
134
+ # Pyre type checker
135
+ .pyre/
136
+
137
+ # vscode
138
+ .vscode
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@fb.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to OrienterNet
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to OrienterNet, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE ADDED
@@ -0,0 +1 @@
 
 
1
+ The MGL dataset is made available under the CC-BY-SA license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a CC-BY-NC license. OpenStreetMap data follows its own license.
README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center"><ins>OrienterNet</ins><br>Visual Localization in 2D Public Maps<br>with Neural Matching</h1>
3
+ <p align="center">
4
+ <a href="https://psarlin.com/">Paul-Edouard&nbsp;Sarlin</a>
5
+ ·
6
+ <a href="https://danieldetone.com/">Daniel&nbsp;DeTone</a>
7
+ ·
8
+ <a href="https://scholar.google.com/citations?user=WhISCE4AAAAJ&hl=en">Tsun-Yi&nbsp;Yang</a>
9
+ ·
10
+ <a href="https://scholar.google.com/citations?user=Ta4TDJoAAAAJ&hl=en">Armen&nbsp;Avetisyan</a>
11
+ ·
12
+ <a href="https://scholar.google.com/citations?hl=en&user=49_cCT8AAAAJ">Julian&nbsp;Straub</a>
13
+ <br>
14
+ <a href="https://tom.ai/">Tomasz&nbsp;Malisiewicz</a>
15
+ ·
16
+ <a href="https://scholar.google.com/citations?user=484sccEAAAAJ&hl=en">Samuel&nbsp;Rota&nbsp;Bulo</a>
17
+ ·
18
+ <a href="https://scholar.google.com/citations?hl=en&user=MhowvPkAAAAJ">Richard&nbsp;Newcombe</a>
19
+ ·
20
+ <a href="https://scholar.google.com/citations?hl=en&user=CxbDDRMAAAAJ">Peter&nbsp;Kontschieder</a>
21
+ ·
22
+ <a href="https://scholar.google.com/citations?user=AGoNHcsAAAAJ&hl=en">Vasileios&nbsp;Balntas</a>
23
+ </p>
24
+ <h2 align="center">CVPR 2023</h2>
25
+ <h3 align="center"><a href="https://arxiv.org/pdf/2304.02009.pdf">Paper</a> | <a href="https://psarlin.com/orienternet">Project Page</a> | <a href="https://youtu.be/wglW8jnupSs">Video</a></h3>
26
+ <div align="center"></div>
27
+ </p>
28
+ <p align="center">
29
+ <a href="https://psarlin.com/orienternet"><img src="assets/teaser.svg" alt="teaser" width="60%"></a>
30
+ <br>
31
+ <em>OrienterNet is a deep neural network that can accurately localize an image<br>using the same 2D semantic maps that humans use to orient themselves.</em>
32
+ </p>
33
+
34
+ ##
35
+
36
+ This repository hosts the source code for OrienterNet, a research project by Meta Reality Labs. OrienterNet leverages the power of deep learning to provide accurate positioning of images using free and globally-available maps from OpenStreetMap. As opposed to complex existing algorithms that rely on 3D point clouds, OrienterNet estimates a position and orientation by matching a neural Bird's-Eye-View with 2D maps.
37
+
38
+ ## Installation
39
+
40
+ OrienterNet requires Python >= 3.8 and [PyTorch](https://pytorch.org/). To run the demo, clone this repo and install the minimal requirements:
41
+
42
+ ```bash
43
+ git clone https://github.com/facebookresearch/OrienterNet
44
+ python -m pip install -r requirements/demo.txt
45
+ ```
46
+
47
+ To run the evaluation and training, install the full requirements:
48
+
49
+ ```bash
50
+ python -m pip install -r requirements/full.txt
51
+ ```
52
+
53
+ ## Demo ➡️ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)
54
+
55
+ Check out the Jupyter notebook [`demo.ipynb`](./demo.ipynb) ([run it on Colab!](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)) for a minimal demo - take a picture with your phone in any city and find its exact location in a few seconds!
56
+
57
+ <p align="center">
58
+ <a href="./demo.ipynb"><img src="assets/demo.jpg" alt="demo" width="60%"></a>
59
+ <br>
60
+ <em>OrienterNet positions any image within a large area - try it with your own images!</em>
61
+ </p>
62
+
63
+ ## Evaluation
64
+
65
+ #### Mapillary Geo-Localization dataset
66
+
67
+ <details>
68
+ <summary>[Click to expand]</summary>
69
+
70
+ To obtain the dataset:
71
+
72
+ 1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token.
73
+ 2. Run the following script to download the data from Mapillary and prepare it:
74
+
75
+ ```bash
76
+ python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN
77
+ ```
78
+
79
+ By default the data is written to the directory `./datasets/MGL/`. Then run the evaluation with the pre-trained model:
80
+
81
+ ```bash
82
+ python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL model.num_rotations=256
83
+ ```
84
+
85
+ This downloads the pre-trained models if necessary. The results should be close to the following:
86
+
87
+ ```
88
+ Recall xy_max_error: [14.37, 48.69, 61.7] at (1, 3, 5) m/°
89
+ Recall yaw_max_error: [20.95, 54.96, 70.17] at (1, 3, 5) m/°
90
+ ```
91
+
92
+ This requires a GPU with 11GB of memory. If you run into OOM issues, consider reducing the number of rotations (the default is 256):
93
+
94
+ ```bash
95
+ python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL \
96
+ model.num_rotations=128
97
+ ```
98
+
99
+ To export visualizations for the first 100 examples:
100
+
101
+ ```bash
102
+ python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL \
103
+ --output_dir ./viz_MGL/ --num 100
104
+ ```
105
+
106
+ To run the evaluation in sequential mode (by default with 10 frames):
107
+
108
+ ```bash
109
+ python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL --sequential
110
+ ```
111
+
112
+ </details>
113
+
114
+ #### KITTI dataset
115
+
116
+ <details>
117
+ <summary>[Click to expand]</summary>
118
+
119
+ 1. Download and prepare the dataset to `./datasets/kitti/`:
120
+
121
+ ```bash
122
+ python -m maploc.data.kitti.prepare
123
+ ```
124
+
125
+ 2. Run the evaluation with the model trained on MGL:
126
+
127
+ ```bash
128
+ python -m maploc.evaluation.kitti --experiment OrienterNet_MGL
129
+ ```
130
+
131
+ You should expect the following results:
132
+
133
+ ```
134
+ Recall directional_error: [[50.33, 85.18, 92.73], [24.38, 56.13, 67.98]] at (1, 3, 5) m/°
135
+ Recall yaw_max_error: [29.22, 68.2, 84.49] at (1, 3, 5) m/°
136
+ ```
137
+
138
+ You can similarly export some visual examples:
139
+
140
+ ```bash
141
+ python -m maploc.evaluation.kitti --experiment OrienterNet_MGL \
142
+ --output_dir ./viz_KITTI/ --num 100
143
+ ```
144
+
145
+ </details>
146
+
147
+ #### Aria Detroit & Seattle
148
+
149
+ We are currently unable to release the dataset used to evaluate OrienterNet in the CVPR 2023 paper.
150
+
151
+ ## Training
152
+
153
+ #### MGL dataset
154
+
155
+ We trained the model on the MGL dataset using 3x 3090 GPUs (24GB VRAM each) and a total batch size of 12 for 340k iterations (about 3-4 days) with the following command:
156
+
157
+ ```bash
158
+ python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
159
+ ```
160
+
161
+ Feel free to use any other experiment name. Configurations are managed by [Hydra](https://hydra.cc/) and [OmegaConf](https://omegaconf.readthedocs.io) so any entry can be overridden from the command line. You may thus reduce the number of GPUs and the batch size via:
162
+
163
+ ```bash
164
+ python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
165
+ experiment.gpus=1 data.loading.train.batch_size=4
166
+ ```
167
+
168
+ Be aware that this can reduce the overall performance. The checkpoints are written to `./experiments/experiment_name/`. Then run the evaluation:
169
+
170
+ ```bash
171
+ # the best checkpoint:
172
+ python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL_reproduce
173
+ # a specific checkpoint:
174
+ python -m maploc.evaluation.mapillary \
175
+ --experiment OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt
176
+ ```
177
+
178
+ #### KITTI
179
+
180
+ To fine-tune a trained model on the KITTI dataset:
181
+
182
+ ```bash
183
+ python -m maploc.train experiment.name=OrienterNet_MGL_kitti data=kitti \
184
+ training.finetune_from_checkpoint='"experiments/OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt"'
185
+ ```
186
+
187
+ ## Interactive development
188
+
189
+ We provide several visualization notebooks:
190
+
191
+ - [Visualize predictions on the MGL dataset](./notebooks/visualize_predictions_mgl.ipynb)
192
+ - [Visualize predictions on the KITTI dataset](./notebooks/visualize_predictions_kitti.ipynb)
193
+ - [Visualize sequential predictions](./notebooks/visualize_predictions_sequences.ipynb)
194
+
195
+ ## OpenStreetMap data
196
+
197
+ <details>
198
+ <summary>[Click to expand]</summary>
199
+
200
+ To make sure that the results are consistent over time, we used OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021. By default, the dataset scripts `maploc.data.[mapillary,kitti].prepare` download pre-generated raster tiles. If you wish to use different OSM classes, you can pass `--generate_tiles`, which will download and use our prepared raw `.osm` XML files. You may alternatively download more recent files.
201
+
202
+ </details>
203
+
204
+ ## License
205
+
206
+ The MGL dataset is made available under the [CC-BY-SA](https://creativecommons.org/licenses/by-sa/4.0/) license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a [CC-BY-NC](https://creativecommons.org/licenses/by-nc/2.0/) license. Keep in mind that OpenStreetMap [follows a different license](https://www.openstreetmap.org/copyright).
207
+
208
+ ## BibTex citation
209
+
210
+ Please consider citing our work if you use any code from this repo or ideas presented in the paper:
211
+ ```
212
+ @inproceedings{sarlin2023orienternet,
213
+ author = {Paul-Edouard Sarlin and
214
+ Daniel DeTone and
215
+ Tsun-Yi Yang and
216
+ Armen Avetisyan and
217
+ Julian Straub and
218
+ Tomasz Malisiewicz and
219
+ Samuel Rota Bulo and
220
+ Richard Newcombe and
221
+ Peter Kontschieder and
222
+ Vasileios Balntas},
223
+ title = {{OrienterNet: Visual Localization in 2D Public Maps with Neural Matching}},
224
+ booktitle = {CVPR},
225
+ year = {2023},
226
+ }
227
+ ```
228
+
229
+
assets/demo.jpg ADDED
assets/teaser.svg ADDED
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
maploc/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from pathlib import Path
4
+ import logging
5
+
6
+ import pytorch_lightning # noqa: F401
7
+
8
+
9
+ formatter = logging.Formatter(
10
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ )
13
+ handler = logging.StreamHandler()
14
+ handler.setFormatter(formatter)
15
+ handler.setLevel(logging.INFO)
16
+
17
+ logger = logging.getLogger("maploc")
18
+ logger.setLevel(logging.INFO)
19
+ logger.addHandler(handler)
20
+ logger.propagate = False
21
+
22
+ pl_logger = logging.getLogger("pytorch_lightning")
23
+ if len(pl_logger.handlers):
24
+ pl_logger.handlers[0].setFormatter(formatter)
25
+
26
+ repo_dir = Path(__file__).parent.parent
27
+ EXPERIMENTS_PATH = repo_dir / "experiments/"
28
+ DATASETS_PATH = repo_dir / "datasets/"
maploc/conf/__init__.py ADDED
File without changes
maploc/conf/data/__init__.py ADDED
File without changes
maploc/conf/data/kitti.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: kitti
2
+ loading:
3
+ train:
4
+ batch_size: 9
5
+ num_workers: ${.batch_size}
6
+ val:
7
+ batch_size: ${..train.batch_size}
8
+ num_workers: ${.batch_size}
9
+ # make sure train and val locations are at least 5m apart
10
+ selection_subset_val: furthest
11
+ max_num_val: 500
12
+ drop_train_too_close_to_val: 5.0
13
+ # map data
14
+ num_classes:
15
+ areas: 7
16
+ ways: 10
17
+ nodes: 33
18
+ pixel_per_meter: 2
19
+ crop_size_meters: 64
20
+ max_init_error: 32
21
+ # preprocessing
22
+ target_focal_length: 256
23
+ resize_image: [448, 160] # multiple of 32 at f=256px
24
+ # pad_to_multiple: 32
25
+ rectify_pitch: true
26
+ augmentation:
27
+ rot90: true
28
+ flip: true
29
+ image: {apply: true}
maploc/conf/data/mapillary.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mapillary
2
+ scenes:
3
+ - sanfrancisco_soma
4
+ - sanfrancisco_hayes
5
+ - amsterdam
6
+ - berlin
7
+ - lemans
8
+ - montrouge
9
+ - toulouse
10
+ - nantes
11
+ - vilnius
12
+ - avignon
13
+ - helsinki
14
+ - milan
15
+ - paris
16
+ split: splits_MGL_13loc.json
17
+ loading:
18
+ train:
19
+ batch_size: 12
20
+ num_workers: ${.batch_size}
21
+ val:
22
+ batch_size: ${..train.batch_size}
23
+ num_workers: ${.batch_size}
24
+ # map data
25
+ num_classes:
26
+ areas: 7
27
+ ways: 10
28
+ nodes: 33
29
+ pixel_per_meter: 2
30
+ crop_size_meters: 64
31
+ max_init_error: 48
32
+ add_map_mask: true
33
+ # preprocessing
34
+ resize_image: 512
35
+ pad_to_square: true
36
+ rectify_pitch: true
37
+ augmentation:
38
+ rot90: true
39
+ flip: true
40
+ image: {apply: true}
maploc/conf/model/image_encoder/global.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ name: feature_extractor
2
+ backbone:
3
+ encoder: resnet18
4
+ pretrained: true
5
+ output_dim: ${...latent_dim}
6
+ output_scales: [5]
7
+ num_downsample: 5
8
+ decoder: null
9
+ pooling: mean
maploc/conf/model/image_encoder/resnet_fpn.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name: feature_extractor_v2
2
+ backbone:
3
+ encoder: resnet50
4
+ pretrained: true
5
+ output_dim: ${...latent_dim}
6
+ num_downsample: null
7
+ remove_stride_from_first_conv: false
maploc/conf/model/image_encoder/vgg_unet.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: feature_extractor
2
+ backbone:
3
+ encoder: vgg16
4
+ pretrained: true
5
+ output_dim: ${...latent_dim}
6
+ output_scales: [0]
7
+ num_downsample: 4
8
+ decoder: [512, 256, 256, 128]
maploc/conf/orienternet.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: mapillary
3
+ - model/image_encoder: resnet_fpn
4
+ - training
5
+ - _self_
6
+ model:
7
+ name: orienternet
8
+ latent_dim: 128
9
+ matching_dim: 8
10
+ z_max: 32
11
+ x_max: 32
12
+ pixel_per_meter: ${data.pixel_per_meter}
13
+ num_scale_bins: 33
14
+ num_rotations: 64
15
+ image_encoder:
16
+ backbone:
17
+ encoder: resnet101
18
+ map_encoder:
19
+ embedding_dim: 16
20
+ output_dim: ${..matching_dim}
21
+ num_classes: ${data.num_classes}
22
+ backbone:
23
+ encoder: vgg19
24
+ pretrained: false
25
+ output_scales: [0]
26
+ num_downsample: 3
27
+ decoder: [128, 64, 64]
28
+ padding: replicate
29
+ unary_prior: false
30
+ bev_net:
31
+ num_blocks: 4
32
+ latent_dim: ${..latent_dim}
33
+ output_dim: ${..matching_dim}
34
+ confidence: true
maploc/conf/overfit.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - orienternet
3
+ - _self_
4
+ data:
5
+ loading:
6
+ train:
7
+ batch_size: 6
8
+ random: false
9
+ split: null
10
+ model:
11
+ freeze_batch_normalization: true
12
+ training:
13
+ trainer:
14
+ overfit_batches: 1
15
+ val_check_interval: 1
16
+ log_every_n_steps: 1
17
+ limit_val_batches: 1
maploc/conf/training.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ name: ???
3
+ gpus: 3
4
+ seed: 0
5
+ training:
6
+ lr: 1e-4
7
+ lr_scheduler: null
8
+ finetune_from_checkpoint: null
9
+ trainer:
10
+ val_check_interval: 5000
11
+ log_every_n_steps: 100
12
+ limit_val_batches: 1000
13
+ max_steps: 500000
14
+ devices: ${experiment.gpus}
15
+ checkpointing:
16
+ monitor: "loss/total/val"
17
+ save_top_k: 5
18
+ mode: "min"
19
+ hydra:
20
+ job:
21
+ name: ${experiment.name}
22
+ chdir: false
maploc/data/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .kitti.dataset import KittiDataModule
2
+ from .mapillary.dataset import MapillaryDataModule
3
+
4
+ modules = {"mapillary": MapillaryDataModule, "kitti": KittiDataModule}
maploc/data/dataset.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data as torchdata
10
+ import torchvision.transforms as tvf
11
+ from omegaconf import DictConfig, OmegaConf
12
+
13
+ from ..models.utils import deg2rad, rotmat2d
14
+ from ..osm.tiling import TileManager
15
+ from ..utils.geo import BoundaryBox
16
+ from ..utils.io import read_image
17
+ from ..utils.wrappers import Camera
18
+ from .image import pad_image, rectify_image, resize_image
19
+ from .utils import decompose_rotmat, random_flip, random_rot90
20
+
21
+
22
+ class MapLocDataset(torchdata.Dataset):
23
+ default_cfg = {
24
+ "seed": 0,
25
+ "accuracy_gps": 15,
26
+ "random": True,
27
+ "num_threads": None,
28
+ # map
29
+ "num_classes": None,
30
+ "pixel_per_meter": "???",
31
+ "crop_size_meters": "???",
32
+ "max_init_error": "???",
33
+ "max_init_error_rotation": None,
34
+ "init_from_gps": False,
35
+ "return_gps": False,
36
+ "force_camera_height": None,
37
+ # pose priors
38
+ "add_map_mask": False,
39
+ "mask_radius": None,
40
+ "mask_pad": 1,
41
+ "prior_range_rotation": None,
42
+ # image preprocessing
43
+ "target_focal_length": None,
44
+ "reduce_fov": None,
45
+ "resize_image": None,
46
+ "pad_to_square": False, # legacy
47
+ "pad_to_multiple": 32,
48
+ "rectify_pitch": True,
49
+ "augmentation": {
50
+ "rot90": False,
51
+ "flip": False,
52
+ "image": {
53
+ "apply": False,
54
+ "brightness": 0.5,
55
+ "contrast": 0.4,
56
+ "saturation": 0.4,
57
+ "hue": 0.5 / 3.14,
58
+ },
59
+ },
60
+ }
61
+
62
+ def __init__(
63
+ self,
64
+ stage: str,
65
+ cfg: DictConfig,
66
+ names: List[str],
67
+ data: Dict[str, Any],
68
+ image_dirs: Dict[str, Path],
69
+ tile_managers: Dict[str, TileManager],
70
+ image_ext: str = "",
71
+ ):
72
+ self.stage = stage
73
+ self.cfg = deepcopy(cfg)
74
+ self.data = data
75
+ self.image_dirs = image_dirs
76
+ self.tile_managers = tile_managers
77
+ self.names = names
78
+ self.image_ext = image_ext
79
+
80
+ tfs = []
81
+ if stage == "train" and cfg.augmentation.image.apply:
82
+ args = OmegaConf.masked_copy(
83
+ cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
84
+ )
85
+ tfs.append(tvf.ColorJitter(**args))
86
+ self.tfs = tvf.Compose(tfs)
87
+
88
+ def __len__(self):
89
+ return len(self.names)
90
+
91
+ def __getitem__(self, idx):
92
+ if self.stage == "train" and self.cfg.random:
93
+ seed = None
94
+ else:
95
+ seed = [self.cfg.seed, idx]
96
+ (seed,) = np.random.SeedSequence(seed).generate_state(1)
97
+
98
+ scene, seq, name = self.names[idx]
99
+ if self.cfg.init_from_gps:
100
+ latlon_gps = self.data["gps_position"][idx][:2].clone().numpy()
101
+ xy_w_init = self.tile_managers[scene].projection.project(latlon_gps)
102
+ else:
103
+ xy_w_init = self.data["t_c2w"][idx][:2].clone().double().numpy()
104
+
105
+ if "shifts" in self.data:
106
+ yaw = self.data["roll_pitch_yaw"][idx][-1]
107
+ R_c2w = rotmat2d((90 - yaw) / 180 * np.pi).float()
108
+ error = (R_c2w @ self.data["shifts"][idx][:2]).numpy()
109
+ else:
110
+ error = np.random.RandomState(seed).uniform(-1, 1, size=2)
111
+ xy_w_init += error * self.cfg.max_init_error
112
+
113
+ bbox_tile = BoundaryBox(
114
+ xy_w_init - self.cfg.crop_size_meters,
115
+ xy_w_init + self.cfg.crop_size_meters,
116
+ )
117
+ return self.get_view(idx, scene, seq, name, seed, bbox_tile)
118
+
119
+ def get_view(self, idx, scene, seq, name, seed, bbox_tile):
120
+ data = {
121
+ "index": idx,
122
+ "name": name,
123
+ "scene": scene,
124
+ "sequence": seq,
125
+ }
126
+ cam_dict = self.data["cameras"][scene][seq][self.data["camera_id"][idx]]
127
+ cam = Camera.from_dict(cam_dict).float()
128
+
129
+ if "roll_pitch_yaw" in self.data:
130
+ roll, pitch, yaw = self.data["roll_pitch_yaw"][idx].numpy()
131
+ else:
132
+ roll, pitch, yaw = decompose_rotmat(self.data["R_c2w"][idx].numpy())
133
+ image = read_image(self.image_dirs[scene] / (name + self.image_ext))
134
+
135
+ if "plane_params" in self.data:
136
+ # transform the plane parameters from world to camera frames
137
+ plane_w = self.data["plane_params"][idx]
138
+ data["ground_plane"] = torch.cat(
139
+ [rotmat2d(deg2rad(torch.tensor(yaw))) @ plane_w[:2], plane_w[2:]]
140
+ )
141
+ if self.cfg.force_camera_height is not None:
142
+ data["camera_height"] = torch.tensor(self.cfg.force_camera_height)
143
+ elif "camera_height" in self.data:
144
+ data["camera_height"] = self.data["height"][idx].clone()
145
+
146
+ # raster extraction
147
+ canvas = self.tile_managers[scene].query(bbox_tile)
148
+ xy_w_gt = self.data["t_c2w"][idx][:2].numpy()
149
+ uv_gt = canvas.to_uv(xy_w_gt)
150
+ uv_init = canvas.to_uv(bbox_tile.center)
151
+ raster = canvas.raster # C, H, W
152
+
153
+ # Map augmentations
154
+ heading = np.deg2rad(90 - yaw) # fixme
155
+ if self.stage == "train":
156
+ if self.cfg.augmentation.rot90:
157
+ raster, uv_gt, heading = random_rot90(raster, uv_gt, heading, seed)
158
+ if self.cfg.augmentation.flip:
159
+ image, raster, uv_gt, heading = random_flip(
160
+ image, raster, uv_gt, heading, seed
161
+ )
162
+ yaw = 90 - np.rad2deg(heading) # fixme
163
+
164
+ image, valid, cam, roll, pitch = self.process_image(
165
+ image, cam, roll, pitch, seed
166
+ )
167
+
168
+ # Create the mask for prior location
169
+ if self.cfg.add_map_mask:
170
+ data["map_mask"] = torch.from_numpy(self.create_map_mask(canvas))
171
+
172
+ if self.cfg.max_init_error_rotation is not None:
173
+ if "shifts" in self.data:
174
+ error = self.data["shifts"][idx][-1]
175
+ else:
176
+ error = np.random.RandomState(seed + 1).uniform(-1, 1)
177
+ error = torch.tensor(error, dtype=torch.float)
178
+ yaw_init = yaw + error * self.cfg.max_init_error_rotation
179
+ range_ = self.cfg.prior_range_rotation or self.cfg.max_init_error_rotation
180
+ data["yaw_prior"] = torch.stack([yaw_init, torch.tensor(range_)])
181
+
182
+ if self.cfg.return_gps:
183
+ gps = self.data["gps_position"][idx][:2].numpy()
184
+ xy_gps = self.tile_managers[scene].projection.project(gps)
185
+ data["uv_gps"] = torch.from_numpy(canvas.to_uv(xy_gps)).float()
186
+ data["accuracy_gps"] = torch.tensor(
187
+ min(self.cfg.accuracy_gps, self.cfg.crop_size_meters)
188
+ )
189
+
190
+ if "chunk_index" in self.data:
191
+ data["chunk_id"] = (scene, seq, self.data["chunk_index"][idx])
192
+
193
+ return {
194
+ **data,
195
+ "image": image,
196
+ "valid": valid,
197
+ "camera": cam,
198
+ "canvas": canvas,
199
+ "map": torch.from_numpy(np.ascontiguousarray(raster)).long(),
200
+ "uv": torch.from_numpy(uv_gt).float(), # TODO: maybe rename to uv?
201
+ "uv_init": torch.from_numpy(uv_init).float(), # TODO: maybe rename to uv?
202
+ "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
203
+ "pixels_per_meter": torch.tensor(canvas.ppm).float(),
204
+ }
205
+
206
+ def process_image(self, image, cam, roll, pitch, seed):
207
+ image = (
208
+ torch.from_numpy(np.ascontiguousarray(image))
209
+ .permute(2, 0, 1)
210
+ .float()
211
+ .div_(255)
212
+ )
213
+ image, valid = rectify_image(
214
+ image, cam, roll, pitch if self.cfg.rectify_pitch else None
215
+ )
216
+ roll = 0.0
217
+ if self.cfg.rectify_pitch:
218
+ pitch = 0.0
219
+
220
+ if self.cfg.target_focal_length is not None:
221
+ # resize to a canonical focal length
222
+ factor = self.cfg.target_focal_length / cam.f.numpy()
223
+ size = (np.array(image.shape[-2:][::-1]) * factor).astype(int)
224
+ image, _, cam, valid = resize_image(image, size, camera=cam, valid=valid)
225
+ size_out = self.cfg.resize_image
226
+ if size_out is None:
227
+ # round the edges up such that they are multiple of a factor
228
+ stride = self.cfg.pad_to_multiple
229
+ size_out = (np.ceil((size / stride)) * stride).astype(int)
230
+ # crop or pad such that both edges are of the given size
231
+ image, valid, cam = pad_image(
232
+ image, size_out, cam, valid, crop_and_center=True
233
+ )
234
+ elif self.cfg.resize_image is not None:
235
+ image, _, cam, valid = resize_image(
236
+ image, self.cfg.resize_image, fn=max, camera=cam, valid=valid
237
+ )
238
+ if self.cfg.pad_to_square:
239
+ # pad such that both edges are of the given size
240
+ image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid)
241
+
242
+ if self.cfg.reduce_fov is not None:
243
+ h, w = image.shape[-2:]
244
+ f = float(cam.f[0])
245
+ fov = np.arctan(w / f / 2)
246
+ w_new = round(2 * f * np.tan(self.cfg.reduce_fov * fov))
247
+ image, valid, cam = pad_image(
248
+ image, (w_new, h), cam, valid, crop_and_center=True
249
+ )
250
+
251
+ with torch.random.fork_rng(devices=[]):
252
+ torch.manual_seed(seed)
253
+ image = self.tfs(image)
254
+ return image, valid, cam, roll, pitch
255
+
256
+ def create_map_mask(self, canvas):
257
+ map_mask = np.zeros(canvas.raster.shape[-2:], bool)
258
+ radius = self.cfg.mask_radius or self.cfg.max_init_error
259
+ mask_min, mask_max = np.round(
260
+ canvas.to_uv(canvas.bbox.center)
261
+ + np.array([[-1], [1]]) * (radius + self.cfg.mask_pad) * canvas.ppm
262
+ ).astype(int)
263
+ map_mask[mask_min[1] : mask_max[1], mask_min[0] : mask_max[0]] = True
264
+ return map_mask
maploc/data/image.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Callable, Optional, Union, Sequence
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.functional as tvf
8
+ import collections
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from ..utils.geometry import from_homogeneous, to_homogeneous
12
+ from ..utils.wrappers import Camera
13
+
14
+
15
+ def rectify_image(
16
+ image: torch.Tensor,
17
+ cam: Camera,
18
+ roll: float,
19
+ pitch: Optional[float] = None,
20
+ valid: Optional[torch.Tensor] = None,
21
+ ):
22
+ *_, h, w = image.shape
23
+ grid = torch.meshgrid(
24
+ [torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
25
+ indexing="xy",
26
+ )
27
+ grid = torch.stack(grid, -1).to(image.dtype)
28
+
29
+ if pitch is not None:
30
+ args = ("ZX", (roll, pitch))
31
+ else:
32
+ args = ("Z", roll)
33
+ R = Rotation.from_euler(*args, degrees=True).as_matrix()
34
+ R = torch.from_numpy(R).to(image)
35
+
36
+ grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
37
+ grid_rect = cam.denormalize(from_homogeneous(grid_rect))
38
+ grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
39
+ rectified = torch.nn.functional.grid_sample(
40
+ image[None],
41
+ grid_norm[None],
42
+ align_corners=False,
43
+ mode="bilinear",
44
+ ).squeeze(0)
45
+ if valid is None:
46
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
47
+ else:
48
+ valid = (
49
+ torch.nn.functional.grid_sample(
50
+ valid[None, None].float(),
51
+ grid_norm[None],
52
+ align_corners=False,
53
+ mode="nearest",
54
+ )[0, 0]
55
+ > 0
56
+ )
57
+ return rectified, valid
58
+
59
+
60
+ def resize_image(
61
+ image: torch.Tensor,
62
+ size: Union[int, Sequence, np.ndarray],
63
+ fn: Optional[Callable] = None,
64
+ camera: Optional[Camera] = None,
65
+ valid: np.ndarray = None,
66
+ ):
67
+ """Resize an image to a fixed size, or according to max or min edge."""
68
+ *_, h, w = image.shape
69
+ if fn is not None:
70
+ assert isinstance(size, int)
71
+ scale = size / fn(h, w)
72
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
73
+ scale = (scale, scale)
74
+ else:
75
+ if isinstance(size, (collections.abc.Sequence, np.ndarray)):
76
+ w_new, h_new = size
77
+ elif isinstance(size, int):
78
+ w_new = h_new = size
79
+ else:
80
+ raise ValueError(f"Incorrect new size: {size}")
81
+ scale = (w_new / w, h_new / h)
82
+ if (w, h) != (w_new, h_new):
83
+ mode = tvf.InterpolationMode.BILINEAR
84
+ image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
85
+ image.clip_(0, 1)
86
+ if camera is not None:
87
+ camera = camera.scale(scale)
88
+ if valid is not None:
89
+ valid = tvf.resize(
90
+ valid.unsqueeze(0),
91
+ (h_new, w_new),
92
+ interpolation=tvf.InterpolationMode.NEAREST,
93
+ ).squeeze(0)
94
+ ret = [image, scale]
95
+ if camera is not None:
96
+ ret.append(camera)
97
+ if valid is not None:
98
+ ret.append(valid)
99
+ return ret
100
+
101
+
102
+ def pad_image(
103
+ image: torch.Tensor,
104
+ size: Union[int, Sequence, np.ndarray],
105
+ camera: Optional[Camera] = None,
106
+ valid: torch.Tensor = None,
107
+ crop_and_center: bool = False,
108
+ ):
109
+ if isinstance(size, int):
110
+ w_new = h_new = size
111
+ elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
112
+ w_new, h_new = size
113
+ else:
114
+ raise ValueError(f"Incorrect new size: {size}")
115
+ *c, h, w = image.shape
116
+ if crop_and_center:
117
+ diff = np.array([w - w_new, h - h_new])
118
+ left, top = left_top = np.round(diff / 2).astype(int)
119
+ right, bottom = diff - left_top
120
+ else:
121
+ assert h <= h_new
122
+ assert w <= w_new
123
+ top = bottom = left = right = 0
124
+ slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
125
+ slice_in = np.s_[
126
+ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
127
+ ]
128
+ if (w, h) == (w_new, h_new):
129
+ out = image
130
+ else:
131
+ out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
132
+ out[slice_out] = image[slice_in]
133
+ if camera is not None:
134
+ camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
135
+ out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
136
+ out_valid[slice_out] = True if valid is None else valid[slice_in]
137
+ if camera is not None:
138
+ return out, out_valid, camera
139
+ else:
140
+ return out, out_valid
maploc/data/kitti/dataset.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import collections
4
+ import collections.abc
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import pytorch_lightning as pl
11
+ import torch
12
+ import torch.utils.data as torchdata
13
+ from omegaconf import OmegaConf
14
+ from scipy.spatial.transform import Rotation
15
+
16
+ from ... import logger, DATASETS_PATH
17
+ from ...osm.tiling import TileManager
18
+ from ..dataset import MapLocDataset
19
+ from ..sequential import chunk_sequence
20
+ from ..torch import collate, worker_init_fn
21
+ from .utils import parse_split_file, parse_gps_file, get_camera_calibration
22
+
23
+
24
+ class KittiDataModule(pl.LightningDataModule):
25
+ default_cfg = {
26
+ **MapLocDataset.default_cfg,
27
+ "name": "kitti",
28
+ # paths and fetch
29
+ "data_dir": DATASETS_PATH / "kitti",
30
+ "tiles_filename": "tiles.pkl",
31
+ "splits": {
32
+ "train": "train_files.txt",
33
+ "val": "test1_files.txt",
34
+ "test": "test2_files.txt",
35
+ },
36
+ "loading": {
37
+ "train": "???",
38
+ "val": "${.test}",
39
+ "test": {"batch_size": 1, "num_workers": 0},
40
+ },
41
+ "max_num_val": 500,
42
+ "selection_subset_val": "furthest",
43
+ "drop_train_too_close_to_val": 5.0,
44
+ "skip_frames": 1,
45
+ "camera_index": 2,
46
+ # overwrite
47
+ "crop_size_meters": 64,
48
+ "max_init_error": 20,
49
+ "max_init_error_rotation": 10,
50
+ "add_map_mask": True,
51
+ "mask_pad": 2,
52
+ "target_focal_length": 256,
53
+ }
54
+ dummy_scene_name = "kitti"
55
+
56
+ def __init__(self, cfg, tile_manager: Optional[TileManager] = None):
57
+ super().__init__()
58
+ default_cfg = OmegaConf.create(self.default_cfg)
59
+ OmegaConf.set_struct(default_cfg, True) # cannot add new keys
60
+ self.cfg = OmegaConf.merge(default_cfg, cfg)
61
+ self.root = Path(self.cfg.data_dir)
62
+ self.tile_manager = tile_manager
63
+ if self.cfg.crop_size_meters < self.cfg.max_init_error:
64
+ raise ValueError("The ground truth location can be outside the map.")
65
+ assert self.cfg.selection_subset_val in ["random", "furthest"]
66
+ self.splits = {}
67
+ self.shifts = {}
68
+ self.calibrations = {}
69
+ self.data = {}
70
+ self.image_paths = {}
71
+
72
+ def prepare_data(self):
73
+ if not (self.root.exists() and (self.root / ".downloaded").exists()):
74
+ raise FileNotFoundError(
75
+ "Cannot find the KITTI dataset, run maploc.data.kitti.prepare"
76
+ )
77
+
78
+ def parse_split(self, split_arg):
79
+ if isinstance(split_arg, str):
80
+ names, shifts = parse_split_file(self.root / split_arg)
81
+ elif isinstance(split_arg, collections.abc.Sequence):
82
+ names = []
83
+ shifts = None
84
+ for date_drive in split_arg:
85
+ data_dir = (
86
+ self.root / date_drive / f"image_{self.cfg.camera_index:02}/data"
87
+ )
88
+ assert data_dir.exists(), data_dir
89
+ date_drive = tuple(date_drive.split("/"))
90
+ n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png"))
91
+ names.extend(n[:: self.cfg.skip_frames])
92
+ else:
93
+ raise ValueError(split_arg)
94
+ return names, shifts
95
+
96
+ def setup(self, stage: Optional[str] = None):
97
+ if stage == "fit":
98
+ stages = ["train", "val"]
99
+ elif stage is None:
100
+ stages = ["train", "val", "test"]
101
+ else:
102
+ stages = [stage]
103
+ for stage in stages:
104
+ self.splits[stage], self.shifts[stage] = self.parse_split(
105
+ self.cfg.splits[stage]
106
+ )
107
+ do_val_subset = "val" in stages and self.cfg.max_num_val is not None
108
+ if do_val_subset and self.cfg.selection_subset_val == "random":
109
+ select = np.random.RandomState(self.cfg.seed).choice(
110
+ len(self.splits["val"]), self.cfg.max_num_val, replace=False
111
+ )
112
+ self.splits["val"] = [self.splits["val"][i] for i in select]
113
+ if self.shifts["val"] is not None:
114
+ self.shifts["val"] = self.shifts["val"][select]
115
+ dates = {d for ns in self.splits.values() for d, _, _ in ns}
116
+ for d in dates:
117
+ self.calibrations[d] = get_camera_calibration(
118
+ self.root / d, self.cfg.camera_index
119
+ )
120
+ if self.tile_manager is None:
121
+ logger.info("Loading the tile manager...")
122
+ self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename)
123
+ self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()}
124
+ self.cfg.pixel_per_meter = self.tile_manager.ppm
125
+
126
+ # pack all attributes in a single tensor to optimize memory access
127
+ self.pack_data(stages)
128
+
129
+ dists = None
130
+ if do_val_subset and self.cfg.selection_subset_val == "furthest":
131
+ dists = torch.cdist(
132
+ self.data["val"]["t_c2w"][:, :2].double(),
133
+ self.data["train"]["t_c2w"][:, :2].double(),
134
+ )
135
+ min_dists = dists.min(1).values
136
+ select = torch.argsort(min_dists)[-self.cfg.max_num_val :]
137
+ dists = dists[select]
138
+ self.splits["val"] = [self.splits["val"][i] for i in select]
139
+ if self.shifts["val"] is not None:
140
+ self.shifts["val"] = self.shifts["val"][select]
141
+ for k in list(self.data["val"]):
142
+ if k != "cameras":
143
+ self.data["val"][k] = self.data["val"][k][select]
144
+ self.image_paths["val"] = self.image_paths["val"][select]
145
+
146
+ if "train" in stages and self.cfg.drop_train_too_close_to_val is not None:
147
+ if dists is None:
148
+ dists = torch.cdist(
149
+ self.data["val"]["t_c2w"][:, :2].double(),
150
+ self.data["train"]["t_c2w"][:, :2].double(),
151
+ )
152
+ drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0)
153
+ select = torch.where(~drop)[0]
154
+ logger.info(
155
+ "Dropping %d (%f %%) images that are too close to validation images.",
156
+ drop.sum(),
157
+ drop.float().mean(),
158
+ )
159
+ self.splits["train"] = [self.splits["train"][i] for i in select]
160
+ if self.shifts["train"] is not None:
161
+ self.shifts["train"] = self.shifts["train"][select]
162
+ for k in list(self.data["train"]):
163
+ if k != "cameras":
164
+ self.data["train"][k] = self.data["train"][k][select]
165
+ self.image_paths["train"] = self.image_paths["train"][select]
166
+
167
+ def pack_data(self, stages):
168
+ for stage in stages:
169
+ names = []
170
+ data = {}
171
+ for i, (date, drive, index) in enumerate(self.splits[stage]):
172
+ d = self.get_frame_data(date, drive, index)
173
+ for k, v in d.items():
174
+ if i == 0:
175
+ data[k] = []
176
+ data[k].append(v)
177
+ path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}"
178
+ names.append((self.dummy_scene_name, f"{date}/{drive}", path))
179
+ for k in list(data):
180
+ data[k] = torch.from_numpy(np.stack(data[k]))
181
+ data["camera_id"] = np.full(len(names), self.cfg.camera_index)
182
+
183
+ sequences = {date_drive for _, date_drive, _ in names}
184
+ data["cameras"] = {
185
+ self.dummy_scene_name: {
186
+ seq: {
187
+ self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0]
188
+ }
189
+ for seq in sequences
190
+ }
191
+ }
192
+ shifts = self.shifts[stage]
193
+ if shifts is not None:
194
+ data["shifts"] = torch.from_numpy(shifts.astype(np.float32))
195
+ self.data[stage] = data
196
+ self.image_paths[stage] = np.array(names)
197
+
198
+ def get_frame_data(self, date, drive, index):
199
+ _, R_cam_gps, t_cam_gps = self.calibrations[date]
200
+
201
+ # Transform the GPS pose to the camera pose
202
+ gps_path = (
203
+ self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt")
204
+ )
205
+ _, R_world_gps, t_world_gps = parse_gps_file(
206
+ gps_path, self.tile_manager.projection
207
+ )
208
+ R_world_cam = R_world_gps @ R_cam_gps.T
209
+ t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps
210
+ # Some voodoo to extract correct Euler angles from R_world_cam
211
+ R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix()
212
+ R_world_cam_xyz = R_world_cam @ R_cv_xyz
213
+ y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True)
214
+ roll, pitch, yaw = r, -p, 90 - y
215
+ roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason
216
+
217
+ return {
218
+ "t_c2w": t_world_cam.astype(np.float32),
219
+ "roll_pitch_yaw": roll_pitch_yaw,
220
+ "index": int(index.split(".")[0]),
221
+ }
222
+
223
+ def dataset(self, stage: str):
224
+ return MapLocDataset(
225
+ stage,
226
+ self.cfg,
227
+ self.image_paths[stage],
228
+ self.data[stage],
229
+ {self.dummy_scene_name: self.root},
230
+ {self.dummy_scene_name: self.tile_manager},
231
+ )
232
+
233
+ def dataloader(
234
+ self,
235
+ stage: str,
236
+ shuffle: bool = False,
237
+ num_workers: int = None,
238
+ sampler: Optional[torchdata.Sampler] = None,
239
+ ):
240
+ dataset = self.dataset(stage)
241
+ cfg = self.cfg["loading"][stage]
242
+ num_workers = cfg["num_workers"] if num_workers is None else num_workers
243
+ loader = torchdata.DataLoader(
244
+ dataset,
245
+ batch_size=cfg["batch_size"],
246
+ num_workers=num_workers,
247
+ shuffle=shuffle or (stage == "train"),
248
+ pin_memory=True,
249
+ persistent_workers=num_workers > 0,
250
+ worker_init_fn=worker_init_fn,
251
+ collate_fn=collate,
252
+ sampler=sampler,
253
+ )
254
+ return loader
255
+
256
+ def train_dataloader(self, **kwargs):
257
+ return self.dataloader("train", **kwargs)
258
+
259
+ def val_dataloader(self, **kwargs):
260
+ return self.dataloader("val", **kwargs)
261
+
262
+ def test_dataloader(self, **kwargs):
263
+ return self.dataloader("test", **kwargs)
264
+
265
+ def sequence_dataset(self, stage: str, **kwargs):
266
+ keys = self.image_paths[stage]
267
+ # group images by sequence (date/drive)
268
+ seq2indices = defaultdict(list)
269
+ for index, (_, date_drive, _) in enumerate(keys):
270
+ seq2indices[date_drive].append(index)
271
+ # chunk the sequences to the required length
272
+ chunk2indices = {}
273
+ for seq, indices in seq2indices.items():
274
+ chunks = chunk_sequence(
275
+ self.data[stage], indices, names=self.image_paths[stage], **kwargs
276
+ )
277
+ for i, sub_indices in enumerate(chunks):
278
+ chunk2indices[seq, i] = sub_indices
279
+ # store the index of each chunk in its sequence
280
+ chunk_indices = torch.full((len(keys),), -1)
281
+ for (_, chunk_index), idx in chunk2indices.items():
282
+ chunk_indices[idx] = chunk_index
283
+ self.data[stage]["chunk_index"] = chunk_indices
284
+ dataset = self.dataset(stage)
285
+ return dataset, chunk2indices
286
+
287
+ def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
288
+ dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
289
+ seq_keys = sorted(chunk2idx)
290
+ if shuffle:
291
+ perm = torch.randperm(len(seq_keys))
292
+ seq_keys = [seq_keys[i] for i in perm]
293
+ key_indices = [i for key in seq_keys for i in chunk2idx[key]]
294
+ num_workers = self.cfg["loading"][stage]["num_workers"]
295
+ loader = torchdata.DataLoader(
296
+ dataset,
297
+ batch_size=None,
298
+ sampler=key_indices,
299
+ num_workers=num_workers,
300
+ shuffle=False,
301
+ pin_memory=True,
302
+ persistent_workers=num_workers > 0,
303
+ worker_init_fn=worker_init_fn,
304
+ collate_fn=collate,
305
+ )
306
+ return loader, seq_keys, chunk2idx
maploc/data/kitti/prepare.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ import shutil
6
+ import zipfile
7
+
8
+ import numpy as np
9
+ from tqdm.auto import tqdm
10
+
11
+ from ... import logger
12
+ from ...osm.tiling import TileManager
13
+ from ...osm.viz import GeoPlotter
14
+ from ...utils.geo import BoundaryBox, Projection
15
+ from ...utils.io import download_file, DATA_URL
16
+ from .utils import parse_gps_file
17
+ from .dataset import KittiDataModule
18
+
19
+ split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
20
+
21
+
22
+ def prepare_osm(
23
+ data_dir,
24
+ osm_path,
25
+ output_path,
26
+ tile_margin=512,
27
+ ppm=2,
28
+ ):
29
+ all_latlon = []
30
+ for gps_path in data_dir.glob("2011_*/*/oxts/data/*.txt"):
31
+ all_latlon.append(parse_gps_file(gps_path)[0])
32
+ if not all_latlon:
33
+ raise ValueError(f"Cannot find any GPS file in {data_dir}.")
34
+ all_latlon = np.stack(all_latlon)
35
+ projection = Projection.from_points(all_latlon)
36
+ all_xy = projection.project(all_latlon)
37
+ bbox_map = BoundaryBox(all_xy.min(0), all_xy.max(0)) + tile_margin
38
+
39
+ plotter = GeoPlotter()
40
+ plotter.points(all_latlon, "red", name="GPS")
41
+ plotter.bbox(projection.unproject(bbox_map), "blue", "tiling bounding box")
42
+ plotter.fig.write_html(data_dir / "split_kitti.html")
43
+
44
+ tile_manager = TileManager.from_bbox(
45
+ projection,
46
+ bbox_map,
47
+ ppm,
48
+ path=osm_path,
49
+ )
50
+ tile_manager.save(output_path)
51
+ return tile_manager
52
+
53
+
54
+ def download(data_dir: Path):
55
+ data_dir.mkdir(exist_ok=True, parents=True)
56
+ this_dir = Path(__file__).parent
57
+
58
+ seqs = set()
59
+ for f in split_files:
60
+ shutil.copy(this_dir / f, data_dir)
61
+ with open(this_dir / f, "r") as fid:
62
+ info = fid.read()
63
+ for line in info.split("\n"):
64
+ if line:
65
+ seq = line.split()[0].split("/")[1][: -len("_sync")]
66
+ seqs.add(seq)
67
+ dates = {"_".join(s.split("_")[:3]) for s in seqs}
68
+ logger.info("Downloading data for %d sequences in %d dates", len(seqs), len(dates))
69
+
70
+ for seq in tqdm(seqs):
71
+ logger.info("Working on %s.", seq)
72
+ date = "_".join(seq.split("_")[:3])
73
+ url = f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{seq}/{seq}_sync.zip"
74
+ seq_dir = data_dir / date / f"{seq}_sync"
75
+ if seq_dir.exists():
76
+ continue
77
+ zip_path = download_file(url, data_dir)
78
+ with zipfile.ZipFile(zip_path, "r") as z:
79
+ z.extractall(data_dir)
80
+ # Delete unused files to save space.
81
+ for image_index in [0, 1, 3]:
82
+ shutil.rmtree(seq_dir / f"image_0{image_index}")
83
+ shutil.rmtree(seq_dir / "velodyne_points")
84
+ Path(zip_path).unlink()
85
+ break
86
+
87
+ for date in tqdm(dates):
88
+ url = (
89
+ f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{date}_calib.zip"
90
+ )
91
+ zip_path = download_file(url, data_dir)
92
+ with zipfile.ZipFile(zip_path, "r") as z:
93
+ z.extractall(data_dir)
94
+ Path(zip_path).unlink()
95
+
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument(
100
+ "--data_dir", type=Path, default=Path(KittiDataModule.default_cfg["local_dir"])
101
+ )
102
+ parser.add_argument("--pixel_per_meter", type=int, default=2)
103
+ parser.add_argument("--generate_tiles", action="store_true")
104
+ args = parser.parse_args()
105
+
106
+ args.data_dir.mkdir(exist_ok=True, parents=True)
107
+ download(args.data_dir)
108
+
109
+ tiles_path = args.data_dir / KittiDataModule.default_cfg["tiles_filename"]
110
+ if args.generate_tiles:
111
+ logger.info("Generating the map tiles.")
112
+ osm_filename = "karlsruhe.osm"
113
+ osm_path = args.data_dir / osm_filename
114
+ if not osm_path.exists():
115
+ logger.info("Downloading OSM raw data.")
116
+ download_file(DATA_URL + f"/osm/{osm_filename}", osm_path)
117
+ if not osm_path.exists():
118
+ raise FileNotFoundError(f"No OSM data file at {osm_path}.")
119
+ prepare_osm(args.data_dir, osm_path, tiles_path, ppm=args.pixel_per_meter)
120
+ (args.data_dir / ".downloaded").touch()
121
+ else:
122
+ logger.info("Downloading pre-generated map tiles.")
123
+ download_file(DATA_URL + "/tiles/kitti.pkl", tiles_path)
maploc/data/kitti/test1_files.txt ADDED
The diff for this file is too large to render. See raw diff
 
maploc/data/kitti/test2_files.txt ADDED
The diff for this file is too large to render. See raw diff
 
maploc/data/kitti/train_files.txt ADDED
The diff for this file is too large to render. See raw diff
 
maploc/data/kitti/utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from scipy.spatial.transform import Rotation
7
+
8
+ from ...utils.geo import Projection
9
+
10
+ split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
11
+
12
+
13
+ def parse_gps_file(path, projection: Projection = None):
14
+ with open(path, "r") as fid:
15
+ lat, lon, _, roll, pitch, yaw, *_ = map(float, fid.read().split())
16
+ latlon = np.array([lat, lon])
17
+ R_world_gps = Rotation.from_euler("ZYX", [yaw, pitch, roll]).as_matrix()
18
+ t_world_gps = None if projection is None else np.r_[projection.project(latlon), 0]
19
+ return latlon, R_world_gps, t_world_gps
20
+
21
+
22
+ def parse_split_file(path: Path):
23
+ with open(path, "r") as fid:
24
+ info = fid.read()
25
+ names = []
26
+ shifts = []
27
+ for line in info.split("\n"):
28
+ if not line:
29
+ continue
30
+ name, *shift = line.split()
31
+ names.append(tuple(name.split("/")))
32
+ if len(shift) > 0:
33
+ assert len(shift) == 3
34
+ shifts.append(np.array(shift, float))
35
+ shifts = None if len(shifts) == 0 else np.stack(shifts)
36
+ return names, shifts
37
+
38
+
39
+ def parse_calibration_file(path):
40
+ calib = {}
41
+ with open(path, "r") as fid:
42
+ for line in fid.read().split("\n"):
43
+ if not line:
44
+ continue
45
+ key, *data = line.split(" ")
46
+ key = key.rstrip(":")
47
+ if key.startswith("R"):
48
+ data = np.array(data, float).reshape(3, 3)
49
+ elif key.startswith("T"):
50
+ data = np.array(data, float).reshape(3)
51
+ elif key.startswith("P"):
52
+ data = np.array(data, float).reshape(3, 4)
53
+ calib[key] = data
54
+ return calib
55
+
56
+
57
+ def get_camera_calibration(calib_dir, cam_index: int):
58
+ calib_path = calib_dir / "calib_cam_to_cam.txt"
59
+ calib_cam = parse_calibration_file(calib_path)
60
+ P = calib_cam[f"P_rect_{cam_index:02}"]
61
+ K = P[:3, :3]
62
+ size = np.array(calib_cam[f"S_rect_{cam_index:02}"], float).astype(int)
63
+ camera = {
64
+ "model": "PINHOLE",
65
+ "width": size[0],
66
+ "height": size[1],
67
+ "params": K[[0, 1, 0, 1], [0, 1, 2, 2]],
68
+ }
69
+
70
+ t_cam_cam0 = P[:3, 3] / K[[0, 1, 2], [0, 1, 2]]
71
+ R_rect_cam0 = calib_cam["R_rect_00"]
72
+
73
+ calib_gps_velo = parse_calibration_file(calib_dir / "calib_imu_to_velo.txt")
74
+ calib_velo_cam0 = parse_calibration_file(calib_dir / "calib_velo_to_cam.txt")
75
+ R_cam0_gps = calib_velo_cam0["R"] @ calib_gps_velo["R"]
76
+ t_cam0_gps = calib_velo_cam0["R"] @ calib_gps_velo["T"] + calib_velo_cam0["T"]
77
+ R_cam_gps = R_rect_cam0 @ R_cam0_gps
78
+ t_cam_gps = t_cam_cam0 + R_rect_cam0 @ t_cam0_gps
79
+ return camera, R_cam_gps, t_cam_gps
maploc/data/mapillary/dataset.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ from collections import defaultdict
5
+ import os
6
+ import shutil
7
+ import tarfile
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ import torch
14
+ import torch.utils.data as torchdata
15
+ from omegaconf import DictConfig, OmegaConf
16
+
17
+ from ... import logger, DATASETS_PATH
18
+ from ...osm.tiling import TileManager
19
+ from ..dataset import MapLocDataset
20
+ from ..sequential import chunk_sequence
21
+ from ..torch import collate, worker_init_fn
22
+
23
+
24
+ def pack_dump_dict(dump):
25
+ for per_seq in dump.values():
26
+ if "points" in per_seq:
27
+ for chunk in list(per_seq["points"]):
28
+ points = per_seq["points"].pop(chunk)
29
+ if points is not None:
30
+ per_seq["points"][chunk] = np.array(
31
+ per_seq["points"][chunk], np.float64
32
+ )
33
+ for view in per_seq["views"].values():
34
+ for k in ["R_c2w", "roll_pitch_yaw"]:
35
+ view[k] = np.array(view[k], np.float32)
36
+ for k in ["chunk_id"]:
37
+ if k in view:
38
+ view.pop(k)
39
+ if "observations" in view:
40
+ view["observations"] = np.array(view["observations"])
41
+ for camera in per_seq["cameras"].values():
42
+ for k in ["params"]:
43
+ camera[k] = np.array(camera[k], np.float32)
44
+ return dump
45
+
46
+
47
+ class MapillaryDataModule(pl.LightningDataModule):
48
+ dump_filename = "dump.json"
49
+ images_archive = "images.tar.gz"
50
+ images_dirname = "images/"
51
+
52
+ default_cfg = {
53
+ **MapLocDataset.default_cfg,
54
+ "name": "mapillary",
55
+ # paths and fetch
56
+ "data_dir": DATASETS_PATH / "MGL",
57
+ "local_dir": None,
58
+ "tiles_filename": "tiles.pkl",
59
+ "scenes": "???",
60
+ "split": None,
61
+ "loading": {
62
+ "train": "???",
63
+ "val": "${.test}",
64
+ "test": {"batch_size": 1, "num_workers": 0},
65
+ },
66
+ "filter_for": None,
67
+ "filter_by_ground_angle": None,
68
+ "min_num_points": "???",
69
+ }
70
+
71
+ def __init__(self, cfg: Dict[str, Any]):
72
+ super().__init__()
73
+ default_cfg = OmegaConf.create(self.default_cfg)
74
+ OmegaConf.set_struct(default_cfg, True) # cannot add new keys
75
+ self.cfg = OmegaConf.merge(default_cfg, cfg)
76
+ self.root = Path(self.cfg.data_dir)
77
+ self.local_dir = self.cfg.local_dir or os.environ.get("TMPDIR")
78
+ if self.local_dir is not None:
79
+ self.local_dir = Path(self.local_dir, "MGL")
80
+ if self.cfg.crop_size_meters < self.cfg.max_init_error:
81
+ raise ValueError("The ground truth location can be outside the map.")
82
+
83
+ def prepare_data(self):
84
+ for scene in self.cfg.scenes:
85
+ dump_dir = self.root / scene
86
+ assert (dump_dir / self.dump_filename).exists(), dump_dir
87
+ assert (dump_dir / self.cfg.tiles_filename).exists(), dump_dir
88
+ if self.local_dir is None:
89
+ assert (dump_dir / self.images_dirname).exists(), dump_dir
90
+ continue
91
+ # Cache the folder of images locally to speed up reading
92
+ local_dir = self.local_dir / scene
93
+ if local_dir.exists():
94
+ shutil.rmtree(local_dir)
95
+ local_dir.mkdir(exist_ok=True, parents=True)
96
+ images_archive = dump_dir / self.images_archive
97
+ logger.info("Extracting the image archive %s.", images_archive)
98
+ with tarfile.open(images_archive) as fp:
99
+ fp.extractall(local_dir)
100
+
101
+ def setup(self, stage: Optional[str] = None):
102
+ self.dumps = {}
103
+ self.tile_managers = {}
104
+ self.image_dirs = {}
105
+ names = []
106
+
107
+ for scene in self.cfg.scenes:
108
+ logger.info("Loading scene %s.", scene)
109
+ dump_dir = self.root / scene
110
+ logger.info("Loading map tiles %s.", self.cfg.tiles_filename)
111
+ self.tile_managers[scene] = TileManager.load(
112
+ dump_dir / self.cfg.tiles_filename
113
+ )
114
+ groups = self.tile_managers[scene].groups
115
+ if self.cfg.num_classes: # check consistency
116
+ if set(groups.keys()) != set(self.cfg.num_classes.keys()):
117
+ raise ValueError(
118
+ f"Inconsistent groups: {groups.keys()} {self.cfg.num_classes.keys()}"
119
+ )
120
+ for k in groups:
121
+ if len(groups[k]) != self.cfg.num_classes[k]:
122
+ raise ValueError(
123
+ f"{k}: {len(groups[k])} vs {self.cfg.num_classes[k]}"
124
+ )
125
+ ppm = self.tile_managers[scene].ppm
126
+ if ppm != self.cfg.pixel_per_meter:
127
+ raise ValueError(
128
+ "The tile manager and the config/model have different ground resolutions: "
129
+ f"{ppm} vs {self.cfg.pixel_per_meter}"
130
+ )
131
+
132
+ logger.info("Loading dump json file %s.", self.dump_filename)
133
+ with (dump_dir / self.dump_filename).open("r") as fp:
134
+ self.dumps[scene] = pack_dump_dict(json.load(fp))
135
+ for seq, per_seq in self.dumps[scene].items():
136
+ for cam_id, cam_dict in per_seq["cameras"].items():
137
+ if cam_dict["model"] != "PINHOLE":
138
+ raise ValueError(
139
+ f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
140
+ )
141
+
142
+ self.image_dirs[scene] = (
143
+ (self.local_dir or self.root) / scene / self.images_dirname
144
+ )
145
+ assert self.image_dirs[scene].exists(), self.image_dirs[scene]
146
+
147
+ for seq, data in self.dumps[scene].items():
148
+ for name in data["views"]:
149
+ names.append((scene, seq, name))
150
+
151
+ self.parse_splits(self.cfg.split, names)
152
+ if self.cfg.filter_for is not None:
153
+ self.filter_elements()
154
+ self.pack_data()
155
+
156
+ def pack_data(self):
157
+ # We pack the data into compact tensors that can be shared across processes without copy
158
+ exclude = {
159
+ "compass_angle",
160
+ "compass_accuracy",
161
+ "gps_accuracy",
162
+ "chunk_key",
163
+ "panorama_offset",
164
+ }
165
+ cameras = {
166
+ scene: {seq: per_seq["cameras"] for seq, per_seq in per_scene.items()}
167
+ for scene, per_scene in self.dumps.items()
168
+ }
169
+ points = {
170
+ scene: {
171
+ seq: {
172
+ i: torch.from_numpy(p) for i, p in per_seq.get("points", {}).items()
173
+ }
174
+ for seq, per_seq in per_scene.items()
175
+ }
176
+ for scene, per_scene in self.dumps.items()
177
+ }
178
+ self.data = {}
179
+ for stage, names in self.splits.items():
180
+ view = self.dumps[names[0][0]][names[0][1]]["views"][names[0][2]]
181
+ data = {k: [] for k in view.keys() - exclude}
182
+ for scene, seq, name in names:
183
+ for k in data:
184
+ data[k].append(self.dumps[scene][seq]["views"][name].get(k, None))
185
+ for k in data:
186
+ v = np.array(data[k])
187
+ if np.issubdtype(v.dtype, np.integer) or np.issubdtype(
188
+ v.dtype, np.floating
189
+ ):
190
+ v = torch.from_numpy(v)
191
+ data[k] = v
192
+ data["cameras"] = cameras
193
+ data["points"] = points
194
+ self.data[stage] = data
195
+ self.splits[stage] = np.array(names)
196
+
197
+ def filter_elements(self):
198
+ for stage, names in self.splits.items():
199
+ names_select = []
200
+ for scene, seq, name in names:
201
+ view = self.dumps[scene][seq]["views"][name]
202
+ if self.cfg.filter_for == "ground_plane":
203
+ if not (1.0 <= view["height"] <= 3.0):
204
+ continue
205
+ planes = self.dumps[scene][seq].get("plane")
206
+ if planes is not None:
207
+ inliers = planes[str(view["chunk_id"])][-1]
208
+ if inliers < 10:
209
+ continue
210
+ if self.cfg.filter_by_ground_angle is not None:
211
+ plane = np.array(view["plane_params"])
212
+ normal = plane[:3] / np.linalg.norm(plane[:3])
213
+ angle = np.rad2deg(np.arccos(np.abs(normal[-1])))
214
+ if angle > self.cfg.filter_by_ground_angle:
215
+ continue
216
+ elif self.cfg.filter_for == "pointcloud":
217
+ if len(view["observations"]) < self.cfg.min_num_points:
218
+ continue
219
+ elif self.cfg.filter_for is not None:
220
+ raise ValueError(f"Unknown filtering: {self.cfg.filter_for}")
221
+ names_select.append((scene, seq, name))
222
+ logger.info(
223
+ "%s: Keep %d/%d images after filtering for %s.",
224
+ stage,
225
+ len(names_select),
226
+ len(names),
227
+ self.cfg.filter_for,
228
+ )
229
+ self.splits[stage] = names_select
230
+
231
+ def parse_splits(self, split_arg, names):
232
+ if split_arg is None:
233
+ self.splits = {
234
+ "train": names,
235
+ "val": names,
236
+ }
237
+ elif isinstance(split_arg, int):
238
+ names = np.random.RandomState(self.cfg.seed).permutation(names).tolist()
239
+ self.splits = {
240
+ "train": names[split_arg:],
241
+ "val": names[:split_arg],
242
+ }
243
+ elif isinstance(split_arg, DictConfig):
244
+ scenes_val = set(split_arg.val)
245
+ scenes_train = set(split_arg.train)
246
+ assert len(scenes_val - set(self.cfg.scenes)) == 0
247
+ assert len(scenes_train - set(self.cfg.scenes)) == 0
248
+ self.splits = {
249
+ "train": [n for n in names if n[0] in scenes_train],
250
+ "val": [n for n in names if n[0] in scenes_val],
251
+ }
252
+ elif isinstance(split_arg, str):
253
+ with (self.root / split_arg).open("r") as fp:
254
+ splits = json.load(fp)
255
+ splits = {
256
+ k: {loc: set(ids) for loc, ids in split.items()}
257
+ for k, split in splits.items()
258
+ }
259
+ self.splits = {}
260
+ for k, split in splits.items():
261
+ self.splits[k] = [
262
+ n
263
+ for n in names
264
+ if n[0] in split and int(n[-1].rsplit("_", 1)[0]) in split[n[0]]
265
+ ]
266
+ else:
267
+ raise ValueError(split_arg)
268
+
269
+ def dataset(self, stage: str):
270
+ return MapLocDataset(
271
+ stage,
272
+ self.cfg,
273
+ self.splits[stage],
274
+ self.data[stage],
275
+ self.image_dirs,
276
+ self.tile_managers,
277
+ image_ext=".jpg",
278
+ )
279
+
280
+ def dataloader(
281
+ self,
282
+ stage: str,
283
+ shuffle: bool = False,
284
+ num_workers: int = None,
285
+ sampler: Optional[torchdata.Sampler] = None,
286
+ ):
287
+ dataset = self.dataset(stage)
288
+ cfg = self.cfg["loading"][stage]
289
+ num_workers = cfg["num_workers"] if num_workers is None else num_workers
290
+ loader = torchdata.DataLoader(
291
+ dataset,
292
+ batch_size=cfg["batch_size"],
293
+ num_workers=num_workers,
294
+ shuffle=shuffle or (stage == "train"),
295
+ pin_memory=True,
296
+ persistent_workers=num_workers > 0,
297
+ worker_init_fn=worker_init_fn,
298
+ collate_fn=collate,
299
+ sampler=sampler,
300
+ )
301
+ return loader
302
+
303
+ def train_dataloader(self, **kwargs):
304
+ return self.dataloader("train", **kwargs)
305
+
306
+ def val_dataloader(self, **kwargs):
307
+ return self.dataloader("val", **kwargs)
308
+
309
+ def test_dataloader(self, **kwargs):
310
+ return self.dataloader("test", **kwargs)
311
+
312
+ def sequence_dataset(self, stage: str, **kwargs):
313
+ keys = self.splits[stage]
314
+ seq2indices = defaultdict(list)
315
+ for index, (_, seq, _) in enumerate(keys):
316
+ seq2indices[seq].append(index)
317
+ # chunk the sequences to the required length
318
+ chunk2indices = {}
319
+ for seq, indices in seq2indices.items():
320
+ chunks = chunk_sequence(self.data[stage], indices, **kwargs)
321
+ for i, sub_indices in enumerate(chunks):
322
+ chunk2indices[seq, i] = sub_indices
323
+ # store the index of each chunk in its sequence
324
+ chunk_indices = torch.full((len(keys),), -1)
325
+ for (_, chunk_index), idx in chunk2indices.items():
326
+ chunk_indices[idx] = chunk_index
327
+ self.data[stage]["chunk_index"] = chunk_indices
328
+ dataset = self.dataset(stage)
329
+ return dataset, chunk2indices
330
+
331
+ def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
332
+ dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
333
+ chunk_keys = sorted(chunk2idx)
334
+ if shuffle:
335
+ perm = torch.randperm(len(chunk_keys))
336
+ chunk_keys = [chunk_keys[i] for i in perm]
337
+ key_indices = [i for key in chunk_keys for i in chunk2idx[key]]
338
+ num_workers = self.cfg["loading"][stage]["num_workers"]
339
+ loader = torchdata.DataLoader(
340
+ dataset,
341
+ batch_size=None,
342
+ sampler=key_indices,
343
+ num_workers=num_workers,
344
+ shuffle=False,
345
+ pin_memory=True,
346
+ persistent_workers=num_workers > 0,
347
+ worker_init_fn=worker_init_fn,
348
+ collate_fn=collate,
349
+ )
350
+ return loader, chunk_keys, chunk2idx
maploc/data/mapillary/download.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import httpx
8
+ import asyncio
9
+ from aiolimiter import AsyncLimiter
10
+ import tqdm
11
+
12
+ from opensfm.pygeometry import Camera, Pose
13
+ from opensfm.pymap import Shot
14
+
15
+ from ... import logger
16
+ from ...utils.geo import Projection
17
+
18
+
19
+ semaphore = asyncio.Semaphore(100) # number of parallel threads.
20
+ image_filename = "{image_id}.jpg"
21
+ info_filename = "{image_id}.json"
22
+
23
+
24
+ class MapillaryDownloader:
25
+ image_fields = (
26
+ "id",
27
+ "height",
28
+ "width",
29
+ "camera_parameters",
30
+ "camera_type",
31
+ "captured_at",
32
+ "compass_angle",
33
+ "geometry",
34
+ "altitude",
35
+ "computed_compass_angle",
36
+ "computed_geometry",
37
+ "computed_altitude",
38
+ "computed_rotation",
39
+ "thumb_2048_url",
40
+ "thumb_original_url",
41
+ "sequence",
42
+ "sfm_cluster",
43
+ )
44
+ image_info_url = (
45
+ "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
46
+ )
47
+ seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}"
48
+ max_requests_per_minute = 50_000
49
+
50
+ def __init__(self, token: str):
51
+ self.token = token
52
+ self.client = httpx.AsyncClient(
53
+ transport=httpx.AsyncHTTPTransport(retries=20), timeout=20.0
54
+ )
55
+ self.limiter = AsyncLimiter(self.max_requests_per_minute // 2, time_period=60)
56
+
57
+ async def call_api(self, url: str):
58
+ async with self.limiter:
59
+ r = await self.client.get(url)
60
+ if not r.is_success:
61
+ logger.error("Error in API call: %s", r.text)
62
+ return r
63
+
64
+ async def get_image_info(self, image_id: int):
65
+ url = self.image_info_url.format(
66
+ image_id=image_id,
67
+ token=self.token,
68
+ fields=",".join(self.image_fields),
69
+ )
70
+ r = await self.call_api(url)
71
+ if r.is_success:
72
+ return json.loads(r.text)
73
+
74
+ async def get_sequence_info(self, seq_id: str):
75
+ url = self.seq_info_url.format(seq_id=seq_id, token=self.token)
76
+ r = await self.call_api(url)
77
+ if r.is_success:
78
+ return json.loads(r.text)
79
+
80
+ async def download_image_pixels(self, url: str, path: Path):
81
+ r = await self.call_api(url)
82
+ if r.is_success:
83
+ with open(path, "wb") as fid:
84
+ fid.write(r.content)
85
+ return r.is_success
86
+
87
+ async def get_image_info_cached(self, image_id: int, path: Path):
88
+ if path.exists():
89
+ info = json.loads(path.read_text())
90
+ else:
91
+ info = await self.get_image_info(image_id)
92
+ path.write_text(json.dumps(info))
93
+ return info
94
+
95
+ async def download_image_pixels_cached(self, url: str, path: Path):
96
+ if path.exists():
97
+ return True
98
+ else:
99
+ return await self.download_image_pixels(url, path)
100
+
101
+
102
+ async def fetch_images_in_sequence(i, downloader):
103
+ async with semaphore:
104
+ info = await downloader.get_sequence_info(i)
105
+ image_ids = [int(d["id"]) for d in info["data"]]
106
+ return i, image_ids
107
+
108
+
109
+ async def fetch_images_in_sequences(sequence_ids, downloader):
110
+ seq_to_images_ids = {}
111
+ tasks = [fetch_images_in_sequence(i, downloader) for i in sequence_ids]
112
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
113
+ i, image_ids = await task
114
+ seq_to_images_ids[i] = image_ids
115
+ return seq_to_images_ids
116
+
117
+
118
+ async def fetch_image_info(i, downloader, dir_):
119
+ async with semaphore:
120
+ path = dir_ / info_filename.format(image_id=i)
121
+ info = await downloader.get_image_info_cached(i, path)
122
+ return i, info
123
+
124
+
125
+ async def fetch_image_infos(image_ids, downloader, dir_):
126
+ infos = {}
127
+ num_fail = 0
128
+ tasks = [fetch_image_info(i, downloader, dir_) for i in image_ids]
129
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
130
+ i, info = await task
131
+ if info is None:
132
+ num_fail += 1
133
+ else:
134
+ infos[i] = info
135
+ return infos, num_fail
136
+
137
+
138
+ async def fetch_image_pixels(i, url, downloader, dir_, overwrite=False):
139
+ async with semaphore:
140
+ path = dir_ / image_filename.format(image_id=i)
141
+ if overwrite:
142
+ path.unlink(missing_ok=True)
143
+ success = await downloader.download_image_pixels_cached(url, path)
144
+ return i, success
145
+
146
+
147
+ async def fetch_images_pixels(image_urls, downloader, dir_):
148
+ num_fail = 0
149
+ tasks = [fetch_image_pixels(*id_url, downloader, dir_) for id_url in image_urls]
150
+ for task in tqdm.asyncio.tqdm.as_completed(tasks):
151
+ i, success = await task
152
+ num_fail += not success
153
+ return num_fail
154
+
155
+
156
+ def opensfm_camera_from_info(info: dict) -> Camera:
157
+ cam_type = info["camera_type"]
158
+ if cam_type == "perspective":
159
+ camera = Camera.create_perspective(*info["camera_parameters"])
160
+ elif cam_type == "fisheye":
161
+ camera = Camera.create_fisheye(*info["camera_parameters"])
162
+ elif Camera.is_panorama(cam_type):
163
+ camera = Camera.create_spherical()
164
+ else:
165
+ raise ValueError(cam_type)
166
+ camera.width = info["width"]
167
+ camera.height = info["height"]
168
+ camera.id = info["id"]
169
+ return camera
170
+
171
+
172
+ def opensfm_shot_from_info(info: dict, projection: Projection) -> Shot:
173
+ latlong = info["computed_geometry"]["coordinates"][::-1]
174
+ alt = info["computed_altitude"]
175
+ xyz = projection.project(np.array([*latlong, alt]), return_z=True)
176
+ c_rotvec_w = np.array(info["computed_rotation"])
177
+ pose = Pose()
178
+ pose.set_from_cam_to_world(-c_rotvec_w, xyz)
179
+ camera = opensfm_camera_from_info(info)
180
+ return latlong, Shot(info["id"], camera, pose)
maploc/data/mapillary/prepare.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import asyncio
4
+ import argparse
5
+ from collections import defaultdict
6
+ import json
7
+ import shutil
8
+ from pathlib import Path
9
+ from typing import List
10
+
11
+ import numpy as np
12
+ import cv2
13
+ from tqdm import tqdm
14
+ from tqdm.contrib.concurrent import thread_map
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from opensfm.pygeometry import Camera
17
+ from opensfm.pymap import Shot
18
+ from opensfm.undistort import (
19
+ perspective_camera_from_fisheye,
20
+ perspective_camera_from_perspective,
21
+ )
22
+
23
+ from ... import logger
24
+ from ...osm.tiling import TileManager
25
+ from ...osm.viz import GeoPlotter
26
+ from ...utils.geo import BoundaryBox, Projection
27
+ from ...utils.io import write_json, download_file, DATA_URL
28
+ from ..utils import decompose_rotmat
29
+ from .utils import (
30
+ keyframe_selection,
31
+ perspective_camera_from_pano,
32
+ scale_camera,
33
+ CameraUndistorter,
34
+ PanoramaUndistorter,
35
+ undistort_shot,
36
+ )
37
+ from .download import (
38
+ MapillaryDownloader,
39
+ opensfm_shot_from_info,
40
+ image_filename,
41
+ fetch_image_infos,
42
+ fetch_images_pixels,
43
+ )
44
+ from .dataset import MapillaryDataModule
45
+
46
+
47
+ location_to_params = {
48
+ "sanfrancisco_soma": {
49
+ "bbox": BoundaryBox(
50
+ [-122.410307, 37.770364][::-1], [-122.388772, 37.795545][::-1]
51
+ ),
52
+ "camera_models": ["GoPro Max"],
53
+ "osm_file": "sanfrancisco.osm",
54
+ },
55
+ "sanfrancisco_hayes": {
56
+ "bbox": BoundaryBox(
57
+ [-122.438415, 37.768634][::-1], [-122.410605, 37.783894][::-1]
58
+ ),
59
+ "camera_models": ["GoPro Max"],
60
+ "osm_file": "sanfrancisco.osm",
61
+ },
62
+ "amsterdam": {
63
+ "bbox": BoundaryBox([4.845284, 52.340679][::-1], [4.926147, 52.386299][::-1]),
64
+ "camera_models": ["GoPro Max"],
65
+ "osm_file": "amsterdam.osm",
66
+ },
67
+ "lemans": {
68
+ "bbox": BoundaryBox([0.185752, 47.995125][::-1], [0.224088, 48.014209][::-1]),
69
+ "owners": ["xXOocM1jUB4jaaeukKkmgw"], # sogefi
70
+ "osm_file": "lemans.osm",
71
+ },
72
+ "berlin": {
73
+ "bbox": BoundaryBox([13.416271, 52.459656][::-1], [13.469829, 52.499195][::-1]),
74
+ "owners": ["LT3ajUxH6qsosamrOHIrFw"], # supaplex030
75
+ "osm_file": "berlin.osm",
76
+ },
77
+ "montrouge": {
78
+ "bbox": BoundaryBox([2.298958, 48.80874][::-1], [2.332989, 48.825276][::-1]),
79
+ "owners": [
80
+ "XtzGKZX2_VIJRoiJ8IWRNQ",
81
+ "C4ENdWpJdFNf8CvnQd7NrQ",
82
+ "e_ZBE6mFd7CYNjRSpLl-Lg",
83
+ ], # overflorian, phyks, francois2
84
+ "camera_models": ["LG-R105"],
85
+ "osm_file": "paris.osm",
86
+ },
87
+ "nantes": {
88
+ "bbox": BoundaryBox([-1.585839, 47.198289][::-1], [-1.51318, 47.236161][::-1]),
89
+ "owners": [
90
+ "jGdq3CL-9N-Esvj3mtCWew",
91
+ "s-j5BH9JRIzsgORgaJF3aA",
92
+ ], # c_mobilite, cartocite
93
+ "osm_file": "nantes.osm",
94
+ },
95
+ "toulouse": {
96
+ "bbox": BoundaryBox([1.429457, 43.591434][::-1], [1.456653, 43.61343][::-1]),
97
+ "owners": ["MNkhq6MCoPsdQNGTMh3qsQ"], # tyndare
98
+ "osm_file": "toulouse.osm",
99
+ },
100
+ "vilnius": {
101
+ "bbox": BoundaryBox([25.258633, 54.672956][::-1], [25.296094, 54.696755][::-1]),
102
+ "owners": ["bClduFF6Gq16cfwCdhWivw", "u5ukBseATUS8jUbtE43fcO"], # kedas, vms
103
+ "osm_file": "vilnius.osm",
104
+ },
105
+ "helsinki": {
106
+ "bbox": BoundaryBox(
107
+ [24.8975480117, 60.1449128318][::-1], [24.9816543235, 60.1770977471][::-1]
108
+ ),
109
+ "camera_types": ["spherical", "equirectangular"],
110
+ "osm_file": "helsinki.osm",
111
+ },
112
+ "milan": {
113
+ "bbox": BoundaryBox(
114
+ [9.1732723899, 45.4810977947][::-1],
115
+ [9.2255987917, 45.5284238563][::-1],
116
+ ),
117
+ "camera_types": ["spherical", "equirectangular"],
118
+ "osm_file": "milan.osm",
119
+ },
120
+ "avignon": {
121
+ "bbox": BoundaryBox(
122
+ [4.7887045302, 43.9416178156][::-1], [4.8227015622, 43.9584848909][::-1]
123
+ ),
124
+ "camera_types": ["spherical", "equirectangular"],
125
+ "osm_file": "avignon.osm",
126
+ },
127
+ "paris": {
128
+ "bbox": BoundaryBox([2.306823, 48.833827][::-1], [2.39067, 48.889335][::-1]),
129
+ "camera_types": ["spherical", "equirectangular"],
130
+ "osm_file": "paris.osm",
131
+ },
132
+ }
133
+
134
+
135
+ cfg = OmegaConf.create(
136
+ {
137
+ "max_image_size": 512,
138
+ "do_legacy_pano_offset": True,
139
+ "min_dist_between_keyframes": 4,
140
+ "tiling": {
141
+ "tile_size": 128,
142
+ "margin": 128,
143
+ "ppm": 2,
144
+ },
145
+ }
146
+ )
147
+
148
+
149
+ def get_pano_offset(image_info: dict, do_legacy: bool = False) -> float:
150
+ if do_legacy:
151
+ seed = int(image_info["sfm_cluster"]["id"])
152
+ else:
153
+ seed = image_info["sequence"].__hash__()
154
+ seed = seed % (2**32 - 1)
155
+ return np.random.RandomState(seed).uniform(-45, 45)
156
+
157
+
158
+ def process_shot(
159
+ shot: Shot, info: dict, image_path: Path, output_dir: Path, cfg: DictConfig
160
+ ) -> List[Shot]:
161
+ if not image_path.exists():
162
+ return None
163
+
164
+ image_orig = cv2.imread(str(image_path))
165
+ max_size = cfg.max_image_size
166
+ pano_offset = None
167
+
168
+ camera = shot.camera
169
+ camera.width, camera.height = image_orig.shape[:2][::-1]
170
+ if camera.is_panorama(camera.projection_type):
171
+ camera_new = perspective_camera_from_pano(camera, max_size)
172
+ undistorter = PanoramaUndistorter(camera, camera_new)
173
+ pano_offset = get_pano_offset(info, cfg.do_legacy_pano_offset)
174
+ elif camera.projection_type in ["fisheye", "perspective"]:
175
+ if camera.projection_type == "fisheye":
176
+ camera_new = perspective_camera_from_fisheye(camera)
177
+ else:
178
+ camera_new = perspective_camera_from_perspective(camera)
179
+ camera_new = scale_camera(camera_new, max_size)
180
+ camera_new.id = camera.id + "_undistorted"
181
+ undistorter = CameraUndistorter(camera, camera_new)
182
+ else:
183
+ raise NotImplementedError(camera.projection_type)
184
+
185
+ shots_undist, images_undist = undistort_shot(
186
+ image_orig, shot, undistorter, pano_offset
187
+ )
188
+ for shot, image in zip(shots_undist, images_undist):
189
+ cv2.imwrite(str(output_dir / f"{shot.id}.jpg"), image)
190
+
191
+ return shots_undist
192
+
193
+
194
+ def pack_shot_dict(shot: Shot, info: dict) -> dict:
195
+ latlong = info["computed_geometry"]["coordinates"][::-1]
196
+ latlong_gps = info["geometry"]["coordinates"][::-1]
197
+ w_p_c = shot.pose.get_origin()
198
+ w_r_c = shot.pose.get_R_cam_to_world()
199
+ rpy = decompose_rotmat(w_r_c)
200
+ return dict(
201
+ camera_id=shot.camera.id,
202
+ latlong=latlong,
203
+ t_c2w=w_p_c,
204
+ R_c2w=w_r_c,
205
+ roll_pitch_yaw=rpy,
206
+ capture_time=info["captured_at"],
207
+ gps_position=np.r_[latlong_gps, info["altitude"]],
208
+ compass_angle=info["compass_angle"],
209
+ chunk_id=int(info["sfm_cluster"]["id"]),
210
+ )
211
+
212
+
213
+ def pack_camera_dict(camera: Camera) -> dict:
214
+ assert camera.projection_type == "perspective"
215
+ K = camera.get_K_in_pixel_coordinates(camera.width, camera.height)
216
+ return dict(
217
+ id=camera.id,
218
+ model="PINHOLE",
219
+ width=camera.width,
220
+ height=camera.height,
221
+ params=K[[0, 1, 0, 1], [0, 1, 2, 2]],
222
+ )
223
+
224
+
225
+ def process_sequence(
226
+ image_ids: List[int],
227
+ image_infos: dict,
228
+ projection: Projection,
229
+ cfg: DictConfig,
230
+ raw_image_dir: Path,
231
+ out_image_dir: Path,
232
+ ):
233
+ shots = []
234
+ image_ids = sorted(image_ids, key=lambda i: image_infos[i]["captured_at"])
235
+ for i in image_ids:
236
+ _, shot = opensfm_shot_from_info(image_infos[i], projection)
237
+ shots.append(shot)
238
+ if not shots:
239
+ return {}
240
+
241
+ shot_idxs = keyframe_selection(shots, min_dist=cfg.min_dist_between_keyframes)
242
+ shots = [shots[i] for i in shot_idxs]
243
+
244
+ shots_out = thread_map(
245
+ lambda shot: process_shot(
246
+ shot,
247
+ image_infos[int(shot.id)],
248
+ raw_image_dir / image_filename.format(image_id=shot.id),
249
+ out_image_dir,
250
+ cfg,
251
+ ),
252
+ shots,
253
+ disable=True,
254
+ )
255
+ shots_out = [(i, s) for i, ss in enumerate(shots_out) for s in ss if ss is not None]
256
+
257
+ dump = {}
258
+ for index, shot in shots_out:
259
+ i, suffix = shot.id.rsplit("_", 1)
260
+ info = image_infos[int(i)]
261
+ seq_id = info["sequence"]
262
+ is_pano = not suffix.endswith("undistorted")
263
+ if is_pano:
264
+ seq_id += f"_{suffix}"
265
+ if seq_id not in dump:
266
+ dump[seq_id] = dict(views={}, cameras={})
267
+
268
+ view = pack_shot_dict(shot, info)
269
+ view["index"] = index
270
+ dump[seq_id]["views"][shot.id] = view
271
+ dump[seq_id]["cameras"][shot.camera.id] = pack_camera_dict(shot.camera)
272
+ return dump
273
+
274
+
275
+ def process_location(
276
+ location: str,
277
+ data_dir: Path,
278
+ split_path: Path,
279
+ token: str,
280
+ generate_tiles: bool = False,
281
+ ):
282
+ params = location_to_params[location]
283
+ bbox = params["bbox"]
284
+ projection = Projection(*bbox.center)
285
+
286
+ splits = json.loads(split_path.read_text())
287
+ image_ids = [i for split in splits.values() for i in split[location]]
288
+
289
+ loc_dir = data_dir / location
290
+ infos_dir = loc_dir / "image_infos"
291
+ raw_image_dir = loc_dir / "images_raw"
292
+ out_image_dir = loc_dir / "images"
293
+ for d in (infos_dir, raw_image_dir, out_image_dir):
294
+ d.mkdir(parents=True, exist_ok=True)
295
+
296
+ downloader = MapillaryDownloader(token)
297
+ loop = asyncio.get_event_loop()
298
+
299
+ logger.info("Fetching metadata for all images.")
300
+ image_infos, num_fail = loop.run_until_complete(
301
+ fetch_image_infos(image_ids, downloader, infos_dir)
302
+ )
303
+ logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_ids))
304
+
305
+ logger.info("Fetching image pixels.")
306
+ image_urls = [(i, info["thumb_2048_url"]) for i, info in image_infos.items()]
307
+ num_fail = loop.run_until_complete(
308
+ fetch_images_pixels(image_urls, downloader, raw_image_dir)
309
+ )
310
+ logger.info("%d failures (%.1f%%).", num_fail, 100 * num_fail / len(image_urls))
311
+
312
+ seq_to_image_ids = defaultdict(list)
313
+ for i, info in image_infos.items():
314
+ seq_to_image_ids[info["sequence"]].append(i)
315
+ seq_to_image_ids = dict(seq_to_image_ids)
316
+
317
+ dump = {}
318
+ for seq_image_ids in tqdm(seq_to_image_ids.values()):
319
+ dump.update(
320
+ process_sequence(
321
+ seq_image_ids,
322
+ image_infos,
323
+ projection,
324
+ cfg,
325
+ raw_image_dir,
326
+ out_image_dir,
327
+ )
328
+ )
329
+ write_json(loc_dir / "dump.json", dump)
330
+
331
+ # Get the view locations
332
+ view_ids = []
333
+ views_latlon = []
334
+ for seq in dump:
335
+ for view_id, view in dump[seq]["views"].items():
336
+ view_ids.append(view_id)
337
+ views_latlon.append(view["latlong"])
338
+ views_latlon = np.stack(views_latlon)
339
+ view_ids = np.array(view_ids)
340
+ views_xy = projection.project(views_latlon)
341
+
342
+ tiles_path = loc_dir / MapillaryDataModule.default_cfg["tiles_filename"]
343
+ if generate_tiles:
344
+ logger.info("Creating the map tiles.")
345
+ bbox_data = BoundaryBox(views_xy.min(0), views_xy.max(0))
346
+ bbox_tiling = bbox_data + cfg.tiling.margin
347
+ osm_dir = data_dir / "osm"
348
+ osm_path = osm_dir / params["osm_file"]
349
+ if not osm_path.exists():
350
+ logger.info("Downloading OSM raw data.")
351
+ download_file(DATA_URL + f"/osm/{params['osm_file']}", osm_path)
352
+ if not osm_path.exists():
353
+ raise FileNotFoundError(f"Cannot find OSM data file {osm_path}.")
354
+ tile_manager = TileManager.from_bbox(
355
+ projection,
356
+ bbox_tiling,
357
+ cfg.tiling.ppm,
358
+ tile_size=cfg.tiling.tile_size,
359
+ path=osm_path,
360
+ )
361
+ tile_manager.save(tiles_path)
362
+ else:
363
+ logger.info("Downloading pre-generated map tiles.")
364
+ download_file(DATA_URL + f"/tiles/{location}.pkl", tiles_path)
365
+
366
+ # Visualize the data split
367
+ plotter = GeoPlotter()
368
+ view_ids_val = set(splits["val"][location])
369
+ is_val = np.array([int(i.rsplit("_", 1)[0]) in view_ids_val for i in view_ids])
370
+ plotter.points(views_latlon[~is_val], "red", view_ids[~is_val], "train")
371
+ plotter.points(views_latlon[is_val], "green", view_ids[is_val], "val")
372
+ plotter.bbox(bbox, "blue", "query bounding box")
373
+ plotter.bbox(projection.unproject(bbox_tiling), "black", "tiling bounding box")
374
+ geo_viz_path = loc_dir / f"split_{location}.html"
375
+ plotter.fig.write_html(geo_viz_path)
376
+ logger.info("Wrote split visualization to %s.", geo_viz_path)
377
+
378
+ shutil.rmtree(raw_image_dir)
379
+ logger.info("Done processing for location %s.", location)
380
+
381
+
382
+ if __name__ == "__main__":
383
+ parser = argparse.ArgumentParser()
384
+ parser.add_argument(
385
+ "--locations", type=str, nargs="+", default=list(location_to_params)
386
+ )
387
+ parser.add_argument("--split_filename", type=str, default="splits_MGL_13loc.json")
388
+ parser.add_argument("--token", type=str, required=True)
389
+ parser.add_argument(
390
+ "--data_dir", type=Path, default=MapillaryDataModule.default_cfg["data_dir"]
391
+ )
392
+ parser.add_argument("--generate_tiles", action="store_true")
393
+ args = parser.parse_args()
394
+
395
+ args.data_dir.mkdir(exist_ok=True, parents=True)
396
+ shutil.copy(Path(__file__).parent / args.split_filename, args.data_dir)
397
+
398
+ for location in args.locations:
399
+ logger.info("Starting processing for location %s.", location)
400
+ process_location(
401
+ location,
402
+ args.data_dir,
403
+ args.data_dir / args.split_filename,
404
+ args.token,
405
+ args.generate_tiles,
406
+ )
maploc/data/mapillary/splits_MGL_13loc.json ADDED
The diff for this file is too large to render. See raw diff
 
maploc/data/mapillary/utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ from typing import List, Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from opensfm import features
9
+ from opensfm.pygeometry import Camera, compute_camera_mapping, Pose
10
+ from opensfm.pymap import Shot
11
+ from scipy.spatial.transform import Rotation
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def keyframe_selection(shots: List[Shot], min_dist: float = 4) -> List[int]:
17
+ camera_centers = np.stack([shot.pose.get_origin() for shot in shots], 0)
18
+ distances = np.linalg.norm(np.diff(camera_centers, axis=0), axis=1)
19
+ selected = [0]
20
+ cum = 0
21
+ for i in range(1, len(camera_centers)):
22
+ cum += distances[i - 1]
23
+ if cum >= min_dist:
24
+ selected.append(i)
25
+ cum = 0
26
+ return selected
27
+
28
+
29
+ def perspective_camera_from_pano(camera: Camera, size: int) -> Camera:
30
+ camera_new = Camera.create_perspective(0.5, 0, 0)
31
+ camera_new.height = camera_new.width = size
32
+ camera_new.id = "perspective_from_pano"
33
+ return camera_new
34
+
35
+
36
+ def scale_camera(camera: Camera, max_size: int) -> Camera:
37
+ height = camera.height
38
+ width = camera.width
39
+ factor = max_size / float(max(height, width))
40
+ if factor >= 1:
41
+ return camera
42
+ camera.width = int(round(width * factor))
43
+ camera.height = int(round(height * factor))
44
+ return camera
45
+
46
+
47
+ class PanoramaUndistorter:
48
+ def __init__(self, camera_pano: Camera, camera_new: Camera):
49
+ w, h = camera_new.width, camera_new.height
50
+ self.shape = (h, w)
51
+
52
+ dst_y, dst_x = np.indices(self.shape).astype(np.float32)
53
+ dst_pixels_denormalized = np.column_stack([dst_x.ravel(), dst_y.ravel()])
54
+ dst_pixels = features.normalized_image_coordinates(
55
+ dst_pixels_denormalized, w, h
56
+ )
57
+ self.dst_bearings = camera_new.pixel_bearing_many(dst_pixels)
58
+
59
+ self.camera_pano = camera_pano
60
+ self.camera_perspective = camera_new
61
+
62
+ def __call__(
63
+ self, image: np.ndarray, panoshot: Shot, perspectiveshot: Shot
64
+ ) -> np.ndarray:
65
+ # Rotate to panorama reference frame
66
+ rotation = np.dot(
67
+ panoshot.pose.get_rotation_matrix(),
68
+ perspectiveshot.pose.get_rotation_matrix().T,
69
+ )
70
+ rotated_bearings = np.dot(self.dst_bearings, rotation.T)
71
+
72
+ # Project to panorama pixels
73
+ src_pixels = panoshot.camera.project_many(rotated_bearings)
74
+ src_pixels_denormalized = features.denormalized_image_coordinates(
75
+ src_pixels, image.shape[1], image.shape[0]
76
+ )
77
+ src_pixels_denormalized.shape = self.shape + (2,)
78
+
79
+ # Sample color
80
+ x = src_pixels_denormalized[..., 0].astype(np.float32)
81
+ y = src_pixels_denormalized[..., 1].astype(np.float32)
82
+ colors = cv2.remap(image, x, y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_WRAP)
83
+ return colors
84
+
85
+
86
+ class CameraUndistorter:
87
+ def __init__(self, camera_distorted: Camera, camera_new: Camera):
88
+ self.maps = compute_camera_mapping(
89
+ camera_distorted,
90
+ camera_new,
91
+ camera_distorted.width,
92
+ camera_distorted.height,
93
+ )
94
+ self.camera_perspective = camera_new
95
+ self.camera_distorted = camera_distorted
96
+
97
+ def __call__(self, image: np.ndarray) -> np.ndarray:
98
+ assert image.shape[:2] == (
99
+ self.camera_distorted.height,
100
+ self.camera_distorted.width,
101
+ )
102
+ undistorted = cv2.remap(image, *self.maps, cv2.INTER_LINEAR)
103
+ resized = cv2.resize(
104
+ undistorted,
105
+ (self.camera_perspective.width, self.camera_perspective.height),
106
+ interpolation=cv2.INTER_AREA,
107
+ )
108
+ return resized
109
+
110
+
111
+ def render_panorama(
112
+ shot: Shot,
113
+ pano: np.ndarray,
114
+ undistorter: PanoramaUndistorter,
115
+ offset: float = 0.0,
116
+ ) -> Tuple[List[Shot], List[np.ndarray]]:
117
+ yaws = [0, 90, 180, 270]
118
+ suffixes = ["front", "left", "back", "right"]
119
+ images = []
120
+ shots = []
121
+
122
+ # To reduce aliasing, since cv2.remap does not support area samplimg,
123
+ # we first resize with anti-aliasing.
124
+ h, w = undistorter.shape
125
+ h, w = (w * 2, w * 4) # assuming 90deg FOV
126
+ pano_resized = cv2.resize(pano, (w, h), interpolation=cv2.INTER_AREA)
127
+
128
+ for yaw, suffix in zip(yaws, suffixes):
129
+ R_pano2persp = Rotation.from_euler("Y", yaw + offset, degrees=True).as_matrix()
130
+ name = f"{shot.id}_{suffix}"
131
+ shot_new = Shot(
132
+ name,
133
+ undistorter.camera_perspective,
134
+ Pose.compose(Pose(R_pano2persp), shot.pose),
135
+ )
136
+ shot_new.metadata = shot.metadata
137
+ perspective = undistorter(pano_resized, shot, shot_new)
138
+ images.append(perspective)
139
+ shots.append(shot_new)
140
+ return shots, images
141
+
142
+
143
+ def undistort_camera(
144
+ shot: Shot, image: np.ndarray, undistorter: CameraUndistorter
145
+ ) -> Tuple[Shot, np.ndarray]:
146
+ name = f"{shot.id}_undistorted"
147
+ shot_out = Shot(name, undistorter.camera_perspective, shot.pose)
148
+ shot_out.metadata = shot.metadata
149
+ undistorted = undistorter(image)
150
+ return shot_out, undistorted
151
+
152
+
153
+ def undistort_shot(
154
+ image_raw: np.ndarray,
155
+ shot_orig: Shot,
156
+ undistorter,
157
+ pano_offset: float,
158
+ ) -> Tuple[List[Shot], List[np.ndarray]]:
159
+ camera = shot_orig.camera
160
+ if image_raw.shape[:2] != (camera.height, camera.width):
161
+ raise ValueError(
162
+ shot_orig.id, image_raw.shape[:2], (camera.height, camera.width)
163
+ )
164
+ if camera.is_panorama(camera.projection_type):
165
+ shots, undistorted = render_panorama(
166
+ shot_orig, image_raw, undistorter, offset=pano_offset
167
+ )
168
+ elif camera.projection_type in ("perspective", "fisheye"):
169
+ shot, undistorted = undistort_camera(shot_orig, image_raw, undistorter)
170
+ shots, undistorted = [shot], [undistorted]
171
+ else:
172
+ raise NotImplementedError(camera.projection_type)
173
+ return shots, undistorted
maploc/data/sequential.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def chunk_sequence(
8
+ data,
9
+ indices,
10
+ *,
11
+ names=None,
12
+ max_length=100,
13
+ min_length=1,
14
+ max_delay_s=None,
15
+ max_inter_dist=None,
16
+ max_total_dist=None,
17
+ ):
18
+ sort_array = data.get("capture_time", data.get("index", names or indices))
19
+ indices = sorted(indices, key=lambda i: sort_array[i].tolist())
20
+ centers = torch.stack([data["t_c2w"][i][:2] for i in indices]).numpy()
21
+ dists = np.linalg.norm(np.diff(centers, axis=0), axis=-1)
22
+ if "capture_time" in data:
23
+ times = torch.stack([data["capture_time"][i] for i in indices])
24
+ times = times.double() / 1e3 # ms to s
25
+ delays = np.diff(times, axis=0)
26
+ else:
27
+ delays = np.zeros_like(dists)
28
+ chunks = [[indices[0]]]
29
+ dist_total = 0
30
+ for dist, delay, idx in zip(dists, delays, indices[1:]):
31
+ dist_total += dist
32
+ if (
33
+ (max_inter_dist is not None and dist > max_inter_dist)
34
+ or (max_total_dist is not None and dist_total > max_total_dist)
35
+ or (max_delay_s is not None and delay > max_delay_s)
36
+ or len(chunks[-1]) >= max_length
37
+ ):
38
+ chunks.append([])
39
+ dist_total = 0
40
+ chunks[-1].append(idx)
41
+ chunks = list(filter(lambda c: len(c) >= min_length, chunks))
42
+ chunks = sorted(chunks, key=len, reverse=True)
43
+ return chunks
44
+
45
+
46
+ def unpack_batches(batches):
47
+ images = [b["image"].permute(1, 2, 0) for b in batches]
48
+ canvas = [b["canvas"] for b in batches]
49
+ rasters = [b["map"] for b in batches]
50
+ yaws = torch.stack([b["roll_pitch_yaw"][-1] for b in batches])
51
+ uv_gt = torch.stack([b["uv"] for b in batches])
52
+ xy_gt = torch.stack(
53
+ [canv.to_xy(uv.cpu().double()) for uv, canv in zip(uv_gt, canvas)]
54
+ )
55
+ ret = [images, canvas, rasters, yaws, uv_gt, xy_gt.to(uv_gt)]
56
+ if "uv_gps" in batches[0]:
57
+ xy_gps = torch.stack(
58
+ [c.to_xy(b["uv_gps"].cpu().double()) for b, c in zip(batches, canvas)]
59
+ )
60
+ ret.append(xy_gps.to(uv_gt))
61
+ return ret
maploc/data/torch.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import collections
4
+ import os
5
+
6
+ import torch
7
+ from torch.utils.data import get_worker_info
8
+ from torch.utils.data._utils.collate import (
9
+ default_collate_err_msg_format,
10
+ np_str_obj_array_pattern,
11
+ )
12
+ from lightning_fabric.utilities.seed import pl_worker_init_function
13
+ from lightning_utilities.core.apply_func import apply_to_collection
14
+ from lightning_fabric.utilities.apply_func import move_data_to_device
15
+
16
+
17
+ def collate(batch):
18
+ """Difference with PyTorch default_collate: it can stack other tensor-like objects.
19
+ Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
20
+ https://github.com/cvg/pixloc
21
+ Released under the Apache License 2.0
22
+ """
23
+ if not isinstance(batch, list): # no batching
24
+ return batch
25
+ elem = batch[0]
26
+ elem_type = type(elem)
27
+ if isinstance(elem, torch.Tensor):
28
+ out = None
29
+ if torch.utils.data.get_worker_info() is not None:
30
+ # If we're in a background process, concatenate directly into a
31
+ # shared memory tensor to avoid an extra copy
32
+ numel = sum(x.numel() for x in batch)
33
+ storage = elem.storage()._new_shared(numel, device=elem.device)
34
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
35
+ return torch.stack(batch, 0, out=out)
36
+ elif (
37
+ elem_type.__module__ == "numpy"
38
+ and elem_type.__name__ != "str_"
39
+ and elem_type.__name__ != "string_"
40
+ ):
41
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
42
+ # array of string classes and object
43
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
44
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
45
+
46
+ return collate([torch.as_tensor(b) for b in batch])
47
+ elif elem.shape == (): # scalars
48
+ return torch.as_tensor(batch)
49
+ elif isinstance(elem, float):
50
+ return torch.tensor(batch, dtype=torch.float64)
51
+ elif isinstance(elem, int):
52
+ return torch.tensor(batch)
53
+ elif isinstance(elem, (str, bytes)):
54
+ return batch
55
+ elif isinstance(elem, collections.abc.Mapping):
56
+ return {key: collate([d[key] for d in batch]) for key in elem}
57
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
58
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
59
+ elif isinstance(elem, collections.abc.Sequence):
60
+ # check to make sure that the elements in batch have consistent size
61
+ it = iter(batch)
62
+ elem_size = len(next(it))
63
+ if not all(len(elem) == elem_size for elem in it):
64
+ raise RuntimeError("each element in list of batch should be of equal size")
65
+ transposed = zip(*batch)
66
+ return [collate(samples) for samples in transposed]
67
+ else:
68
+ # try to stack anyway in case the object implements stacking.
69
+ try:
70
+ return torch.stack(batch, 0)
71
+ except TypeError as e:
72
+ if "expected Tensor as element" in str(e):
73
+ return batch
74
+ else:
75
+ raise e
76
+
77
+
78
+ def set_num_threads(nt):
79
+ """Force numpy and other libraries to use a limited number of threads."""
80
+ try:
81
+ import mkl
82
+ except ImportError:
83
+ pass
84
+ else:
85
+ mkl.set_num_threads(nt)
86
+ torch.set_num_threads(1)
87
+ os.environ["IPC_ENABLE"] = "1"
88
+ for o in [
89
+ "OPENBLAS_NUM_THREADS",
90
+ "NUMEXPR_NUM_THREADS",
91
+ "OMP_NUM_THREADS",
92
+ "MKL_NUM_THREADS",
93
+ ]:
94
+ os.environ[o] = str(nt)
95
+
96
+
97
+ def worker_init_fn(i):
98
+ info = get_worker_info()
99
+ pl_worker_init_function(info.id)
100
+ num_threads = info.dataset.cfg.get("num_threads")
101
+ if num_threads is not None:
102
+ set_num_threads(num_threads)
103
+
104
+
105
+ def unbatch_to_device(data, device="cpu"):
106
+ data = move_data_to_device(data, device)
107
+ data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
108
+ data = apply_to_collection(
109
+ data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
110
+ )
111
+ return data
maploc/data/utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ from scipy.spatial.transform import Rotation
5
+
6
+
7
+ def crop_map(raster, xy, size, seed=None):
8
+ h, w = raster.shape[-2:]
9
+ state = np.random.RandomState(seed)
10
+ top = state.randint(0, h - size + 1)
11
+ left = state.randint(0, w - size + 1)
12
+ raster = raster[..., top : top + size, left : left + size]
13
+ xy -= np.array([left, top])
14
+ return raster, xy
15
+
16
+
17
+ def random_rot90(raster, xy, heading, seed=None):
18
+ rot = np.random.RandomState(seed).randint(0, 4)
19
+ heading = (heading + rot * np.pi / 2) % (2 * np.pi)
20
+ h, w = raster.shape[-2:]
21
+ if rot == 0:
22
+ xy2 = xy
23
+ elif rot == 2:
24
+ xy2 = np.array([w, h]) - 1 - xy
25
+ elif rot == 1:
26
+ xy2 = np.array([xy[1], w - 1 - xy[0]])
27
+ elif rot == 3:
28
+ xy2 = np.array([h - 1 - xy[1], xy[0]])
29
+ else:
30
+ raise ValueError(rot)
31
+ raster = np.rot90(raster, rot, axes=(-2, -1))
32
+ return raster, xy2, heading
33
+
34
+
35
+ def random_flip(image, raster, xy, heading, seed=None):
36
+ state = np.random.RandomState(seed)
37
+ if state.rand() > 0.5: # no flip
38
+ return image, raster, xy, heading
39
+ image = image[:, ::-1]
40
+ h, w = raster.shape[-2:]
41
+ if state.rand() > 0.5: # flip x
42
+ raster = raster[..., :, ::-1]
43
+ xy = np.array([w - 1 - xy[0], xy[1]])
44
+ heading = np.pi - heading
45
+ else: # flip y
46
+ raster = raster[..., ::-1, :]
47
+ xy = np.array([xy[0], h - 1 - xy[1]])
48
+ heading = -heading
49
+ heading = heading % (2 * np.pi)
50
+ return image, raster, xy, heading
51
+
52
+
53
+ def decompose_rotmat(R_c2w):
54
+ R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
55
+ rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
56
+ roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
57
+ # rot_w2c_check = R_cv2xyz.inv() * Rotation.from_euler('YXZ', [roll, pitch, yaw], degrees=True)
58
+ # np.testing.assert_allclose(rot_w2c_check.as_matrix(), R_c2w.T, rtol=1e-6, atol=1e-6)
59
+ # R_plane2c = Rotation.from_euler("ZX", [roll, pitch], degrees=True).as_matrix()
60
+ return roll, pitch, yaw
maploc/demo.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import torch
6
+ import numpy as np
7
+
8
+ from . import logger
9
+ from .evaluation.run import resolve_checkpoint_path, pretrained_models
10
+ from .models.orienternet import OrienterNet
11
+ from .models.voting import fuse_gps, argmax_xyr
12
+ from .data.image import resize_image, pad_image, rectify_image
13
+ from .osm.raster import Canvas
14
+ from .utils.wrappers import Camera
15
+ from .utils.io import read_image
16
+ from .utils.geo import BoundaryBox, Projection
17
+ from .utils.exif import EXIF
18
+
19
+ try:
20
+ from geopy.geocoders import Nominatim
21
+
22
+ geolocator = Nominatim(user_agent="orienternet")
23
+ except ImportError:
24
+ geolocator = None
25
+
26
+ try:
27
+ from gradio_client import Client
28
+
29
+ calibrator = Client("https://jinlinyi-perspectivefields.hf.space/")
30
+ except (ImportError, ValueError):
31
+ calibrator = None
32
+
33
+
34
+ def image_calibration(image_path):
35
+ logger.info("Calling the PerspectiveFields calibrator, this may take some time.")
36
+ result = calibrator.predict(
37
+ image_path, "NEW:Paramnet-360Cities-edina-centered", api_name="/predict"
38
+ )
39
+ result = dict(r.rsplit(" ", 1) for r in result[1].split("\n"))
40
+ roll_pitch = float(result["roll"]), float(result["pitch"])
41
+ return roll_pitch, float(result["vertical fov"])
42
+
43
+
44
+ def camera_from_exif(exif: EXIF, fov: Optional[float] = None) -> Camera:
45
+ w, h = image_size = exif.extract_image_size()
46
+ _, f_ratio = exif.extract_focal()
47
+ if f_ratio == 0:
48
+ if fov is not None:
49
+ # This is the vertical FoV.
50
+ f = h / 2 / np.tan(np.deg2rad(fov) / 2)
51
+ else:
52
+ return None
53
+ else:
54
+ f = f_ratio * max(image_size)
55
+ return Camera.from_dict(
56
+ dict(
57
+ model="SIMPLE_PINHOLE",
58
+ width=w,
59
+ height=h,
60
+ params=[f, w / 2 + 0.5, h / 2 + 0.5],
61
+ )
62
+ )
63
+
64
+
65
+ def read_input_image(
66
+ image_path: str,
67
+ prior_latlon: Optional[Tuple[float, float]] = None,
68
+ prior_address: Optional[str] = None,
69
+ fov: Optional[float] = None,
70
+ tile_size_meters: int = 64,
71
+ ):
72
+ image = read_image(image_path)
73
+
74
+ roll_pitch = None
75
+ if calibrator is not None:
76
+ roll_pitch, fov = image_calibration(image_path)
77
+ else:
78
+ logger.info("Could not call PerspectiveFields, maybe install gradio_client?")
79
+ if roll_pitch is not None:
80
+ logger.info("Using (roll, pitch) %s.", roll_pitch)
81
+
82
+ with open(image_path, "rb") as fid:
83
+ exif = EXIF(fid, lambda: image.shape[:2])
84
+ camera = camera_from_exif(exif, fov)
85
+ if camera is None:
86
+ raise ValueError(
87
+ "No camera intrinsics found in the EXIF, provide an FoV guess."
88
+ )
89
+
90
+ latlon = None
91
+ if prior_latlon is not None:
92
+ latlon = prior_latlon
93
+ logger.info("Using prior latlon %s.", prior_latlon)
94
+ if prior_address is not None:
95
+ if geolocator is None:
96
+ raise ValueError("geocoding unavailable, install geopy.")
97
+ location = geolocator.geocode(prior_address)
98
+ if location is None:
99
+ logger.info("Could not find any location for %s.", prior_address)
100
+ else:
101
+ logger.info("Using prior address: %s", location.address)
102
+ latlon = (location.latitude, location.longitude)
103
+ if latlon is None:
104
+ geo = exif.extract_geo()
105
+ if geo:
106
+ alt = geo.get("altitude", 0) # read if available
107
+ latlon = (geo["latitude"], geo["longitude"], alt)
108
+ logger.info("Using prior location from EXIF.")
109
+ else:
110
+ logger.info("Could not find any prior location in EXIF.")
111
+ if latlon is None:
112
+ raise ValueError("Need prior latlon")
113
+ latlon = np.array(latlon)
114
+
115
+ proj = Projection(*latlon)
116
+ center = proj.project(latlon)
117
+ bbox = BoundaryBox(center, center) + tile_size_meters
118
+ return image, camera, roll_pitch, proj, bbox, latlon
119
+
120
+
121
+ class Demo:
122
+ def __init__(
123
+ self,
124
+ experiment_or_path: Optional[str] = "OrienterNet_MGL",
125
+ device=None,
126
+ **kwargs
127
+ ):
128
+ if experiment_or_path in pretrained_models:
129
+ experiment_or_path, _ = pretrained_models[experiment_or_path]
130
+ path = resolve_checkpoint_path(experiment_or_path)
131
+ ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
132
+ config = ckpt["hyper_parameters"]
133
+ config.model.update(kwargs)
134
+ config.model.image_encoder.backbone.pretrained = False
135
+
136
+ model = OrienterNet(config.model).eval()
137
+ state = {k[len("model.") :]: v for k, v in ckpt["state_dict"].items()}
138
+ model.load_state_dict(state, strict=True)
139
+ if device is None:
140
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
+ model = model.to(device)
142
+
143
+ self.model = model
144
+ self.config = config
145
+ self.device = device
146
+
147
+ def prepare_data(
148
+ self,
149
+ image: np.ndarray,
150
+ camera: Camera,
151
+ canvas: Canvas,
152
+ roll_pitch: Optional[Tuple[float]] = None,
153
+ ):
154
+ assert image.shape[:2][::-1] == tuple(camera.size.tolist())
155
+ target_focal_length = self.config.data.resize_image / 2
156
+ factor = target_focal_length / camera.f
157
+ size = (camera.size * factor).round().int()
158
+
159
+ image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
160
+ valid = None
161
+ if roll_pitch is not None:
162
+ roll, pitch = roll_pitch
163
+ image, valid = rectify_image(
164
+ image,
165
+ camera.float(),
166
+ roll=-roll,
167
+ pitch=-pitch,
168
+ )
169
+ image, _, camera, *maybe_valid = resize_image(
170
+ image, size.numpy(), camera=camera, valid=valid
171
+ )
172
+ valid = None if valid is None else maybe_valid
173
+
174
+ max_stride = max(self.model.image_encoder.layer_strides)
175
+ size = (np.ceil((size / max_stride)) * max_stride).int()
176
+ image, valid, camera = pad_image(
177
+ image, size.numpy(), camera, crop_and_center=True
178
+ )
179
+
180
+ return dict(
181
+ image=image,
182
+ map=torch.from_numpy(canvas.raster).long(),
183
+ camera=camera.float(),
184
+ valid=valid,
185
+ )
186
+
187
+ def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
188
+ data = self.prepare_data(image, camera, canvas, **kwargs)
189
+ data_ = {k: v.to(self.device)[None] for k, v in data.items()}
190
+ with torch.no_grad():
191
+ pred = self.model(data_)
192
+
193
+ xy_gps = canvas.bbox.center
194
+ uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
195
+
196
+ lp_xyr = pred["log_probs"].squeeze(0)
197
+ tile_size = canvas.bbox.size.min() / 2
198
+ sigma = tile_size - 20 # 20 meters margin
199
+ lp_xyr = fuse_gps(
200
+ lp_xyr,
201
+ uv_gps.to(lp_xyr),
202
+ self.config.model.pixel_per_meter,
203
+ sigma=sigma,
204
+ )
205
+ xyr = argmax_xyr(lp_xyr).cpu()
206
+
207
+ prob = lp_xyr.exp().cpu()
208
+ neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
209
+ return xyr[:2], xyr[1], prob, neural_map, data["image"]
maploc/evaluation/kitti.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
+ from .. import logger
10
+ from ..data import KittiDataModule
11
+ from .run import evaluate
12
+
13
+
14
+ default_cfg_single = OmegaConf.create({})
15
+ # For the sequential evaluation, we need to center the map around the GT location,
16
+ # since random offsets would accumulate and leave only the GT location with a valid mask.
17
+ # This should not have much impact on the results.
18
+ default_cfg_sequential = OmegaConf.create(
19
+ {
20
+ "data": {
21
+ "mask_radius": KittiDataModule.default_cfg["max_init_error"],
22
+ "prior_range_rotation": KittiDataModule.default_cfg[
23
+ "max_init_error_rotation"
24
+ ]
25
+ + 1,
26
+ "max_init_error": 0,
27
+ "max_init_error_rotation": 0,
28
+ },
29
+ "chunking": {
30
+ "max_length": 100, # about 10s?
31
+ },
32
+ }
33
+ )
34
+
35
+
36
+ def run(
37
+ split: str,
38
+ experiment: str,
39
+ cfg: Optional[DictConfig] = None,
40
+ sequential: bool = False,
41
+ thresholds: Tuple[int] = (1, 3, 5),
42
+ **kwargs,
43
+ ):
44
+ cfg = cfg or {}
45
+ if isinstance(cfg, dict):
46
+ cfg = OmegaConf.create(cfg)
47
+ default = default_cfg_sequential if sequential else default_cfg_single
48
+ cfg = OmegaConf.merge(default, cfg)
49
+ dataset = KittiDataModule(cfg.get("data", {}))
50
+
51
+ metrics = evaluate(
52
+ experiment,
53
+ cfg,
54
+ dataset,
55
+ split=split,
56
+ sequential=sequential,
57
+ viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
58
+ **kwargs,
59
+ )
60
+
61
+ keys = ["directional_error", "yaw_max_error"]
62
+ if sequential:
63
+ keys += ["directional_seq_error", "yaw_seq_error"]
64
+ for k in keys:
65
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
66
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
67
+ return metrics
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--experiment", type=str, required=True)
73
+ parser.add_argument(
74
+ "--split", type=str, default="test", choices=["test", "val", "train"]
75
+ )
76
+ parser.add_argument("--sequential", action="store_true")
77
+ parser.add_argument("--output_dir", type=Path)
78
+ parser.add_argument("--num", type=int)
79
+ parser.add_argument("dotlist", nargs="*")
80
+ args = parser.parse_args()
81
+ cfg = OmegaConf.from_cli(args.dotlist)
82
+ run(
83
+ args.split,
84
+ args.experiment,
85
+ cfg,
86
+ args.sequential,
87
+ output_dir=args.output_dir,
88
+ num=args.num,
89
+ )
maploc/evaluation/mapillary.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
+ from .. import logger
10
+ from ..conf import data as conf_data_dir
11
+ from ..data import MapillaryDataModule
12
+ from .run import evaluate
13
+
14
+
15
+ split_overrides = {
16
+ "val": {
17
+ "scenes": [
18
+ "sanfrancisco_soma",
19
+ "sanfrancisco_hayes",
20
+ "amsterdam",
21
+ "berlin",
22
+ "lemans",
23
+ "montrouge",
24
+ "toulouse",
25
+ "nantes",
26
+ "vilnius",
27
+ "avignon",
28
+ "helsinki",
29
+ "milan",
30
+ "paris",
31
+ ],
32
+ },
33
+ }
34
+ data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
35
+ data_cfg = OmegaConf.merge(
36
+ data_cfg_train,
37
+ {
38
+ "return_gps": True,
39
+ "add_map_mask": True,
40
+ "max_init_error": 32,
41
+ "loading": {"val": {"batch_size": 1, "num_workers": 0}},
42
+ },
43
+ )
44
+ default_cfg_single = OmegaConf.create({"data": data_cfg})
45
+ default_cfg_sequential = OmegaConf.create(
46
+ {
47
+ **default_cfg_single,
48
+ "chunking": {
49
+ "max_length": 10,
50
+ },
51
+ }
52
+ )
53
+
54
+
55
+ def run(
56
+ split: str,
57
+ experiment: str,
58
+ cfg: Optional[DictConfig] = None,
59
+ sequential: bool = False,
60
+ thresholds: Tuple[int] = (1, 3, 5),
61
+ **kwargs,
62
+ ):
63
+ cfg = cfg or {}
64
+ if isinstance(cfg, dict):
65
+ cfg = OmegaConf.create(cfg)
66
+ default = default_cfg_sequential if sequential else default_cfg_single
67
+ default = OmegaConf.merge(default, split_overrides[split])
68
+ cfg = OmegaConf.merge(default, cfg)
69
+ dataset = MapillaryDataModule(cfg.get("data", {}))
70
+
71
+ metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
72
+
73
+ keys = [
74
+ "xy_max_error",
75
+ "xy_gps_error",
76
+ "yaw_max_error",
77
+ ]
78
+ if sequential:
79
+ keys += [
80
+ "xy_seq_error",
81
+ "xy_gps_seq_error",
82
+ "yaw_seq_error",
83
+ "yaw_gps_seq_error",
84
+ ]
85
+ for k in keys:
86
+ if k not in metrics:
87
+ logger.warning("Key %s not in metrics.", k)
88
+ continue
89
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
90
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
91
+ return metrics
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--experiment", type=str, required=True)
97
+ parser.add_argument("--split", type=str, default="val", choices=["val"])
98
+ parser.add_argument("--sequential", action="store_true")
99
+ parser.add_argument("--output_dir", type=Path)
100
+ parser.add_argument("--num", type=int)
101
+ parser.add_argument("dotlist", nargs="*")
102
+ args = parser.parse_args()
103
+ cfg = OmegaConf.from_cli(args.dotlist)
104
+ run(
105
+ args.split,
106
+ args.experiment,
107
+ cfg,
108
+ args.sequential,
109
+ output_dir=args.output_dir,
110
+ num=args.num,
111
+ )
maploc/evaluation/run.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import functools
4
+ from itertools import islice
5
+ from typing import Callable, Dict, Optional, Tuple
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from torchmetrics import MetricCollection
12
+ from pytorch_lightning import seed_everything
13
+ from tqdm import tqdm
14
+
15
+ from .. import logger, EXPERIMENTS_PATH
16
+ from ..data.torch import collate, unbatch_to_device
17
+ from ..models.voting import argmax_xyr, fuse_gps
18
+ from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError
19
+ from ..models.sequential import GPSAligner, RigidAligner
20
+ from ..module import GenericModule
21
+ from ..utils.io import download_file, DATA_URL
22
+ from .viz import plot_example_single, plot_example_sequential
23
+ from .utils import write_dump
24
+
25
+
26
+ pretrained_models = dict(
27
+ OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
28
+ )
29
+
30
+
31
+ def resolve_checkpoint_path(experiment_or_path: str) -> Path:
32
+ path = Path(experiment_or_path)
33
+ if not path.exists():
34
+ # provided name of experiment
35
+ path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
36
+ if not path.exists():
37
+ if experiment_or_path in set(p for p, _ in pretrained_models.values()):
38
+ download_file(f"{DATA_URL}/{experiment_or_path}", path)
39
+ else:
40
+ raise FileNotFoundError(path)
41
+ if path.is_file():
42
+ return path
43
+ # provided only the experiment name
44
+ maybe_path = path / "last-step.ckpt"
45
+ if not maybe_path.exists():
46
+ maybe_path = path / "step.ckpt"
47
+ if not maybe_path.exists():
48
+ raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
49
+ return maybe_path
50
+
51
+
52
+ @torch.no_grad()
53
+ def evaluate_single_image(
54
+ dataloader: torch.utils.data.DataLoader,
55
+ model: GenericModule,
56
+ num: Optional[int] = None,
57
+ callback: Optional[Callable] = None,
58
+ progress: bool = True,
59
+ mask_index: Optional[Tuple[int]] = None,
60
+ has_gps: bool = False,
61
+ ):
62
+ ppm = model.model.conf.pixel_per_meter
63
+ metrics = MetricCollection(model.model.metrics())
64
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
65
+ if has_gps:
66
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
67
+ metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
68
+ metrics["yaw_fused_error"] = AngleError("yaw_fused")
69
+ metrics = metrics.to(model.device)
70
+
71
+ for i, batch_ in enumerate(
72
+ islice(tqdm(dataloader, total=num, disable=not progress), num)
73
+ ):
74
+ batch = model.transfer_batch_to_device(batch_, model.device, i)
75
+ # Ablation: mask semantic classes
76
+ if mask_index is not None:
77
+ mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
78
+ batch["map"][0, mask_index[0]][mask] = 0
79
+ pred = model(batch)
80
+
81
+ if has_gps:
82
+ (uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
83
+ pred["log_probs_fused"] = fuse_gps(
84
+ pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
85
+ )
86
+ uvt_fused = argmax_xyr(pred["log_probs_fused"])
87
+ pred["uv_fused"] = uvt_fused[..., :2]
88
+ pred["yaw_fused"] = uvt_fused[..., -1]
89
+ del uv_gps, uvt_fused
90
+
91
+ results = metrics(pred, batch)
92
+ if callback is not None:
93
+ callback(
94
+ i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
95
+ )
96
+ del batch_, batch, pred, results
97
+
98
+ return metrics.cpu()
99
+
100
+
101
+ @torch.no_grad()
102
+ def evaluate_sequential(
103
+ dataset: torch.utils.data.Dataset,
104
+ chunk2idx: Dict,
105
+ model: GenericModule,
106
+ num: Optional[int] = None,
107
+ shuffle: bool = False,
108
+ callback: Optional[Callable] = None,
109
+ progress: bool = True,
110
+ num_rotations: int = 512,
111
+ mask_index: Optional[Tuple[int]] = None,
112
+ has_gps: bool = True,
113
+ ):
114
+ chunk_keys = list(chunk2idx)
115
+ if shuffle:
116
+ chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
117
+ if num is not None:
118
+ chunk_keys = chunk_keys[:num]
119
+ lengths = [len(chunk2idx[k]) for k in chunk_keys]
120
+ logger.info(
121
+ "Min/max/med lengths: %d/%d/%d, total number of images: %d",
122
+ min(lengths),
123
+ np.median(lengths),
124
+ max(lengths),
125
+ sum(lengths),
126
+ )
127
+ viz = callback is not None
128
+
129
+ metrics = MetricCollection(model.model.metrics())
130
+ ppm = model.model.conf.pixel_per_meter
131
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
132
+ metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
133
+ metrics["yaw_seq_error"] = AngleError("yaw_seq")
134
+ metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
135
+ if has_gps:
136
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
137
+ metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
138
+ metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
139
+ metrics = metrics.to(model.device)
140
+
141
+ keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
142
+ if has_gps:
143
+ keys_save.append("uv_gps")
144
+ if viz:
145
+ keys_save.append("log_probs")
146
+
147
+ for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
148
+ indices = chunk2idx[key]
149
+ aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
150
+ if has_gps:
151
+ aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
152
+ batches = []
153
+ preds = []
154
+ for i in indices:
155
+ data = dataset[i]
156
+ data = model.transfer_batch_to_device(data, model.device, 0)
157
+ pred = model(collate([data]))
158
+
159
+ canvas = data["canvas"]
160
+ data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
161
+ data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
162
+ aligner.update(pred["log_probs"][0], canvas, xy, yaw)
163
+
164
+ if has_gps:
165
+ (uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
166
+ xy_gps = canvas.to_xy(uv_gps.double())
167
+ aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
168
+
169
+ if not viz:
170
+ data.pop("image")
171
+ data.pop("map")
172
+ batches.append(data)
173
+ preds.append({k: pred[k][0] for k in keys_save})
174
+ del pred
175
+
176
+ xy_gt = torch.stack([b["xy_geo"] for b in batches])
177
+ yaw_gt = torch.stack([b["yaw"] for b in batches])
178
+ aligner.compute()
179
+ xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
180
+ if has_gps:
181
+ aligner_gps.compute()
182
+ xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
183
+ results = []
184
+ for i in range(len(indices)):
185
+ preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
186
+ preds[i]["yaw_seq"] = yaw_seq[i].float()
187
+ if has_gps:
188
+ preds[i]["uv_gps_seq"] = (
189
+ batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
190
+ )
191
+ preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
192
+ results.append(metrics(preds[i], batches[i]))
193
+ if viz:
194
+ callback(chunk_index, model, batches, preds, results, aligner)
195
+ del aligner, preds, batches, results
196
+ return metrics.cpu()
197
+
198
+
199
+ def evaluate(
200
+ experiment: str,
201
+ cfg: DictConfig,
202
+ dataset,
203
+ split: str,
204
+ sequential: bool = False,
205
+ output_dir: Optional[Path] = None,
206
+ callback: Optional[Callable] = None,
207
+ num_workers: int = 1,
208
+ viz_kwargs=None,
209
+ **kwargs,
210
+ ):
211
+ if experiment in pretrained_models:
212
+ experiment, cfg_override = pretrained_models[experiment]
213
+ cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
214
+
215
+ logger.info("Evaluating model %s with config %s", experiment, cfg)
216
+ checkpoint_path = resolve_checkpoint_path(experiment)
217
+ model = GenericModule.load_from_checkpoint(
218
+ checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
219
+ )
220
+ model = model.eval()
221
+ if torch.cuda.is_available():
222
+ model = model.cuda()
223
+
224
+ dataset.prepare_data()
225
+ dataset.setup()
226
+
227
+ if output_dir is not None:
228
+ output_dir.mkdir(exist_ok=True, parents=True)
229
+ if callback is None:
230
+ if sequential:
231
+ callback = plot_example_sequential
232
+ else:
233
+ callback = plot_example_single
234
+ callback = functools.partial(
235
+ callback, out_dir=output_dir, **(viz_kwargs or {})
236
+ )
237
+ kwargs = {**kwargs, "callback": callback}
238
+
239
+ seed_everything(dataset.cfg.seed)
240
+ if sequential:
241
+ dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
242
+ metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
243
+ else:
244
+ loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
245
+ metrics = evaluate_single_image(loader, model, **kwargs)
246
+
247
+ results = metrics.compute()
248
+ logger.info("All results: %s", results)
249
+ if output_dir is not None:
250
+ write_dump(output_dir, experiment, cfg, results, metrics)
251
+ logger.info("Outputs have been written to %s.", output_dir)
252
+ return metrics
maploc/evaluation/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+
6
+ from ..utils.io import write_json
7
+
8
+
9
+ def compute_recall(errors):
10
+ num_elements = len(errors)
11
+ sort_idx = np.argsort(errors)
12
+ errors = np.array(errors.copy())[sort_idx]
13
+ recall = (np.arange(num_elements) + 1) / num_elements
14
+ recall = np.r_[0, recall]
15
+ errors = np.r_[0, errors]
16
+ return errors, recall
17
+
18
+
19
+ def compute_auc(errors, recall, thresholds):
20
+ aucs = []
21
+ for t in thresholds:
22
+ last_index = np.searchsorted(errors, t, side="right")
23
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
24
+ e = np.r_[errors[:last_index], t]
25
+ auc = np.trapz(r, x=e) / t
26
+ aucs.append(auc * 100)
27
+ return aucs
28
+
29
+
30
+ def write_dump(output_dir, experiment, cfg, results, metrics):
31
+ dump = {
32
+ "experiment": experiment,
33
+ "cfg": OmegaConf.to_container(cfg),
34
+ "results": results,
35
+ "errors": {},
36
+ }
37
+ for k, m in metrics.items():
38
+ if hasattr(m, "get_errors"):
39
+ dump["errors"][k] = m.get_errors().numpy()
40
+ write_json(output_dir / "log.json", dump)
maploc/evaluation/viz.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+
7
+ from ..utils.io import write_torch_image
8
+ from ..utils.viz_2d import plot_images, features_to_RGB, save_plot
9
+ from ..utils.viz_localization import (
10
+ likelihood_overlay,
11
+ plot_pose,
12
+ plot_dense_rotations,
13
+ add_circle_inset,
14
+ )
15
+ from ..osm.viz import Colormap, plot_nodes
16
+
17
+
18
+ def plot_example_single(
19
+ idx,
20
+ model,
21
+ pred,
22
+ data,
23
+ results,
24
+ plot_bev=True,
25
+ out_dir=None,
26
+ fig_for_paper=False,
27
+ show_gps=False,
28
+ show_fused=False,
29
+ show_dir_error=False,
30
+ show_masked_prob=False,
31
+ ):
32
+ scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
33
+ uv_gps = data.get("uv_gps")
34
+ yaw_gt = data["roll_pitch_yaw"][-1].numpy()
35
+ image = data["image"].permute(1, 2, 0)
36
+ if "valid" in data:
37
+ image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
38
+
39
+ lp_uvt = lp_uv = pred["log_probs"]
40
+ if show_fused and "log_probs_fused" in pred:
41
+ lp_uvt = lp_uv = pred["log_probs_fused"]
42
+ elif not show_masked_prob and "scores_unmasked" in pred:
43
+ lp_uvt = lp_uv = pred["scores_unmasked"]
44
+ has_rotation = lp_uvt.ndim == 3
45
+ if has_rotation:
46
+ lp_uv = lp_uvt.max(-1).values
47
+ if lp_uv.min() > -np.inf:
48
+ lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
49
+ prob = lp_uv.exp()
50
+ uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
51
+ if show_fused and "uv_fused" in pred:
52
+ uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
53
+ feats_map = pred["map"]["map_features"][0]
54
+ (feats_map_rgb,) = features_to_RGB(feats_map.numpy())
55
+
56
+ text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
57
+ if has_rotation:
58
+ text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
59
+ if show_fused and "xy_fused_error" in results:
60
+ text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
61
+ text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
62
+ if show_dir_error and "directional_error" in results:
63
+ err_lat, err_lon = results["directional_error"]
64
+ text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
65
+ if "xy_gps_error" in results:
66
+ text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
67
+
68
+ map_viz = Colormap.apply(rasters)
69
+ overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
70
+ plot_images(
71
+ [image, map_viz, overlay, feats_map_rgb],
72
+ titles=[text1, "map", "likelihood", "neural map"],
73
+ dpi=75,
74
+ cmaps="jet",
75
+ )
76
+ fig = plt.gcf()
77
+ axes = fig.axes
78
+ axes[1].images[0].set_interpolation("none")
79
+ axes[2].images[0].set_interpolation("none")
80
+ Colormap.add_colorbar()
81
+ plot_nodes(1, rasters[2])
82
+
83
+ if show_gps and uv_gps is not None:
84
+ plot_pose([1], uv_gps, c="blue")
85
+ plot_pose([1], uv_gt, yaw_gt, c="red")
86
+ plot_pose([1], uv_p, yaw_p, c="k")
87
+ plot_dense_rotations(2, lp_uvt.exp())
88
+ inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
89
+ axins = add_circle_inset(axes[2], inset_center)
90
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
91
+ axes[0].text(
92
+ 0.003,
93
+ 0.003,
94
+ f"{scene}/{name}",
95
+ transform=axes[0].transAxes,
96
+ fontsize=3,
97
+ va="bottom",
98
+ ha="left",
99
+ color="w",
100
+ )
101
+ plt.show()
102
+ if out_dir is not None:
103
+ name_ = name.replace("/", "_")
104
+ p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
105
+ save_plot(p.format("pred"))
106
+ plt.close()
107
+
108
+ if fig_for_paper:
109
+ # !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
110
+ plot_images([map_viz])
111
+ plt.gca().images[0].set_interpolation("none")
112
+ plot_nodes(0, rasters[2])
113
+ plot_pose([0], uv_gt, yaw_gt, c="red")
114
+ plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
115
+ save_plot(p.format("map"))
116
+ plt.close()
117
+ plot_images([lp_uv], cmaps="jet")
118
+ plot_dense_rotations(0, lp_uvt.exp())
119
+ save_plot(p.format("loglikelihood"), dpi=100)
120
+ plt.close()
121
+ plot_images([overlay])
122
+ plt.gca().images[0].set_interpolation("none")
123
+ axins = add_circle_inset(plt.gca(), inset_center)
124
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
125
+ save_plot(p.format("likelihood"))
126
+ plt.close()
127
+ write_torch_image(
128
+ p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
129
+ )
130
+ write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
131
+
132
+ if not plot_bev:
133
+ return
134
+
135
+ feats_q = pred["features_bev"]
136
+ mask_bev = pred["valid_bev"]
137
+ prior = None
138
+ if "log_prior" in pred["map"]:
139
+ prior = pred["map"]["log_prior"][0].sigmoid()
140
+ if "bev" in pred and "confidence" in pred["bev"]:
141
+ conf_q = pred["bev"]["confidence"]
142
+ else:
143
+ conf_q = torch.norm(feats_q, dim=0)
144
+ conf_q = conf_q.masked_fill(~mask_bev, np.nan)
145
+ (feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
146
+ # feats_map_rgb, feats_q_rgb, = features_to_RGB(
147
+ # feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
148
+ norm_map = torch.norm(feats_map, dim=0)
149
+
150
+ plot_images(
151
+ [conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
152
+ titles=["BEV confidence", "BEV features", "map norm"]
153
+ + ([] if prior is None else ["map prior"]),
154
+ dpi=50,
155
+ cmaps="jet",
156
+ )
157
+ plt.show()
158
+
159
+ if out_dir is not None:
160
+ save_plot(p.format("bev"))
161
+ plt.close()
162
+
163
+
164
+ def plot_example_sequential(
165
+ idx,
166
+ model,
167
+ pred,
168
+ data,
169
+ results,
170
+ plot_bev=True,
171
+ out_dir=None,
172
+ fig_for_paper=False,
173
+ show_gps=False,
174
+ show_fused=False,
175
+ show_dir_error=False,
176
+ show_masked_prob=False,
177
+ ):
178
+ return
maploc/models/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ import inspect
8
+
9
+ from .base import BaseModel
10
+
11
+
12
+ def get_class(mod_name, base_path, BaseClass):
13
+ """Get the class object which inherits from BaseClass and is defined in
14
+ the module named mod_name, child of base_path.
15
+ """
16
+ mod_path = "{}.{}".format(base_path, mod_name)
17
+ mod = __import__(mod_path, fromlist=[""])
18
+ classes = inspect.getmembers(mod, inspect.isclass)
19
+ # Filter classes defined in the module
20
+ classes = [c for c in classes if c[1].__module__ == mod_path]
21
+ # Filter classes inherited from BaseModel
22
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
23
+ assert len(classes) == 1, classes
24
+ return classes[0][1]
25
+
26
+
27
+ def get_model(name):
28
+ if name == "localizer":
29
+ name = "localizer_basic"
30
+ elif name == "rotation_localizer":
31
+ name = "localizer_basic_rotation"
32
+ elif name == "bev_localizer":
33
+ name = "localizer_bev_plane"
34
+ return get_class(name, __name__, BaseModel)
maploc/models/base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ """
8
+ Base class for trainable models.
9
+ """
10
+
11
+ from abc import ABCMeta, abstractmethod
12
+ from copy import copy
13
+
14
+ import omegaconf
15
+ from omegaconf import OmegaConf
16
+ from torch import nn
17
+
18
+
19
+ class BaseModel(nn.Module, metaclass=ABCMeta):
20
+ """
21
+ What the child model is expect to declare:
22
+ default_conf: dictionary of the default configuration of the model.
23
+ It recursively updates the default_conf of all parent classes, and
24
+ it is updated by the user-provided configuration passed to __init__.
25
+ Configurations can be nested.
26
+
27
+ required_data_keys: list of expected keys in the input data dictionary.
28
+
29
+ strict_conf (optional): boolean. If false, BaseModel does not raise
30
+ an error when the user provides an unknown configuration entry.
31
+
32
+ _init(self, conf): initialization method, where conf is the final
33
+ configuration object (also accessible with `self.conf`). Accessing
34
+ unknown configuration entries will raise an error.
35
+
36
+ _forward(self, data): method that returns a dictionary of batched
37
+ prediction tensors based on a dictionary of batched input data tensors.
38
+
39
+ loss(self, pred, data): method that returns a dictionary of losses,
40
+ computed from model predictions and input data. Each loss is a batch
41
+ of scalars, i.e. a torch.Tensor of shape (B,).
42
+ The total loss to be optimized has the key `'total'`.
43
+
44
+ metrics(self, pred, data): method that returns a dictionary of metrics,
45
+ each as a batch of scalars.
46
+ """
47
+
48
+ base_default_conf = {
49
+ "name": None,
50
+ "trainable": True, # if false: do not optimize this model parameters
51
+ "freeze_batch_normalization": False, # use test-time statistics
52
+ }
53
+ default_conf = {}
54
+ required_data_keys = []
55
+ strict_conf = True
56
+
57
+ def __init__(self, conf):
58
+ """Perform some logic and call the _init method of the child model."""
59
+ super().__init__()
60
+ default_conf = OmegaConf.merge(
61
+ self.base_default_conf, OmegaConf.create(self.default_conf)
62
+ )
63
+ if self.strict_conf:
64
+ OmegaConf.set_struct(default_conf, True)
65
+
66
+ # fixme: backward compatibility
67
+ if "pad" in conf and "pad" not in default_conf: # backward compat.
68
+ with omegaconf.read_write(conf):
69
+ with omegaconf.open_dict(conf):
70
+ conf["interpolation"] = {"pad": conf.pop("pad")}
71
+
72
+ if isinstance(conf, dict):
73
+ conf = OmegaConf.create(conf)
74
+ self.conf = conf = OmegaConf.merge(default_conf, conf)
75
+ OmegaConf.set_readonly(conf, True)
76
+ OmegaConf.set_struct(conf, True)
77
+ self.required_data_keys = copy(self.required_data_keys)
78
+ self._init(conf)
79
+
80
+ if not conf.trainable:
81
+ for p in self.parameters():
82
+ p.requires_grad = False
83
+
84
+ def train(self, mode=True):
85
+ super().train(mode)
86
+
87
+ def freeze_bn(module):
88
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
89
+ module.eval()
90
+
91
+ if self.conf.freeze_batch_normalization:
92
+ self.apply(freeze_bn)
93
+
94
+ return self
95
+
96
+ def forward(self, data):
97
+ """Check the data and call the _forward method of the child model."""
98
+
99
+ def recursive_key_check(expected, given):
100
+ for key in expected:
101
+ assert key in given, f"Missing key {key} in data"
102
+ if isinstance(expected, dict):
103
+ recursive_key_check(expected[key], given[key])
104
+
105
+ recursive_key_check(self.required_data_keys, data)
106
+ return self._forward(data)
107
+
108
+ @abstractmethod
109
+ def _init(self, conf):
110
+ """To be implemented by the child class."""
111
+ raise NotImplementedError
112
+
113
+ @abstractmethod
114
+ def _forward(self, data):
115
+ """To be implemented by the child class."""
116
+ raise NotImplementedError
117
+
118
+ def loss(self, pred, data):
119
+ """To be implemented by the child class."""
120
+ raise NotImplementedError
121
+
122
+ def metrics(self):
123
+ return {} # no metrics
maploc/models/bev_net.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch.nn as nn
4
+ from torchvision.models.resnet import Bottleneck
5
+
6
+ from .base import BaseModel
7
+ from .feature_extractor import AdaptationBlock
8
+ from .utils import checkpointed
9
+
10
+
11
+ class BEVNet(BaseModel):
12
+ default_conf = {
13
+ "pretrained": True,
14
+ "num_blocks": "???",
15
+ "latent_dim": "???",
16
+ "input_dim": "${.latent_dim}",
17
+ "output_dim": "${.latent_dim}",
18
+ "confidence": False,
19
+ "norm_layer": "nn.BatchNorm2d", # normalization ind decoder blocks
20
+ "checkpointed": False, # whether to use gradient checkpointing
21
+ "padding": "zeros",
22
+ }
23
+
24
+ def _init(self, conf):
25
+ blocks = []
26
+ Block = checkpointed(Bottleneck, do=conf.checkpointed)
27
+ for i in range(conf.num_blocks):
28
+ dim = conf.input_dim if i == 0 else conf.latent_dim
29
+ blocks.append(
30
+ Block(
31
+ dim,
32
+ conf.latent_dim // Bottleneck.expansion,
33
+ norm_layer=eval(conf.norm_layer),
34
+ )
35
+ )
36
+ self.blocks = nn.Sequential(*blocks)
37
+ self.output_layer = AdaptationBlock(conf.latent_dim, conf.output_dim)
38
+ if conf.confidence:
39
+ self.confidence_layer = AdaptationBlock(conf.latent_dim, 1)
40
+
41
+ def update_padding(module):
42
+ if isinstance(module, nn.Conv2d):
43
+ module.padding_mode = conf.padding
44
+
45
+ if conf.padding != "zeros":
46
+ self.bocks.apply(update_padding)
47
+
48
+ def _forward(self, data):
49
+ features = self.blocks(data["input"])
50
+ pred = {
51
+ "output": self.output_layer(features),
52
+ }
53
+ if self.conf.confidence:
54
+ pred["confidence"] = self.confidence_layer(features).squeeze(1).sigmoid()
55
+ return pred
56
+
57
+ def loss(self, pred, data):
58
+ raise NotImplementedError
59
+
60
+ def metrics(self, pred, data):
61
+ raise NotImplementedError
maploc/models/bev_projection.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ from torch.nn.functional import grid_sample
5
+
6
+ from ..utils.geometry import from_homogeneous
7
+ from .utils import make_grid
8
+
9
+
10
+ class PolarProjectionDepth(torch.nn.Module):
11
+ def __init__(self, z_max, ppm, scale_range, z_min=None):
12
+ super().__init__()
13
+ self.z_max = z_max
14
+ self.Δ = Δ = 1 / ppm
15
+ self.z_min = z_min = Δ if z_min is None else z_min
16
+ self.scale_range = scale_range
17
+ z_steps = torch.arange(z_min, z_max + Δ, Δ)
18
+ self.register_buffer("depth_steps", z_steps, persistent=False)
19
+
20
+ def sample_depth_scores(self, pixel_scales, camera):
21
+ scale_steps = camera.f[..., None, 1] / self.depth_steps.flip(-1)
22
+ log_scale_steps = torch.log2(scale_steps)
23
+ scale_min, scale_max = self.scale_range
24
+ log_scale_norm = (log_scale_steps - scale_min) / (scale_max - scale_min)
25
+ log_scale_norm = log_scale_norm * 2 - 1 # in [-1, 1]
26
+
27
+ values = pixel_scales.flatten(1, 2).unsqueeze(-1)
28
+ indices = log_scale_norm.unsqueeze(-1)
29
+ indices = torch.stack([torch.zeros_like(indices), indices], -1)
30
+ depth_scores = grid_sample(values, indices, align_corners=True)
31
+ depth_scores = depth_scores.reshape(
32
+ pixel_scales.shape[:-1] + (len(self.depth_steps),)
33
+ )
34
+ return depth_scores
35
+
36
+ def forward(
37
+ self,
38
+ image,
39
+ pixel_scales,
40
+ camera,
41
+ return_total_score=False,
42
+ ):
43
+ depth_scores = self.sample_depth_scores(pixel_scales, camera)
44
+ depth_prob = torch.softmax(depth_scores, dim=1)
45
+ image_polar = torch.einsum("...dhw,...hwz->...dzw", image, depth_prob)
46
+ if return_total_score:
47
+ cell_score = torch.logsumexp(depth_scores, dim=1, keepdim=True)
48
+ return image_polar, cell_score.squeeze(1)
49
+ return image_polar
50
+
51
+
52
+ class CartesianProjection(torch.nn.Module):
53
+ def __init__(self, z_max, x_max, ppm, z_min=None):
54
+ super().__init__()
55
+ self.z_max = z_max
56
+ self.x_max = x_max
57
+ self.Δ = Δ = 1 / ppm
58
+ self.z_min = z_min = Δ if z_min is None else z_min
59
+
60
+ grid_xz = make_grid(
61
+ x_max * 2 + Δ, z_max, step_y=Δ, step_x=Δ, orig_y=Δ, orig_x=-x_max, y_up=True
62
+ )
63
+ self.register_buffer("grid_xz", grid_xz, persistent=False)
64
+
65
+ def grid_to_polar(self, cam):
66
+ f, c = cam.f[..., 0][..., None, None], cam.c[..., 0][..., None, None]
67
+ u = from_homogeneous(self.grid_xz).squeeze(-1) * f + c
68
+ z_idx = (self.grid_xz[..., 1] - self.z_min) / self.Δ # convert z value to index
69
+ z_idx = z_idx[None].expand_as(u)
70
+ grid_polar = torch.stack([u, z_idx], -1)
71
+ return grid_polar
72
+
73
+ def sample_from_polar(self, image_polar, valid_polar, grid_uz):
74
+ size = grid_uz.new_tensor(image_polar.shape[-2:][::-1])
75
+ grid_uz_norm = (grid_uz + 0.5) / size * 2 - 1
76
+ grid_uz_norm = grid_uz_norm * grid_uz.new_tensor([1, -1]) # y axis is up
77
+ image_bev = grid_sample(image_polar, grid_uz_norm, align_corners=False)
78
+
79
+ if valid_polar is None:
80
+ valid = torch.ones_like(image_polar[..., :1, :, :])
81
+ else:
82
+ valid = valid_polar.to(image_polar)[:, None]
83
+ valid = grid_sample(valid, grid_uz_norm, align_corners=False)
84
+ valid = valid.squeeze(1) > (1 - 1e-4)
85
+
86
+ return image_bev, valid
87
+
88
+ def forward(self, image_polar, valid_polar, cam):
89
+ grid_uz = self.grid_to_polar(cam)
90
+ image, valid = self.sample_from_polar(image_polar, valid_polar, grid_uz)
91
+ return image, valid, grid_uz
maploc/models/feature_extractor.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ """
8
+ Flexible UNet model which takes any Torchvision backbone as encoder.
9
+ Predicts multi-level feature and makes sure that they are well aligned.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision
15
+
16
+ from .base import BaseModel
17
+ from .utils import checkpointed
18
+
19
+
20
+ class DecoderBlock(nn.Module):
21
+ def __init__(
22
+ self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
23
+ ):
24
+ super().__init__()
25
+
26
+ self.upsample = nn.Upsample(
27
+ scale_factor=2, mode="bilinear", align_corners=False
28
+ )
29
+
30
+ layers = []
31
+ for i in range(num_convs):
32
+ conv = nn.Conv2d(
33
+ previous + skip if i == 0 else out,
34
+ out,
35
+ kernel_size=3,
36
+ padding=1,
37
+ bias=norm is None,
38
+ padding_mode=padding,
39
+ )
40
+ layers.append(conv)
41
+ if norm is not None:
42
+ layers.append(norm(out))
43
+ layers.append(nn.ReLU(inplace=True))
44
+ self.layers = nn.Sequential(*layers)
45
+
46
+ def forward(self, previous, skip):
47
+ upsampled = self.upsample(previous)
48
+ # If the shape of the input map `skip` is not a multiple of 2,
49
+ # it will not match the shape of the upsampled map `upsampled`.
50
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
51
+ # If it uses ceil_mode=True (not supported here), we should pad it.
52
+ _, _, hu, wu = upsampled.shape
53
+ _, _, hs, ws = skip.shape
54
+ assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?"
55
+ # assert (hu == hs) and (wu == ws), 'Careful about padding'
56
+ skip = skip[:, :, :hu, :wu]
57
+ return self.layers(torch.cat([upsampled, skip], dim=1))
58
+
59
+
60
+ class AdaptationBlock(nn.Sequential):
61
+ def __init__(self, inp, out):
62
+ conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True)
63
+ super().__init__(conv)
64
+
65
+
66
+ class FeatureExtractor(BaseModel):
67
+ default_conf = {
68
+ "pretrained": True,
69
+ "input_dim": 3,
70
+ "output_scales": [0, 2, 4], # what scales to adapt and output
71
+ "output_dim": 128, # # of channels in output feature maps
72
+ "encoder": "vgg16", # string (torchvision net) or list of channels
73
+ "num_downsample": 4, # how many downsample block (if VGG-style net)
74
+ "decoder": [64, 64, 64, 64], # list of channels of decoder
75
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
76
+ "do_average_pooling": False,
77
+ "checkpointed": False, # whether to use gradient checkpointing
78
+ "padding": "zeros",
79
+ }
80
+ mean = [0.485, 0.456, 0.406]
81
+ std = [0.229, 0.224, 0.225]
82
+
83
+ def build_encoder(self, conf):
84
+ assert isinstance(conf.encoder, str)
85
+ if conf.pretrained:
86
+ assert conf.input_dim == 3
87
+ Encoder = getattr(torchvision.models, conf.encoder)
88
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None)
89
+ Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed)
90
+ assert max(conf.output_scales) <= conf.num_downsample
91
+
92
+ if conf.encoder.startswith("vgg"):
93
+ # Parse the layers and pack them into downsampling blocks
94
+ # It's easy for VGG-style nets because of their linear structure.
95
+ # This does not handle strided convs and residual connections
96
+ skip_dims = []
97
+ previous_dim = None
98
+ blocks = [[]]
99
+ for i, layer in enumerate(encoder.features):
100
+ if isinstance(layer, torch.nn.Conv2d):
101
+ # Change the first conv layer if the input dim mismatches
102
+ if i == 0 and conf.input_dim != layer.in_channels:
103
+ args = {k: getattr(layer, k) for k in layer.__constants__}
104
+ args.pop("output_padding")
105
+ layer = torch.nn.Conv2d(
106
+ **{**args, "in_channels": conf.input_dim}
107
+ )
108
+ previous_dim = layer.out_channels
109
+ elif isinstance(layer, torch.nn.MaxPool2d):
110
+ assert previous_dim is not None
111
+ skip_dims.append(previous_dim)
112
+ if (conf.num_downsample + 1) == len(blocks):
113
+ break
114
+ blocks.append([]) # start a new block
115
+ if conf.do_average_pooling:
116
+ assert layer.dilation == 1
117
+ layer = torch.nn.AvgPool2d(
118
+ kernel_size=layer.kernel_size,
119
+ stride=layer.stride,
120
+ padding=layer.padding,
121
+ ceil_mode=layer.ceil_mode,
122
+ count_include_pad=False,
123
+ )
124
+ blocks[-1].append(layer)
125
+ encoder = [Block(*b) for b in blocks]
126
+ elif conf.encoder.startswith("resnet"):
127
+ # Manually define the ResNet blocks such that the downsampling comes first
128
+ assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"]
129
+ assert conf.input_dim == 3, "Unsupported for now."
130
+ block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
131
+ block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1)
132
+ block3 = encoder.layer2
133
+ block4 = encoder.layer3
134
+ block5 = encoder.layer4
135
+ blocks = [block1, block2, block3, block4, block5]
136
+ # Extract the output dimension of each block
137
+ skip_dims = [encoder.conv1.out_channels]
138
+ for i in range(1, 5):
139
+ modules = getattr(encoder, f"layer{i}")[-1]._modules
140
+ conv = sorted(k for k in modules if k.startswith("conv"))[-1]
141
+ skip_dims.append(modules[conv].out_channels)
142
+ # Add a dummy block such that the first one does not downsample
143
+ encoder = [torch.nn.Identity()] + [Block(b) for b in blocks]
144
+ skip_dims = [3] + skip_dims
145
+ # Trim based on the requested encoder size
146
+ encoder = encoder[: conf.num_downsample + 1]
147
+ skip_dims = skip_dims[: conf.num_downsample + 1]
148
+ else:
149
+ raise NotImplementedError(conf.encoder)
150
+
151
+ assert (conf.num_downsample + 1) == len(encoder)
152
+ encoder = nn.ModuleList(encoder)
153
+
154
+ return encoder, skip_dims
155
+
156
+ def _init(self, conf):
157
+ # Encoder
158
+ self.encoder, skip_dims = self.build_encoder(conf)
159
+ self.skip_dims = skip_dims
160
+
161
+ def update_padding(module):
162
+ if isinstance(module, nn.Conv2d):
163
+ module.padding_mode = conf.padding
164
+
165
+ if conf.padding != "zeros":
166
+ self.encoder.apply(update_padding)
167
+
168
+ # Decoder
169
+ if conf.decoder is not None:
170
+ assert len(conf.decoder) == (len(skip_dims) - 1)
171
+ Block = checkpointed(DecoderBlock, do=conf.checkpointed)
172
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
173
+
174
+ previous = skip_dims[-1]
175
+ decoder = []
176
+ for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]):
177
+ decoder.append(
178
+ Block(previous, skip, out, norm=norm, padding=conf.padding)
179
+ )
180
+ previous = out
181
+ self.decoder = nn.ModuleList(decoder)
182
+
183
+ # Adaptation layers
184
+ adaptation = []
185
+ for idx, i in enumerate(conf.output_scales):
186
+ if conf.decoder is None or i == (len(self.encoder) - 1):
187
+ input_ = skip_dims[i]
188
+ else:
189
+ input_ = conf.decoder[-1 - i]
190
+
191
+ # out_dim can be an int (same for all scales) or a list (per scale)
192
+ dim = conf.output_dim
193
+ if not isinstance(dim, int):
194
+ dim = dim[idx]
195
+
196
+ block = AdaptationBlock(input_, dim)
197
+ adaptation.append(block)
198
+ self.adaptation = nn.ModuleList(adaptation)
199
+ self.scales = [2**s for s in conf.output_scales]
200
+
201
+ def _forward(self, data):
202
+ image = data["image"]
203
+ if self.conf.pretrained:
204
+ mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
205
+ image = (image - mean[:, None, None]) / std[:, None, None]
206
+
207
+ skip_features = []
208
+ features = image
209
+ for block in self.encoder:
210
+ features = block(features)
211
+ skip_features.append(features)
212
+
213
+ if self.conf.decoder:
214
+ pre_features = [skip_features[-1]]
215
+ for block, skip in zip(self.decoder, skip_features[:-1][::-1]):
216
+ pre_features.append(block(pre_features[-1], skip))
217
+ pre_features = pre_features[::-1] # fine to coarse
218
+ else:
219
+ pre_features = skip_features
220
+
221
+ out_features = []
222
+ for adapt, i in zip(self.adaptation, self.conf.output_scales):
223
+ out_features.append(adapt(pre_features[i]))
224
+ pred = {"feature_maps": out_features, "skip_features": skip_features}
225
+ return pred
226
+
227
+ def loss(self, pred, data):
228
+ raise NotImplementedError
229
+
230
+ def metrics(self, pred, data):
231
+ raise NotImplementedError
maploc/models/feature_extractor_v2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from torchvision.models.feature_extraction import create_feature_extractor
8
+
9
+ from .base import BaseModel
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DecoderBlock(nn.Module):
15
+ def __init__(
16
+ self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
17
+ ):
18
+ super().__init__()
19
+ layers = []
20
+ for i in range(num_convs):
21
+ conv = nn.Conv2d(
22
+ previous if i == 0 else out,
23
+ out,
24
+ kernel_size=ksize,
25
+ padding=ksize // 2,
26
+ bias=norm is None,
27
+ padding_mode=padding,
28
+ )
29
+ layers.append(conv)
30
+ if norm is not None:
31
+ layers.append(norm(out))
32
+ layers.append(nn.ReLU(inplace=True))
33
+ self.layers = nn.Sequential(*layers)
34
+
35
+ def forward(self, previous, skip):
36
+ _, _, hp, wp = previous.shape
37
+ _, _, hs, ws = skip.shape
38
+ scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
39
+ upsampled = nn.functional.interpolate(
40
+ previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
41
+ )
42
+ # If the shape of the input map `skip` is not a multiple of 2,
43
+ # it will not match the shape of the upsampled map `upsampled`.
44
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
45
+ # If it uses ceil_mode=True (not supported here), we should pad it.
46
+ _, _, hu, wu = upsampled.shape
47
+ _, _, hs, ws = skip.shape
48
+ if (hu <= hs) and (wu <= ws):
49
+ skip = skip[:, :, :hu, :wu]
50
+ elif (hu >= hs) and (wu >= ws):
51
+ skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
52
+ else:
53
+ raise ValueError(
54
+ f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
55
+ )
56
+
57
+ return self.layers(skip) + upsampled
58
+
59
+
60
+ class FPN(nn.Module):
61
+ def __init__(self, in_channels_list, out_channels, **kw):
62
+ super().__init__()
63
+ self.first = nn.Conv2d(
64
+ in_channels_list[-1], out_channels, 1, padding=0, bias=True
65
+ )
66
+ self.blocks = nn.ModuleList(
67
+ [
68
+ DecoderBlock(c, out_channels, ksize=1, **kw)
69
+ for c in in_channels_list[::-1][1:]
70
+ ]
71
+ )
72
+ self.out = nn.Sequential(
73
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
74
+ nn.BatchNorm2d(out_channels),
75
+ nn.ReLU(inplace=True),
76
+ )
77
+
78
+ def forward(self, layers):
79
+ feats = None
80
+ for idx, x in enumerate(reversed(layers.values())):
81
+ if feats is None:
82
+ feats = self.first(x)
83
+ else:
84
+ feats = self.blocks[idx - 1](feats, x)
85
+ out = self.out(feats)
86
+ return out
87
+
88
+
89
+ def remove_conv_stride(conv):
90
+ conv_new = nn.Conv2d(
91
+ conv.in_channels,
92
+ conv.out_channels,
93
+ conv.kernel_size,
94
+ bias=conv.bias is not None,
95
+ stride=1,
96
+ padding=conv.padding,
97
+ )
98
+ conv_new.weight = conv.weight
99
+ conv_new.bias = conv.bias
100
+ return conv_new
101
+
102
+
103
+ class FeatureExtractor(BaseModel):
104
+ default_conf = {
105
+ "pretrained": True,
106
+ "input_dim": 3,
107
+ "output_dim": 128, # # of channels in output feature maps
108
+ "encoder": "resnet50", # torchvision net as string
109
+ "remove_stride_from_first_conv": False,
110
+ "num_downsample": None, # how many downsample block
111
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
112
+ "do_average_pooling": False,
113
+ "checkpointed": False, # whether to use gradient checkpointing
114
+ }
115
+ mean = [0.485, 0.456, 0.406]
116
+ std = [0.229, 0.224, 0.225]
117
+
118
+ def build_encoder(self, conf):
119
+ assert isinstance(conf.encoder, str)
120
+ if conf.pretrained:
121
+ assert conf.input_dim == 3
122
+ Encoder = getattr(torchvision.models, conf.encoder)
123
+
124
+ kw = {}
125
+ if conf.encoder.startswith("resnet"):
126
+ layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
127
+ kw["replace_stride_with_dilation"] = [False, False, False]
128
+ elif conf.encoder == "vgg13":
129
+ layers = [
130
+ "features.3",
131
+ "features.8",
132
+ "features.13",
133
+ "features.18",
134
+ "features.23",
135
+ ]
136
+ elif conf.encoder == "vgg16":
137
+ layers = [
138
+ "features.3",
139
+ "features.8",
140
+ "features.15",
141
+ "features.22",
142
+ "features.29",
143
+ ]
144
+ else:
145
+ raise NotImplementedError(conf.encoder)
146
+
147
+ if conf.num_downsample is not None:
148
+ layers = layers[: conf.num_downsample]
149
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
150
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
151
+ if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
152
+ encoder.conv1 = remove_conv_stride(encoder.conv1)
153
+
154
+ if conf.do_average_pooling:
155
+ raise NotImplementedError
156
+ if conf.checkpointed:
157
+ raise NotImplementedError
158
+
159
+ return encoder, layers
160
+
161
+ def _init(self, conf):
162
+ # Preprocessing
163
+ self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
164
+ self.register_buffer("std_", torch.tensor(self.std), persistent=False)
165
+
166
+ # Encoder
167
+ self.encoder, self.layers = self.build_encoder(conf)
168
+ s = 128
169
+ inp = torch.zeros(1, 3, s, s)
170
+ features = list(self.encoder(inp).values())
171
+ self.skip_dims = [x.shape[1] for x in features]
172
+ self.layer_strides = [s / f.shape[-1] for f in features]
173
+ self.scales = [self.layer_strides[0]]
174
+
175
+ # Decoder
176
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
177
+ self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
178
+
179
+ logger.debug(
180
+ "Built feature extractor with layers {name:dim:stride}:\n"
181
+ f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
182
+ f"and output scales {self.scales}."
183
+ )
184
+
185
+ def _forward(self, data):
186
+ image = data["image"]
187
+ image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
188
+
189
+ skip_features = self.encoder(image)
190
+ output = self.decoder(skip_features)
191
+ pred = {"feature_maps": [output], "skip_features": skip_features}
192
+ return pred
maploc/models/map_encoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .base import BaseModel
7
+ from .feature_extractor import FeatureExtractor
8
+
9
+
10
+ class MapEncoder(BaseModel):
11
+ default_conf = {
12
+ "embedding_dim": "???",
13
+ "output_dim": None,
14
+ "num_classes": "???",
15
+ "backbone": "???",
16
+ "unary_prior": False,
17
+ }
18
+
19
+ def _init(self, conf):
20
+ self.embeddings = torch.nn.ModuleDict(
21
+ {
22
+ k: torch.nn.Embedding(n + 1, conf.embedding_dim)
23
+ for k, n in conf.num_classes.items()
24
+ }
25
+ )
26
+ input_dim = len(conf.num_classes) * conf.embedding_dim
27
+ output_dim = conf.output_dim
28
+ if output_dim is None:
29
+ output_dim = conf.backbone.output_dim
30
+ if conf.unary_prior:
31
+ output_dim += 1
32
+ if conf.backbone is None:
33
+ self.encoder = nn.Conv2d(input_dim, output_dim, 1)
34
+ elif conf.backbone == "simple":
35
+ self.encoder = nn.Sequential(
36
+ nn.Conv2d(input_dim, 128, 3, padding=1),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(128, 128, 3, padding=1),
39
+ nn.ReLU(inplace=True),
40
+ nn.Conv2d(128, output_dim, 3, padding=1),
41
+ )
42
+ else:
43
+ self.encoder = FeatureExtractor(
44
+ {
45
+ **conf.backbone,
46
+ "input_dim": input_dim,
47
+ "output_dim": output_dim,
48
+ }
49
+ )
50
+
51
+ def _forward(self, data):
52
+ embeddings = [
53
+ self.embeddings[k](data["map"][:, i])
54
+ for i, k in enumerate(("areas", "ways", "nodes"))
55
+ ]
56
+ embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
57
+ if isinstance(self.encoder, BaseModel):
58
+ features = self.encoder({"image": embeddings})["feature_maps"]
59
+ else:
60
+ features = [self.encoder(embeddings)]
61
+ pred = {}
62
+ if self.conf.unary_prior:
63
+ pred["log_prior"] = [f[:, -1] for f in features]
64
+ features = [f[:, :-1] for f in features]
65
+ pred["map_features"] = features
66
+ return pred