qubvel-hf HF Staff commited on
Commit
9b7fcdb
·
0 Parent(s):

Clean proj with LFS

Browse files
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.webp filter=lfs diff=lfs merge=lfs -text
4
+ *.pt filter=lfs diff=lfs merge=lfs -text
5
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ /gradio*
.vscode/settings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "python.pythonPath": "/conda/install/envs/policygrad/bin/python"
3
+ }
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## XFeat: Accelerated Features for Lightweight Image Matching
2
+ [Guilherme Potje](https://guipotje.github.io/) · [Felipe Cadar](https://eucadar.com/) · [Andre Araujo](https://andrefaraujo.github.io/) · [Renato Martins](https://renatojmsdh.github.io/) · [Erickson R. Nascimento](https://homepages.dcc.ufmg.br/~erickson/)
3
+
4
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE)
5
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/verlab/accelerated_features/blob/main/notebooks/xfeat_matching.ipynb)
6
+
7
+ ### [[ArXiv]](https://arxiv.org/abs/2404.19174) | [[Project Page]](https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/) | [[CVPR'24 Paper]](https://cvpr.thecvf.com/)
8
+
9
+ <div align="center" style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
10
+ <div style="display: flex; justify-content: space-around; width: 100%;">
11
+ <img src='./figs/xfeat.gif' width="400"/>
12
+ <img src='./figs/sift.gif' width="400"/>
13
+ </div>
14
+
15
+ Real-time XFeat demonstration (left) compared to SIFT (right) on a textureless scene. SIFT cannot handle fast camera movements, while XFeat provides robust matches under adverse conditions, while being faster than SIFT on CPU.
16
+
17
+ </div>
18
+
19
+ **TL;DR**: Really fast learned keypoint detector and descriptor. Supports sparse and semi-dense matching.
20
+
21
+ Just wanna quickly try on your images? Check this out: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/verlab/accelerated_features/blob/main/notebooks/xfeat_torch_hub.ipynb)
22
+
23
+ ## Table of Contents
24
+ - [Introduction](#introduction) <img align="right" src='./figs/xfeat_quali.jpg' width=360 />
25
+ - [Installation](#installation)
26
+ - [Usage](#usage)
27
+ - [Inference](#inference)
28
+ - [Training](#training)
29
+ - [Evaluation](#evaluation)
30
+ - [Real-time demo app](#real-time-demo)
31
+ - [Contribute](#contributing)
32
+ - [Citation](#citation)
33
+ - [License](#license)
34
+ - [Acknowledgements](#acknowledgements)
35
+
36
+ ## Introduction
37
+ This repository contains the official implementation of the paper: *[XFeat: Accelerated Features for Lightweight Image Matching](https://arxiv.org/abs/2404.19174)*, to be presented at CVPR 2024.
38
+
39
+ **Motivation.** Why another keypoint detector and descriptor among dozens of existing ones? We noticed that the current trend in the literature focuses on accuracy but often neglects compute efficiency, especially when deploying these solutions in the real-world. For applications in mobile robotics and augmented reality, it is critical that models can run on hardware-constrained computers. To this end, XFeat was designed as an agnostic solution focusing on both accuracy and efficiency in an image matching pipeline.
40
+
41
+ **Capabilities.**
42
+ - Real-time sparse inference on CPU for VGA images (tested on laptop with an i5 CPU and vanilla pytorch);
43
+ - Simple architecture components which facilitates deployment on embedded devices (jetson, raspberry pi, custom AI chips, etc..);
44
+ - Supports both sparse and semi-dense matching of local features;
45
+ - Compact descriptors (64D);
46
+ - Performance comparable to known deep local features such as SuperPoint while being significantly faster and more lightweight. Also, XFeat exhibits much better robustness to viewpoint and illumination changes than classic local features as ORB and SIFT;
47
+ - Supports batched inference if you want ridiculously fast feature extraction. On VGA sparse setting, we achieved about 1,400 FPS using an RTX 4090.
48
+ - For single batch inference on GPU (VGA), one can easily achieve over 150 FPS while leaving lots of room on the GPU for other concurrent tasks.
49
+
50
+ ##
51
+
52
+ **Paper Abstract.** We introduce a lightweight and accurate architecture for resource-efficient visual correspondence. Our method, dubbed XFeat (Accelerated Features), revisits fundamental design choices in convolutional neural networks for detecting, extracting, and matching local features. Our new model satisfies a critical need for fast and robust algorithms suitable to resource-limited devices. In particular, accurate image matching requires sufficiently large image resolutions -- for this reason, we keep the resolution as large as possible while limiting the number of channels in the network. Besides, our model is designed to offer the choice of matching at the sparse or semi-dense levels, each of which may be more suitable for different downstream applications, such as visual navigation and augmented reality. Our model is the first to offer semi-dense matching efficiently, leveraging a novel match refinement module that relies on coarse local descriptors. XFeat is versatile and hardware-independent, surpassing current deep learning-based local features in speed (up to 5x faster) with comparable or better accuracy, proven in pose estimation and visual localization. We showcase it running in real-time on an inexpensive laptop CPU without specialized hardware optimizations.
53
+
54
+ **Overview of XFeat's achitecture.**
55
+ XFeat extracts a keypoint heatmap $\mathbf{K}$, a compact 64-D dense descriptor map $\mathbf{F}$, and a reliability heatmap $\mathbf{R}$. It achieves unparalleled speed via early downsampling and shallow convolutions, followed by deeper convolutions in later encoders for robustness. Contrary to typical methods, it separates keypoint detection into a distinct branch, using $1 \times 1$ convolutions on an $8 \times 8$ tensor-block-transformed image for fast processing, being one of the few current learned methods that decouples detection & description and can be processed independently.
56
+
57
+ <img align="center" src="./figs/xfeat_arq.png" width=1000 />
58
+
59
+ ## Installation
60
+ XFeat has minimal dependencies, only relying on torch. Also, XFeat does not need a GPU for real-time sparse inference (vanilla pytorch w/o any special optimization), unless you run it on high-res images. If you want to run the real-time matching demo, you will also need OpenCV.
61
+ We recommend using conda, but you can use any virtualenv of your choice.
62
+ If you use conda, just create a new env with:
63
+ ```bash
64
+ git clone https://github.com/verlab/accelerated_features.git
65
+ cd accelerated_features
66
+
67
+ #Create conda env
68
+ conda create -n xfeat python=3.8
69
+ conda activate xfeat
70
+ ```
71
+
72
+ Then, install [pytorch (>=1.10)](https://pytorch.org/get-started/previous-versions/) and then the rest of depencencies in case you want to run the demos:
73
+ ```bash
74
+
75
+ #CPU only, for GPU check in pytorch website the most suitable version to your gpu.
76
+ pip install torch==1.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
77
+ # CPU only for MacOS
78
+ # pip install torch==1.10.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html
79
+
80
+ #Install dependencies for the demo
81
+ pip install opencv-contrib-python tqdm
82
+ ```
83
+
84
+ ## Usage
85
+
86
+ For your convenience, we provide ready to use notebooks for some examples.
87
+
88
+ | **Description** | **Notebook** |
89
+ |--------------------------------|-------------------------------|
90
+ | Minimal example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/verlab/accelerated_features/blob/main/notebooks/minimal_example.ipynb) |
91
+ | Matching & registration example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/verlab/accelerated_features/blob/main/notebooks/xfeat_matching.ipynb) |
92
+ | Torch hub example | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/verlab/accelerated_features/blob/main/notebooks/xfeat_torch_hub.ipynb) |
93
+
94
+
95
+ ### Inference
96
+ To run XFeat on an image, three lines of code is enough:
97
+ ```python
98
+ from modules.xfeat import XFeat
99
+
100
+ xfeat = XFeat()
101
+
102
+ #Simple inference with batch sz = 1
103
+ output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0]
104
+ ```
105
+ Or you can use this [script](./minimal_example.py) in the root folder:
106
+ ```bash
107
+ python3 minimal_example.py
108
+ ```
109
+
110
+ If you already have pytorch, simply use torch hub if you like it:
111
+ ```python
112
+ import torch
113
+
114
+ xfeat = torch.hub.load('verlab/accelerated_features', 'XFeat', pretrained = True, top_k = 4096)
115
+
116
+ #Simple inference with batch sz = 1
117
+ output = xfeat.detectAndCompute(torch.randn(1,3,480,640), top_k = 4096)[0]
118
+ ```
119
+
120
+ ### Training
121
+ XFeat training code will be released soon. Please stay tuned.
122
+
123
+ ### Evaluation
124
+ XFeat evaluation code will be released soon, alongside the training scripts. Please stay tuned.
125
+
126
+ ## Real-time Demo
127
+ To demonstrate the capabilities of XFeat, we provide a real-time matching demo with Homography registration. Currently, you can experiment with XFeat, ORB and SIFT. You will need a working webcam. To run the demo and show the possible input flags, please run:
128
+ ```bash
129
+ python3 realtime_demo.py -h
130
+ ```
131
+
132
+ Don't forget to press 's' to set a desired reference image. Notice that the demo only works correctly for planar scenes and rotation-only motion, because we're using a homography model.
133
+
134
+ If you want to run the demo with XFeat, please run:
135
+ ```bash
136
+ python3 realtime_demo.py --method XFeat
137
+ ```
138
+
139
+ Or test with SIFT or ORB:
140
+ ```bash
141
+ python3 realtime_demo.py --method SIFT
142
+ python3 realtime_demo.py --method ORB
143
+ ```
144
+
145
+ ## Contributing
146
+ Contributions to XFeat are welcome!
147
+ Currently, it would be nice to have an export script to efficient deployment engines such as TensorRT and ONNX. Also, it would be cool to train a lightweight learned matcher on top of XFeat local features.
148
+
149
+ ## Citation
150
+ If you find this code useful for your research, please cite the paper:
151
+
152
+ ```bibtex
153
+ @INPROCEEDINGS{potje2024cvpr,
154
+ author={Guilherme {Potje} and and Felipe {Cadar} and Andre {Araujo} and Renato {Martins} and Erickson R. {Nascimento}},
155
+ booktitle={2024 IEEE / CVF Computer Vision and Pattern Recognition (CVPR)},
156
+ title={XFeat: Accelerated Features for Lightweight Image Matching},
157
+ year={2024}}
158
+ ```
159
+
160
+ ## License
161
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE)
162
+
163
+ ## Acknowledgements
164
+ - We thank the agencies CAPES, CNPq, and Google for funding different parts of this work.
165
+ - We thank the developers of Kornia for the [kornia library](https://github.com/kornia/kornia)!
166
+
167
+ **VeRLab:** Laboratory of Computer Vison and Robotics https://www.verlab.dcc.ufmg.br
168
+ <br>
169
+ <img align="left" width="auto" height="50" src="./figs/ufmg.png">
170
+ <img align="right" width="auto" height="50" src="./figs/verlab.png">
171
+ <br/>
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import gradio as gr
4
+
5
+ from modules.xfeat import XFeat
6
+ from utils import visualize_matches
7
+
8
+
9
+ HEADER = """
10
+ <div align="center">
11
+ <p>
12
+ <span style="font-size: 30px; vertical-align: bottom;"> XFeat: Accelerated Features for Lightweight Image Matching</span>
13
+ </p>
14
+ <p style="margin-top: -15px;">
15
+ <a href="https://arxiv.org/abs/2404.19174" target="_blank" style="color: grey;">ArXiv Paper</a>
16
+ &nbsp;
17
+ <a href="https://github.com/verlab/accelerated_features" target="_blank" style="color: grey;">GitHub Repository</a>
18
+ </p>
19
+ <p>
20
+ Upload two images 🖼️ of the object and identify matches between them 🚀
21
+ </p>
22
+ </div>
23
+ """
24
+
25
+ ABSTRACT = """
26
+ We introduce a lightweight and accurate architecture for resource-efficient visual correspondence. Our method, dubbed XFeat (Accelerated Features), revisits fundamental design choices in convolutional neural networks for detecting, extracting, and matching local features. Our new model satisfies a critical need for fast and robust algorithms suitable to resource-limited devices. In particular, accurate image matching requires sufficiently large image resolutions -- for this reason, we keep the resolution as large as possible while limiting the number of channels in the network. Besides, our model is designed to offer the choice of matching at the sparse or semi-dense levels, each of which may be more suitable for different downstream applications, such as visual navigation and augmented reality. Our model is the first to offer semi-dense matching efficiently, leveraging a novel match refinement module that relies on coarse local descriptors. XFeat is versatile and hardware-independent, surpassing current deep learning-based local features in speed (up to 5x faster) with comparable or better accuracy, proven in pose estimation and visual localization. We showcase it running in real-time on an inexpensive laptop CPU without specialized hardware optimizations.
27
+ """
28
+
29
+ def find_matches(image_0, image_1):
30
+
31
+ image_0_bgr = cv2.cvtColor(image_0, cv2.COLOR_RGB2BGR)
32
+ image_1_bgr = cv2.cvtColor(image_1, cv2.COLOR_RGB2BGR)
33
+
34
+ xfeat = XFeat(weights="weights/xfeat.pt", top_k=4096)
35
+
36
+ #Use out-of-the-box function for extraction + MNN matching
37
+ match_kp0, match_kp1 = xfeat.match_xfeat(image_0_bgr, image_1_bgr, top_k = 4096)
38
+
39
+ # canvas = warp_corners_and_draw_matches(mkpts_0, mkpts_1, image_0, image_1)
40
+
41
+ _, mask = cv2.findHomography(match_kp0, match_kp1, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
42
+ keep = mask.flatten().astype(bool)
43
+
44
+ match_kp0 = match_kp0[keep]
45
+ match_kp1 = match_kp1[keep]
46
+
47
+ num_filtered_matches = len(match_kp0)
48
+
49
+ viz = visualize_matches(
50
+ image_0,
51
+ image_1,
52
+ match_kp0,
53
+ match_kp1,
54
+ np.eye(num_filtered_matches),
55
+ show_keypoints=True,
56
+ highlight_unmatched=True,
57
+ title=f"{num_filtered_matches} matches",
58
+ line_width=2,
59
+ )
60
+
61
+ return viz
62
+
63
+
64
+ with gr.Blocks() as demo:
65
+
66
+ gr.Markdown(HEADER)
67
+ with gr.Accordion("Abstract (click to open)", open=False):
68
+ gr.Image("assets/xfeat_arq.png")
69
+ gr.Markdown(ABSTRACT)
70
+
71
+ with gr.Row():
72
+ image_1 = gr.Image()
73
+ image_2 = gr.Image()
74
+ with gr.Row():
75
+ button = gr.Button(value="Find Matches")
76
+ clear = gr.ClearButton(value="Clear")
77
+ output = gr.Image()
78
+ button.click(find_matches, [image_1, image_2], output)
79
+ clear.add([image_1, image_2, output])
80
+
81
+ gr.Examples(
82
+ examples=[
83
+ ["assets/ref.png", "assets/tgt.png"],
84
+ ["assets/demo1.jpg", "assets/demo2.jpg"],
85
+ ["assets/tower-1.webp", "assets/tower-2.jpeg"],
86
+ ],
87
+ inputs=[image_1, image_2],
88
+ outputs=[output],
89
+ fn=find_matches,
90
+ cache_examples=None,
91
+ )
92
+
93
+ if __name__ == "__main__":
94
+ demo.launch()
assets/demo1.jpg ADDED

Git LFS Details

  • SHA256: 0c3719183ae9139e45569e16861f42ac8e47b46c86f3536fdc52b22011f31871
  • Pointer size: 130 Bytes
  • Size of remote file: 85.3 kB
assets/demo2.jpg ADDED

Git LFS Details

  • SHA256: 24dbe3a2ee909002b265e647b96a7141419c954a2a90b235699c186f927705c4
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
assets/ref.png ADDED

Git LFS Details

  • SHA256: 1292e2ba509b338a05820e7bf62dcda0b26688a4a4307996ee7c295e6627bee2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
assets/tgt.png ADDED

Git LFS Details

  • SHA256: ffd0fe993cd7967f2e0e44495c9f60828f8e2d6a61440ede4a951ee67b865613
  • Pointer size: 131 Bytes
  • Size of remote file: 946 kB
assets/tower-1.webp ADDED

Git LFS Details

  • SHA256: ee7721d6a79481e2b255826a2f294a6d01f1a8f9b58fbe1253c4ae5028cf69c2
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
assets/tower-2.jpeg ADDED
assets/xfeat_arq.png ADDED

Git LFS Details

  • SHA256: 071bf66baf111568ce8a2b27879f1f0b4b0e27552845db35ecd44c9d202cf5ab
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
figs/sift.gif ADDED

Git LFS Details

  • SHA256: 8059f5d11b8cfc01fa96894a7a2aba1c4c5079d808531a4b39d74538e2f5f312
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
figs/ufmg.png ADDED

Git LFS Details

  • SHA256: 1210e1ef2125305677de696645609625313d3575e9c73f78ddf0998da37f598a
  • Pointer size: 129 Bytes
  • Size of remote file: 7.61 kB
figs/verlab.png ADDED

Git LFS Details

  • SHA256: 12d6e4bfc62dc503311c927360d01b9c506df4df4853fad930fb6b5b9480e86a
  • Pointer size: 130 Bytes
  • Size of remote file: 16.2 kB
figs/xfeat.gif ADDED

Git LFS Details

  • SHA256: 1dff14e4d08150f0000735b88dd961d0b78983ffc00f2d32ff1a1a9d826c3a3d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.2 MB
figs/xfeat_arq.png ADDED

Git LFS Details

  • SHA256: 071bf66baf111568ce8a2b27879f1f0b4b0e27552845db35ecd44c9d202cf5ab
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
figs/xfeat_quali.jpg ADDED

Git LFS Details

  • SHA256: 8c5f7981649d80ab757d232bc755b459c396eff2a9f4572e393aecc834128588
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
hubconf.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch']
2
+ from modules.xfeat import XFeat as _XFeat
3
+ import torch
4
+
5
+ def XFeat(pretrained=True, top_k=4096):
6
+ """
7
+ XFeat model
8
+ pretrained (bool): kwargs, load pretrained weights into the model
9
+ """
10
+ weights = None
11
+ if pretrained:
12
+ weights = torch.hub.load_state_dict_from_url("https://github.com/verlab/accelerated_features/raw/main/weights/xfeat.pt")
13
+
14
+ model = _XFeat(weights, top_k=top_k)
15
+ return model
minimal_example.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+
5
+ Minimal example of how to use XFeat.
6
+ """
7
+
8
+ import numpy as np
9
+ import os
10
+ import torch
11
+ import tqdm
12
+
13
+ from modules.xfeat import XFeat
14
+
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = '' #Force CPU, comment for GPU
16
+
17
+ xfeat = XFeat()
18
+
19
+ #Random input
20
+ x = torch.randn(1,3,480,640)
21
+
22
+ #Simple inference with batch = 1
23
+ output = xfeat.detectAndCompute(x, top_k = 4096)[0]
24
+ print("----------------")
25
+ print("keypoints: ", output['keypoints'].shape)
26
+ print("descriptors: ", output['descriptors'].shape)
27
+ print("scores: ", output['scores'].shape)
28
+ print("----------------\n")
29
+
30
+ x = torch.randn(1,3,480,640)
31
+ # Stress test
32
+ for i in tqdm.tqdm(range(100), desc="Stress test on VGA resolution"):
33
+ output = xfeat.detectAndCompute(x, top_k = 4096)
34
+
35
+ # Batched mode
36
+ x = torch.randn(4,3,480,640)
37
+ outputs = xfeat.detectAndCompute(x, top_k = 4096)
38
+ print("# detected features on each batch item:", [len(o['keypoints']) for o in outputs])
39
+
40
+ # Match two images with sparse features
41
+ x1 = torch.randn(1,3,480,640)
42
+ x2 = torch.randn(1,3,480,640)
43
+ mkpts_0, mkpts_1 = xfeat.match_xfeat(x1, x2)
44
+
45
+ # Match two images with semi-dense approach -- batched mode with batch size 4
46
+ x1 = torch.randn(4,3,480,640)
47
+ x2 = torch.randn(4,3,480,640)
48
+ matches_list = xfeat.match_xfeat_star(x1, x2)
49
+ print(matches_list[0].shape)
modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+ """
modules/interpolator.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class InterpolateSparse2d(nn.Module):
11
+ """ Efficiently interpolate tensor at given sparse 2D positions. """
12
+ def __init__(self, mode = 'bicubic', align_corners = False):
13
+ super().__init__()
14
+ self.mode = mode
15
+ self.align_corners = align_corners
16
+
17
+ def normgrid(self, x, H, W):
18
+ """ Normalize coords to [-1,1]. """
19
+ return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
20
+
21
+ def forward(self, x, pos, H, W):
22
+ """
23
+ Input
24
+ x: [B, C, H, W] feature tensor
25
+ pos: [B, N, 2] tensor of positions
26
+ H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
27
+
28
+ Returns
29
+ [B, N, C] sampled channels at 2d positions
30
+ """
31
+ grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
32
+ x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
33
+ return x.permute(0,2,3,1).squeeze(-2)
modules/model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+ """
5
+
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import time
11
+
12
+ class BasicLayer(nn.Module):
13
+ """
14
+ Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
15
+ """
16
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
17
+ super().__init__()
18
+ self.layer = nn.Sequential(
19
+ nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
20
+ nn.BatchNorm2d(out_channels, affine=False),
21
+ nn.ReLU(inplace = True),
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.layer(x)
26
+
27
+ class XFeatModel(nn.Module):
28
+ """
29
+ Implementation of architecture described in
30
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
31
+ """
32
+
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.norm = nn.InstanceNorm2d(1)
36
+
37
+
38
+ ########### ⬇️ CNN Backbone & Heads ⬇️ ###########
39
+
40
+ self.skip1 = nn.Sequential( nn.AvgPool2d(4, stride = 4),
41
+ nn.Conv2d (1, 24, 1, stride = 1, padding=0) )
42
+
43
+ self.block1 = nn.Sequential(
44
+ BasicLayer( 1, 4, stride=1),
45
+ BasicLayer( 4, 8, stride=2),
46
+ BasicLayer( 8, 8, stride=1),
47
+ BasicLayer( 8, 24, stride=2),
48
+ )
49
+
50
+ self.block2 = nn.Sequential(
51
+ BasicLayer(24, 24, stride=1),
52
+ BasicLayer(24, 24, stride=1),
53
+ )
54
+
55
+ self.block3 = nn.Sequential(
56
+ BasicLayer(24, 64, stride=2),
57
+ BasicLayer(64, 64, stride=1),
58
+ BasicLayer(64, 64, 1, padding=0),
59
+ )
60
+ self.block4 = nn.Sequential(
61
+ BasicLayer(64, 64, stride=2),
62
+ BasicLayer(64, 64, stride=1),
63
+ BasicLayer(64, 64, stride=1),
64
+ )
65
+
66
+ self.block5 = nn.Sequential(
67
+ BasicLayer( 64, 128, stride=2),
68
+ BasicLayer(128, 128, stride=1),
69
+ BasicLayer(128, 128, stride=1),
70
+ BasicLayer(128, 64, 1, padding=0),
71
+ )
72
+
73
+ self.block_fusion = nn.Sequential(
74
+ BasicLayer(64, 64, stride=1),
75
+ BasicLayer(64, 64, stride=1),
76
+ nn.Conv2d (64, 64, 1, padding=0)
77
+ )
78
+
79
+ self.heatmap_head = nn.Sequential(
80
+ BasicLayer(64, 64, 1, padding=0),
81
+ BasicLayer(64, 64, 1, padding=0),
82
+ nn.Conv2d (64, 1, 1),
83
+ nn.Sigmoid()
84
+ )
85
+
86
+
87
+ self.keypoint_head = nn.Sequential(
88
+ BasicLayer(64, 64, 1, padding=0),
89
+ BasicLayer(64, 64, 1, padding=0),
90
+ BasicLayer(64, 64, 1, padding=0),
91
+ nn.Conv2d (64, 65, 1),
92
+ )
93
+
94
+
95
+ ########### ⬇️ Fine Matcher MLP ⬇️ ###########
96
+
97
+ self.fine_matcher = nn.Sequential(
98
+ nn.Linear(128, 512),
99
+ nn.BatchNorm1d(512, affine=False),
100
+ nn.ReLU(inplace = True),
101
+ nn.Linear(512, 512),
102
+ nn.BatchNorm1d(512, affine=False),
103
+ nn.ReLU(inplace = True),
104
+ nn.Linear(512, 512),
105
+ nn.BatchNorm1d(512, affine=False),
106
+ nn.ReLU(inplace = True),
107
+ nn.Linear(512, 512),
108
+ nn.BatchNorm1d(512, affine=False),
109
+ nn.ReLU(inplace = True),
110
+ nn.Linear(512, 64),
111
+ )
112
+
113
+ def _unfold2d(self, x, ws = 2):
114
+ """
115
+ Unfolds tensor in 2D with desired ws (window size) and concat the channels
116
+ """
117
+ B, C, H, W = x.shape
118
+ x = x.unfold(2, ws , ws).unfold(3, ws,ws) \
119
+ .reshape(B, C, H//ws, W//ws, ws**2)
120
+ return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)
121
+
122
+
123
+ def forward(self, x):
124
+ """
125
+ input:
126
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
127
+ return:
128
+ feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
129
+ keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
130
+ heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map
131
+
132
+ """
133
+ #dont backprop through normalization
134
+ with torch.no_grad():
135
+ x = x.mean(dim=1, keepdim = True)
136
+ x = self.norm(x)
137
+
138
+ #main backbone
139
+ x1 = self.block1(x)
140
+ x2 = self.block2(x1 + self.skip1(x))
141
+ x3 = self.block3(x2)
142
+ x4 = self.block4(x3)
143
+ x5 = self.block5(x4)
144
+
145
+ #pyramid fusion
146
+ x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
147
+ x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
148
+ feats = self.block_fusion( x3 + x4 + x5 )
149
+
150
+ #heads
151
+ heatmap = self.heatmap_head(feats) # Reliability map
152
+ keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits
153
+
154
+ return feats, keypoints, heatmap
modules/xfeat.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
4
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
5
+ """
6
+
7
+ import numpy as np
8
+ import os
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import tqdm
13
+
14
+ from modules.model import *
15
+ from modules.interpolator import InterpolateSparse2d
16
+
17
+ class XFeat(nn.Module):
18
+ """
19
+ Implements the inference module for XFeat.
20
+ It supports inference for both sparse and semi-dense feature extraction & matching.
21
+ """
22
+
23
+ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096):
24
+ super().__init__()
25
+ self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ self.net = XFeatModel().to(self.dev).eval()
27
+ self.top_k = top_k
28
+
29
+ if weights is not None:
30
+ if isinstance(weights, str):
31
+ print('loading weights from: ' + weights)
32
+ self.net.load_state_dict(torch.load(weights, map_location=self.dev))
33
+ else:
34
+ self.net.load_state_dict(weights)
35
+
36
+ self.interpolator = InterpolateSparse2d('bicubic')
37
+
38
+ @torch.inference_mode()
39
+ def detectAndCompute(self, x, top_k = None):
40
+ """
41
+ Compute sparse keypoints & descriptors. Supports batched mode.
42
+
43
+ input:
44
+ x -> torch.Tensor(B, C, H, W): grayscale or rgb image
45
+ top_k -> int: keep best k features
46
+ return:
47
+ List[Dict]:
48
+ 'keypoints' -> torch.Tensor(N, 2): keypoints (x,y)
49
+ 'scores' -> torch.Tensor(N,): keypoint scores
50
+ 'descriptors' -> torch.Tensor(N, 64): local features
51
+ """
52
+ if top_k is None: top_k = self.top_k
53
+ x, rh1, rw1 = self.preprocess_tensor(x)
54
+
55
+ B, _, _H1, _W1 = x.shape
56
+
57
+ M1, K1, H1 = self.net(x)
58
+ M1 = F.normalize(M1, dim=1)
59
+
60
+ #Convert logits to heatmap and extract kpts
61
+ K1h = self.get_kpts_heatmap(K1)
62
+ mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5)
63
+
64
+ #Compute reliability scores
65
+ _nearest = InterpolateSparse2d('nearest')
66
+ _bilinear = InterpolateSparse2d('bilinear')
67
+ scores = (_nearest(K1h, mkpts, _H1, _W1) * _bilinear(H1, mkpts, _H1, _W1)).squeeze(-1)
68
+ scores[torch.all(mkpts == 0, dim=-1)] = -1
69
+
70
+ #Select top-k features
71
+ idxs = torch.argsort(-scores)
72
+ mkpts_x = torch.gather(mkpts[...,0], -1, idxs)[:, :top_k]
73
+ mkpts_y = torch.gather(mkpts[...,1], -1, idxs)[:, :top_k]
74
+ mkpts = torch.cat([mkpts_x[...,None], mkpts_y[...,None]], dim=-1)
75
+ scores = torch.gather(scores, -1, idxs)[:, :top_k]
76
+
77
+ #Interpolate descriptors at kpts positions
78
+ feats = self.interpolator(M1, mkpts, H = _H1, W = _W1)
79
+
80
+ #L2-Normalize
81
+ feats = F.normalize(feats, dim=-1)
82
+
83
+ #Correct kpt scale
84
+ mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)
85
+
86
+ valid = scores > 0
87
+ return [
88
+ {'keypoints': mkpts[b][valid[b]],
89
+ 'scores': scores[b][valid[b]],
90
+ 'descriptors': feats[b][valid[b]]} for b in range(B)
91
+ ]
92
+
93
+ @torch.inference_mode()
94
+ def detectAndComputeDense(self, x, top_k = None, multiscale = True):
95
+ """
96
+ Compute dense *and coarse* descriptors. Supports batched mode.
97
+
98
+ input:
99
+ x -> torch.Tensor(B, C, H, W): grayscale or rgb image
100
+ top_k -> int: keep best k features
101
+ return: features sorted by their reliability score -- from most to least
102
+ List[Dict]:
103
+ 'keypoints' -> torch.Tensor(top_k, 2): coarse keypoints
104
+ 'scales' -> torch.Tensor(top_k,): extraction scale
105
+ 'descriptors' -> torch.Tensor(top_k, 64): coarse local features
106
+ """
107
+ if top_k is None: top_k = self.top_k
108
+ if multiscale:
109
+ mkpts, sc, feats = self.extract_dualscale(x, top_k)
110
+ else:
111
+ mkpts, feats = self.extractDense(x, top_k)
112
+ sc = torch.ones(mkpts.shape[:2], device=mkpts.device)
113
+
114
+ return {'keypoints': mkpts,
115
+ 'descriptors': feats,
116
+ 'scales': sc }
117
+
118
+ @torch.inference_mode()
119
+ def match_xfeat(self, img1, img2, top_k = None, min_cossim = -1):
120
+ """
121
+ Simple extractor and MNN matcher.
122
+ For simplicity it does not support batched mode due to possibly different number of kpts.
123
+ input:
124
+ img1 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
125
+ img2 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
126
+ top_k -> int: keep best k features
127
+ returns:
128
+ mkpts_0, mkpts_1 -> np.ndarray (N,2) xy coordinate matches from image1 to image2
129
+ """
130
+ if top_k is None: top_k = self.top_k
131
+ img1 = self.parse_input(img1)
132
+ img2 = self.parse_input(img2)
133
+
134
+ out1 = self.detectAndCompute(img1, top_k=top_k)[0]
135
+ out2 = self.detectAndCompute(img2, top_k=top_k)[0]
136
+
137
+ idxs0, idxs1 = self.match(out1['descriptors'], out2['descriptors'], min_cossim=min_cossim )
138
+
139
+ return out1['keypoints'][idxs0].cpu().numpy(), out2['keypoints'][idxs1].cpu().numpy()
140
+
141
+ @torch.inference_mode()
142
+ def match_xfeat_star(self, im_set1, im_set2, top_k = None):
143
+ """
144
+ Extracts coarse feats, then match pairs and finally refine matches, currently supports batched mode.
145
+ input:
146
+ im_set1 -> torch.Tensor(B, C, H, W) or np.ndarray (H,W,C): grayscale or rgb images.
147
+ im_set2 -> torch.Tensor(B, C, H, W) or np.ndarray (H,W,C): grayscale or rgb images.
148
+ top_k -> int: keep best k features
149
+ returns:
150
+ matches -> List[torch.Tensor(N, 4)]: List of size B containing tensor of pairwise matches (x1,y1,x2,y2)
151
+ """
152
+ if top_k is None: top_k = self.top_k
153
+ im_set1 = self.parse_input(im_set1)
154
+ im_set2 = self.parse_input(im_set2)
155
+
156
+ #Compute coarse feats
157
+ out1 = self.detectAndComputeDense(im_set1, top_k=top_k)
158
+ out2 = self.detectAndComputeDense(im_set2, top_k=top_k)
159
+
160
+ #Match batches of pairs
161
+ idxs_list = self.batch_match(out1['descriptors'], out2['descriptors'] )
162
+ B = len(im_set1)
163
+
164
+ #Refine coarse matches
165
+ #this part is harder to batch, currently iterate
166
+ matches = []
167
+ for b in range(B):
168
+ matches.append(self.refine_matches(out1, out2, matches = idxs_list, batch_idx=b))
169
+
170
+ return matches if B > 1 else (matches[0][:, :2].cpu().numpy(), matches[0][:, 2:].cpu().numpy())
171
+
172
+ def preprocess_tensor(self, x):
173
+ """ Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
174
+ if isinstance(x, np.ndarray) and x.shape == 3:
175
+ x = torch.tensor(x).permute(2,0,1)[None]
176
+ x = x.to(self.dev).float()
177
+
178
+ H, W = x.shape[-2:]
179
+ _H, _W = (H//32) * 32, (W//32) * 32
180
+ rh, rw = H/_H, W/_W
181
+
182
+ x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
183
+ return x, rh, rw
184
+
185
+ def get_kpts_heatmap(self, kpts, softmax_temp = 1.0):
186
+ scores = F.softmax(kpts*softmax_temp, 1)[:, :64]
187
+ B, _, H, W = scores.shape
188
+ heatmap = scores.permute(0, 2, 3, 1).reshape(B, H, W, 8, 8)
189
+ heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(B, 1, H*8, W*8)
190
+ return heatmap
191
+
192
+ def NMS(self, x, threshold = 0.05, kernel_size = 5):
193
+ B, _, H, W = x.shape
194
+ pad=kernel_size//2
195
+ local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
196
+ pos = (x == local_max) & (x > threshold)
197
+ pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
198
+
199
+ pad_val = max([len(x) for x in pos_batched])
200
+ pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
201
+
202
+ #Pad kpts and build (B, N, 2) tensor
203
+ for b in range(len(pos_batched)):
204
+ pos[b, :len(pos_batched[b]), :] = pos_batched[b]
205
+
206
+ return pos
207
+
208
+ @torch.inference_mode()
209
+ def batch_match(self, feats1, feats2, min_cossim = -1):
210
+ B = len(feats1)
211
+ cossim = torch.bmm(feats1, feats2.permute(0,2,1))
212
+ match12 = torch.argmax(cossim, dim=-1)
213
+ match21 = torch.argmax(cossim.permute(0,2,1), dim=-1)
214
+
215
+ idx0 = torch.arange(len(match12[0]), device=match12.device)
216
+
217
+ batched_matches = []
218
+
219
+ for b in range(B):
220
+ mutual = match21[b][match12[b]] == idx0
221
+
222
+ if min_cossim > 0:
223
+ cossim_max, _ = cossim[b].max(dim=1)
224
+ good = cossim_max > min_cossim
225
+ idx0_b = idx0[mutual & good]
226
+ idx1_b = match12[b][mutual & good]
227
+ else:
228
+ idx0_b = idx0[mutual]
229
+ idx1_b = match12[b][mutual]
230
+
231
+ batched_matches.append((idx0_b, idx1_b))
232
+
233
+ return batched_matches
234
+
235
+ def subpix_softmax2d(self, heatmaps, temp = 3):
236
+ N, H, W = heatmaps.shape
237
+ heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
238
+ x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
239
+ x = x - (W//2)
240
+ y = y - (H//2)
241
+
242
+ coords_x = (x[None, ...] * heatmaps)
243
+ coords_y = (y[None, ...] * heatmaps)
244
+ coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
245
+ coords = coords.sum(1)
246
+
247
+ return coords
248
+
249
+ def refine_matches(self, d0, d1, matches, batch_idx, fine_conf = 0.25):
250
+ idx0, idx1 = matches[batch_idx]
251
+ feats1 = d0['descriptors'][batch_idx][idx0]
252
+ feats2 = d1['descriptors'][batch_idx][idx1]
253
+ mkpts_0 = d0['keypoints'][batch_idx][idx0]
254
+ mkpts_1 = d1['keypoints'][batch_idx][idx1]
255
+ sc0 = d0['scales'][batch_idx][idx0]
256
+
257
+ #Compute fine offsets
258
+ offsets = self.net.fine_matcher(torch.cat([feats1, feats2],dim=-1))
259
+ conf = F.softmax(offsets*3, dim=-1).max(dim=-1)[0]
260
+ offsets = self.subpix_softmax2d(offsets.view(-1,8,8))
261
+
262
+ mkpts_0 += offsets* (sc0[:,None]) #*0.9 #* (sc0[:,None])
263
+
264
+ mask_good = conf > fine_conf
265
+ mkpts_0 = mkpts_0[mask_good]
266
+ mkpts_1 = mkpts_1[mask_good]
267
+
268
+ return torch.cat([mkpts_0, mkpts_1], dim=-1)
269
+
270
+ @torch.inference_mode()
271
+ def match(self, feats1, feats2, min_cossim = 0.82):
272
+
273
+ cossim = feats1 @ feats2.t()
274
+ cossim_t = feats2 @ feats1.t()
275
+
276
+ _, match12 = cossim.max(dim=1)
277
+ _, match21 = cossim_t.max(dim=1)
278
+
279
+ idx0 = torch.arange(len(match12), device=match12.device)
280
+ mutual = match21[match12] == idx0
281
+
282
+ if min_cossim > 0:
283
+ cossim, _ = cossim.max(dim=1)
284
+ good = cossim > min_cossim
285
+ idx0 = idx0[mutual & good]
286
+ idx1 = match12[mutual & good]
287
+ else:
288
+ idx0 = idx0[mutual]
289
+ idx1 = match12[mutual]
290
+
291
+ return idx0, idx1
292
+
293
+ def create_xy(self, h, w, dev):
294
+ y, x = torch.meshgrid(torch.arange(h, device = dev),
295
+ torch.arange(w, device = dev), indexing='ij')
296
+ xy = torch.cat([x[..., None],y[..., None]], -1).reshape(-1,2)
297
+ return xy
298
+
299
+ def extractDense(self, x, top_k = 8_000):
300
+ if top_k < 1:
301
+ top_k = 100_000_000
302
+
303
+ x, rh1, rw1 = self.preprocess_tensor(x)
304
+
305
+ M1, K1, H1 = self.net(x)
306
+
307
+ B, C, _H1, _W1 = M1.shape
308
+
309
+ xy1 = (self.create_xy(_H1, _W1, M1.device) * 8).expand(B,-1,-1)
310
+
311
+ M1 = M1.permute(0,2,3,1).reshape(B, -1, C)
312
+ H1 = H1.permute(0,2,3,1).reshape(B, -1)
313
+
314
+ _, top_k = torch.topk(H1, k = min(len(H1[0]), top_k), dim=-1)
315
+
316
+ feats = torch.gather( M1, 1, top_k[...,None].expand(-1, -1, 64))
317
+ mkpts = torch.gather(xy1, 1, top_k[...,None].expand(-1, -1, 2))
318
+ mkpts = mkpts * torch.tensor([rw1, rh1], device=mkpts.device).view(1,-1)
319
+
320
+ return mkpts, feats
321
+
322
+ def extract_dualscale(self, x, top_k, s1 = 0.6, s2 = 1.3):
323
+ x1 = F.interpolate(x, scale_factor=s1, align_corners=False, mode='bilinear')
324
+ x2 = F.interpolate(x, scale_factor=s2, align_corners=False, mode='bilinear')
325
+
326
+ B, _, _, _ = x.shape
327
+
328
+ mkpts_1, feats_1 = self.extractDense(x1, int(top_k*0.20))
329
+ mkpts_2, feats_2 = self.extractDense(x2, int(top_k*0.80))
330
+
331
+ mkpts = torch.cat([mkpts_1/s1, mkpts_2/s2], dim=1)
332
+ sc1 = torch.ones(mkpts_1.shape[:2], device=mkpts_1.device) * (1/s1)
333
+ sc2 = torch.ones(mkpts_2.shape[:2], device=mkpts_2.device) * (1/s2)
334
+ sc = torch.cat([sc1, sc2],dim=1)
335
+ feats = torch.cat([feats_1, feats_2], dim=1)
336
+
337
+ return mkpts, sc, feats
338
+
339
+ def parse_input(self, x):
340
+ if len(x.shape) == 3:
341
+ x = x[None, ...]
342
+
343
+ if isinstance(x, np.ndarray):
344
+ x = torch.tensor(x).permute(0,3,1,2)/255
345
+
346
+ return x
notebooks/minimal_example.ipynb ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "#XFeat minimal inference example"
21
+ ],
22
+ "metadata": {
23
+ "id": "2tDj94al5GAJ"
24
+ }
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "source": [
29
+ "## Clone repository"
30
+ ],
31
+ "metadata": {
32
+ "id": "X8MPXBro5IFv"
33
+ }
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 1,
38
+ "metadata": {
39
+ "colab": {
40
+ "base_uri": "https://localhost:8080/"
41
+ },
42
+ "id": "tVkH1ChzNcLW",
43
+ "outputId": "da9a9474-76bd-4b66-8ecd-8ba0022f030e"
44
+ },
45
+ "outputs": [
46
+ {
47
+ "output_type": "stream",
48
+ "name": "stdout",
49
+ "text": [
50
+ "Cloning into 'accelerated_features'...\n",
51
+ "remote: Enumerating objects: 27, done.\u001b[K\n",
52
+ "remote: Counting objects: 100% (11/11), done.\u001b[K\n",
53
+ "remote: Compressing objects: 100% (10/10), done.\u001b[K\n",
54
+ "remote: Total 27 (delta 0), reused 5 (delta 0), pack-reused 16\u001b[K\n",
55
+ "Receiving objects: 100% (27/27), 13.29 MiB | 23.03 MiB/s, done.\n",
56
+ "Resolving deltas: 100% (1/1), done.\n",
57
+ "/content/accelerated_features\n"
58
+ ]
59
+ }
60
+ ],
61
+ "source": [
62
+ "!cd /content && git clone 'https://github.com/verlab/accelerated_features.git'\n",
63
+ "%cd /content/accelerated_features"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "source": [
69
+ "## Test on simple input (sparse setting)"
70
+ ],
71
+ "metadata": {
72
+ "id": "32T-WzfU5NRH"
73
+ }
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "source": [
78
+ "import numpy as np\n",
79
+ "import os\n",
80
+ "import torch\n",
81
+ "import tqdm\n",
82
+ "\n",
83
+ "from modules.xfeat import XFeat\n",
84
+ "\n",
85
+ "xfeat = XFeat()\n",
86
+ "\n",
87
+ "#Random input\n",
88
+ "x = torch.randn(1,3,480,640)\n",
89
+ "\n",
90
+ "#Simple inference with batch = 1\n",
91
+ "output = xfeat.detectAndCompute(x, top_k = 4096)[0]\n",
92
+ "print(\"----------------\")\n",
93
+ "print(\"keypoints: \", output['keypoints'].shape)\n",
94
+ "print(\"descriptors: \", output['descriptors'].shape)\n",
95
+ "print(\"scores: \", output['scores'].shape)\n",
96
+ "print(\"----------------\\n\")"
97
+ ],
98
+ "metadata": {
99
+ "colab": {
100
+ "base_uri": "https://localhost:8080/"
101
+ },
102
+ "id": "o1TMnCEfNfvD",
103
+ "outputId": "f59757f5-477a-4642-e955-7a5abefe3c21"
104
+ },
105
+ "execution_count": 2,
106
+ "outputs": [
107
+ {
108
+ "output_type": "stream",
109
+ "name": "stdout",
110
+ "text": [
111
+ "loading weights from: /content/accelerated_features/modules/../weights/xfeat.pt\n",
112
+ "----------------\n",
113
+ "keypoints: torch.Size([4096, 2])\n",
114
+ "descriptors: torch.Size([4096, 64])\n",
115
+ "scores: torch.Size([4096])\n",
116
+ "----------------\n",
117
+ "\n"
118
+ ]
119
+ }
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "source": [
125
+ "## Stress test to check FPS on VGA (sparse setting)"
126
+ ],
127
+ "metadata": {
128
+ "id": "8b9C09ya5UwL"
129
+ }
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "source": [
134
+ "x = torch.randn(1,3,480,640)\n",
135
+ "# Stress test\n",
136
+ "for i in tqdm.tqdm(range(100), desc=\"Stress test on VGA resolution\"):\n",
137
+ "\toutput = xfeat.detectAndCompute(x, top_k = 4096)\n"
138
+ ],
139
+ "metadata": {
140
+ "colab": {
141
+ "base_uri": "https://localhost:8080/"
142
+ },
143
+ "id": "Zsjz-QT95ZrM",
144
+ "outputId": "2df6f545-419f-4cc3-ad8b-bf5e12741dba"
145
+ },
146
+ "execution_count": 7,
147
+ "outputs": [
148
+ {
149
+ "output_type": "stream",
150
+ "name": "stderr",
151
+ "text": [
152
+ "Stress test on VGA resolution: 100%|██████████| 100/100 [00:14<00:00, 6.74it/s]\n"
153
+ ]
154
+ }
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "markdown",
159
+ "source": [
160
+ "## Test with batched mode (sparse)"
161
+ ],
162
+ "metadata": {
163
+ "id": "1jAl-ejS5du_"
164
+ }
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "source": [
169
+ "# Batched mode\n",
170
+ "x = torch.randn(4,3,480,640)\n",
171
+ "outputs = xfeat.detectAndCompute(x, top_k = 4096)\n",
172
+ "print(\"# detected features on each batch item:\", [len(o['keypoints']) for o in outputs])"
173
+ ],
174
+ "metadata": {
175
+ "colab": {
176
+ "base_uri": "https://localhost:8080/"
177
+ },
178
+ "id": "lAarS8UH5gyg",
179
+ "outputId": "883f13f8-3fac-48f2-c0a3-656a81b57f2c"
180
+ },
181
+ "execution_count": 4,
182
+ "outputs": [
183
+ {
184
+ "output_type": "stream",
185
+ "name": "stdout",
186
+ "text": [
187
+ "# detected features on each batch item: [4096, 4096, 4096, 4096]\n"
188
+ ]
189
+ }
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "source": [
195
+ "## Matches two images with built-in MNN matcher (sparse mode)"
196
+ ],
197
+ "metadata": {
198
+ "id": "H60iMAlh5nqP"
199
+ }
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "source": [
204
+ "# Match two images with sparse features\n",
205
+ "x1 = torch.randn(1,3,480,640)\n",
206
+ "x2 = torch.randn(1,3,480,640)\n",
207
+ "mkpts_0, mkpts_1 = xfeat.match_xfeat(x1, x2)"
208
+ ],
209
+ "metadata": {
210
+ "id": "6N-ZqoMZ5syf"
211
+ },
212
+ "execution_count": 5,
213
+ "outputs": []
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "source": [
218
+ "## Matches two images with semi-dense matching, and batched mode (batch size = 4) for demonstration purpose"
219
+ ],
220
+ "metadata": {
221
+ "id": "MOV4vZDp5v9_"
222
+ }
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "source": [
227
+ "# Create 4 image pairs\n",
228
+ "x1 = torch.randn(4,3,480,640)\n",
229
+ "x2 = torch.randn(4,3,480,640)\n",
230
+ "\n",
231
+ "#Obtain matches for each batch item\n",
232
+ "matches_list = xfeat.match_xfeat_star(x1, x2, top_k = 5000)\n",
233
+ "print('number of img pairs', len(matches_list))\n",
234
+ "print(matches_list[0].shape) # -> output is (x1,y1,x2,y2)"
235
+ ],
236
+ "metadata": {
237
+ "colab": {
238
+ "base_uri": "https://localhost:8080/"
239
+ },
240
+ "id": "Axe0o6U85zGV",
241
+ "outputId": "e1257959-24fc-4194-b2f1-ee06cf450b24"
242
+ },
243
+ "execution_count": 6,
244
+ "outputs": [
245
+ {
246
+ "output_type": "stream",
247
+ "name": "stdout",
248
+ "text": [
249
+ "number of img pairs 4\n",
250
+ "torch.Size([182, 4])\n"
251
+ ]
252
+ }
253
+ ]
254
+ }
255
+ ]
256
+ }
notebooks/xfeat_matching.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/xfeat_torch_hub.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
realtime_demo.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
3
+ https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
4
+
5
+ Real-time homography estimation demo. Note that scene has to be planar or just rotate the camera for the estimation to work properly.
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+
12
+ from time import time, sleep
13
+ import argparse, sys, tqdm
14
+ import threading
15
+
16
+ from modules.xfeat import XFeat
17
+
18
+ def argparser():
19
+ parser = argparse.ArgumentParser(description="Configurations for the real-time matching demo.")
20
+ parser.add_argument('--width', type=int, default=640, help='Width of the video capture stream.')
21
+ parser.add_argument('--height', type=int, default=480, help='Height of the video capture stream.')
22
+ parser.add_argument('--max_kpts', type=int, default=3_000, help='Maximum number of keypoints.')
23
+ parser.add_argument('--method', type=str, choices=['ORB', 'SIFT', 'XFeat'], default='XFeat', help='Local feature detection method to use.')
24
+ parser.add_argument('--cam', type=int, default=0, help='Webcam device number.')
25
+ return parser.parse_args()
26
+
27
+
28
+ class FrameGrabber(threading.Thread):
29
+ def __init__(self, cap):
30
+ super().__init__()
31
+ self.cap = cap
32
+ _, self.frame = self.cap.read()
33
+ self.running = False
34
+
35
+ def run(self):
36
+ self.running = True
37
+ while self.running:
38
+ ret, frame = self.cap.read()
39
+ if not ret:
40
+ print("Can't receive frame (stream ended?).")
41
+ self.frame = frame
42
+ sleep(0.01)
43
+
44
+ def stop(self):
45
+ self.running = False
46
+ self.cap.release()
47
+
48
+ def get_last_frame(self):
49
+ return self.frame
50
+
51
+ class CVWrapper():
52
+ def __init__(self, mtd):
53
+ self.mtd = mtd
54
+ def detectAndCompute(self, x, mask=None):
55
+ return self.mtd.detectAndCompute(torch.tensor(x).permute(2,0,1).float()[None])[0]
56
+
57
+ class Method:
58
+ def __init__(self, descriptor, matcher):
59
+ self.descriptor = descriptor
60
+ self.matcher = matcher
61
+
62
+ def init_method(method, max_kpts):
63
+ if method == "ORB":
64
+ return Method(descriptor=cv2.ORB_create(max_kpts, fastThreshold=10), matcher=cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True))
65
+ elif method == "SIFT":
66
+ return Method(descriptor=cv2.SIFT_create(max_kpts, contrastThreshold=-1, edgeThreshold=1000), matcher=cv2.BFMatcher(cv2.NORM_L2, crossCheck=True))
67
+ elif method == "XFeat":
68
+ return Method(descriptor=CVWrapper(XFeat(top_k = max_kpts)), matcher=XFeat())
69
+ else:
70
+ raise RuntimeError("Invalid Method.")
71
+
72
+
73
+ class MatchingDemo:
74
+ def __init__(self, args):
75
+ self.args = args
76
+ self.cap = cv2.VideoCapture(args.cam)
77
+ self.width = args.width
78
+ self.height = args.height
79
+ self.ref_frame = None
80
+ self.ref_precomp = [[],[]]
81
+ self.corners = [[50, 50], [640-50, 50], [640-50, 480-50], [50, 480-50]]
82
+ self.current_frame = None
83
+ self.H = None
84
+ self.setup_camera()
85
+
86
+ #Init frame grabber thread
87
+ self.frame_grabber = FrameGrabber(self.cap)
88
+ self.frame_grabber.start()
89
+
90
+ #Homography params
91
+ self.min_inliers = 50
92
+ self.ransac_thr = 4.0
93
+
94
+ #FPS check
95
+ self.FPS = 0
96
+ self.time_list = []
97
+ self.max_cnt = 30 #avg FPS over this number of frames
98
+
99
+ #Set local feature method here -- we expect cv2 or Kornia convention
100
+ self.method = init_method(args.method, max_kpts=args.max_kpts)
101
+
102
+ # Setting up font for captions
103
+ self.font = cv2.FONT_HERSHEY_SIMPLEX
104
+ self.font_scale = 0.9
105
+ self.line_type = cv2.LINE_AA
106
+ self.line_color = (0,255,0)
107
+ self.line_thickness = 3
108
+
109
+ self.window_name = "Real-time matching - Press 's' to set the reference frame."
110
+
111
+ # Removes toolbar and status bar
112
+ cv2.namedWindow(self.window_name, flags=cv2.WINDOW_GUI_NORMAL)
113
+ # Set the window size
114
+ cv2.resizeWindow(self.window_name, self.width*2, self.height*2)
115
+ #Set Mouse Callback
116
+ cv2.setMouseCallback(self.window_name, self.mouse_callback)
117
+
118
+ def setup_camera(self):
119
+ self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
120
+ self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
121
+ self.cap.set(cv2.CAP_PROP_AUTO_EXPOSURE, 3)
122
+ #self.cap.set(cv2.CAP_PROP_EXPOSURE, 200)
123
+ self.cap.set(cv2.CAP_PROP_FPS, 30)
124
+
125
+ if not self.cap.isOpened():
126
+ print("Cannot open camera")
127
+ exit()
128
+
129
+ def draw_quad(self, frame, point_list):
130
+ if len(self.corners) > 1:
131
+ for i in range(len(self.corners) - 1):
132
+ cv2.line(frame, tuple(point_list[i]), tuple(point_list[i + 1]), self.line_color, self.line_thickness, lineType = self.line_type)
133
+ if len(self.corners) == 4: # Close the quadrilateral if 4 corners are defined
134
+ cv2.line(frame, tuple(point_list[3]), tuple(point_list[0]), self.line_color, self.line_thickness, lineType = self.line_type)
135
+
136
+ def mouse_callback(self, event, x, y, flags, param):
137
+ if event == cv2.EVENT_LBUTTONDOWN:
138
+ if len(self.corners) >= 4:
139
+ self.corners = [] # Reset corners if already 4 points were clicked
140
+ self.corners.append((x, y))
141
+
142
+ def putText(self, canvas, text, org, fontFace, fontScale, textColor, borderColor, thickness, lineType):
143
+ # Draw the border
144
+ cv2.putText(img=canvas, text=text, org=org, fontFace=fontFace, fontScale=fontScale,
145
+ color=borderColor, thickness=thickness+2, lineType=lineType)
146
+ # Draw the text
147
+ cv2.putText(img=canvas, text=text, org=org, fontFace=fontFace, fontScale=fontScale,
148
+ color=textColor, thickness=thickness, lineType=lineType)
149
+
150
+ def warp_points(self, points, H, x_offset = 0):
151
+ points_np = np.array(points, dtype='float32').reshape(-1,1,2)
152
+
153
+ warped_points_np = cv2.perspectiveTransform(points_np, H).reshape(-1, 2)
154
+ warped_points_np[:, 0] += x_offset
155
+ warped_points = warped_points_np.astype(int).tolist()
156
+
157
+ return warped_points
158
+
159
+ def create_top_frame(self):
160
+ top_frame_canvas = np.zeros((480, 1280, 3), dtype=np.uint8)
161
+ top_frame = np.hstack((self.ref_frame, self.current_frame))
162
+ color = (3, 186, 252)
163
+ cv2.rectangle(top_frame, (2, 2), (self.width*2-2, self.height-2), color, 5) # Orange color line as a separator
164
+ top_frame_canvas[0:self.height, 0:self.width*2] = top_frame
165
+
166
+ # Adding captions on the top frame canvas
167
+ self.putText(canvas=top_frame_canvas, text="Reference Frame:", org=(10, 30), fontFace=self.font,
168
+ fontScale=self.font_scale, textColor=(0,0,0), borderColor=color, thickness=1, lineType=self.line_type)
169
+
170
+ self.putText(canvas=top_frame_canvas, text="Target Frame:", org=(650, 30), fontFace=self.font,
171
+ fontScale=self.font_scale, textColor=(0,0,0), borderColor=color, thickness=1, lineType=self.line_type)
172
+
173
+ self.draw_quad(top_frame_canvas, self.corners)
174
+
175
+ return top_frame_canvas
176
+
177
+ def process(self):
178
+ # Create a blank canvas for the top frame
179
+ top_frame_canvas = self.create_top_frame()
180
+
181
+ # Match features and draw matches on the bottom frame
182
+ bottom_frame = self.match_and_draw(self.ref_frame, self.current_frame)
183
+
184
+ # Draw warped corners
185
+ if self.H is not None and len(self.corners) > 1:
186
+ self.draw_quad(top_frame_canvas, self.warp_points(self.corners, self.H, self.width))
187
+
188
+ # Stack top and bottom frames vertically on the final canvas
189
+ canvas = np.vstack((top_frame_canvas, bottom_frame))
190
+
191
+ cv2.imshow(self.window_name, canvas)
192
+
193
+ def match_and_draw(self, ref_frame, current_frame):
194
+
195
+ matches, good_matches = [], []
196
+ kp1, kp2 = [], []
197
+ points1, points2 = [], []
198
+
199
+ # Detect and compute features
200
+ if self.args.method in ['SIFT', 'ORB']:
201
+ kp1, des1 = self.ref_precomp
202
+ kp2, des2 = self.method.descriptor.detectAndCompute(current_frame, None)
203
+ else:
204
+ current = self.method.descriptor.detectAndCompute(current_frame)
205
+ kpts1, descs1 = self.ref_precomp['keypoints'], self.ref_precomp['descriptors']
206
+ kpts2, descs2 = current['keypoints'], current['descriptors']
207
+ idx0, idx1 = self.method.matcher.match(descs1, descs2, 0.82)
208
+ points1 = kpts1[idx0].cpu().numpy()
209
+ points2 = kpts2[idx1].cpu().numpy()
210
+
211
+ if len(kp1) > 10 and len(kp2) > 10 and self.args.method in ['SIFT', 'ORB']:
212
+ # Match descriptors
213
+ matches = self.method.matcher.match(des1, des2)
214
+
215
+ if len(matches) > 10:
216
+ points1 = np.zeros((len(matches), 2), dtype=np.float32)
217
+ points2 = np.zeros((len(matches), 2), dtype=np.float32)
218
+
219
+ for i, match in enumerate(matches):
220
+ points1[i, :] = kp1[match.queryIdx].pt
221
+ points2[i, :] = kp2[match.trainIdx].pt
222
+
223
+ if len(points1) > 10 and len(points2) > 10:
224
+ # Find homography
225
+ self.H, inliers = cv2.findHomography(points1, points2, cv2.USAC_MAGSAC, self.ransac_thr, maxIters=700, confidence=0.995)
226
+ inliers = inliers.flatten() > 0
227
+
228
+ if inliers.sum() < self.min_inliers:
229
+ self.H = None
230
+
231
+ if self.args.method in ["SIFT", "ORB"]:
232
+ good_matches = [m for i,m in enumerate(matches) if inliers[i]]
233
+ else:
234
+ kp1 = [cv2.KeyPoint(p[0],p[1], 5) for p in points1[inliers]]
235
+ kp2 = [cv2.KeyPoint(p[0],p[1], 5) for p in points2[inliers]]
236
+ good_matches = [cv2.DMatch(i,i,0) for i in range(len(kp1))]
237
+
238
+ # Draw matches
239
+ matched_frame = cv2.drawMatches(ref_frame, kp1, current_frame, kp2, good_matches, None, matchColor=(0, 200, 0), flags=2)
240
+
241
+ else:
242
+ matched_frame = np.hstack([ref_frame, current_frame])
243
+
244
+ color = (240, 89, 169)
245
+
246
+ # Add a colored rectangle to separate from the top frame
247
+ cv2.rectangle(matched_frame, (2, 2), (self.width*2-2, self.height-2), color, 5)
248
+
249
+ # Adding captions on the top frame canvas
250
+ self.putText(canvas=matched_frame, text="%s Matches: %d"%(self.args.method, len(good_matches)), org=(10, 30), fontFace=self.font,
251
+ fontScale=self.font_scale, textColor=(0,0,0), borderColor=color, thickness=1, lineType=self.line_type)
252
+
253
+ # Adding captions on the top frame canvas
254
+ self.putText(canvas=matched_frame, text="FPS (registration): {:.1f}".format(self.FPS), org=(650, 30), fontFace=self.font,
255
+ fontScale=self.font_scale, textColor=(0,0,0), borderColor=color, thickness=1, lineType=self.line_type)
256
+
257
+ return matched_frame
258
+
259
+ def main_loop(self):
260
+ self.current_frame = self.frame_grabber.get_last_frame()
261
+ self.ref_frame = self.current_frame.copy()
262
+ self.ref_precomp = self.method.descriptor.detectAndCompute(self.ref_frame, None) #Cache ref features
263
+
264
+ while True:
265
+ if self.current_frame is None:
266
+ break
267
+
268
+ t0 = time()
269
+ self.process()
270
+
271
+ key = cv2.waitKey(1)
272
+ if key == ord('q'):
273
+ break
274
+ elif key == ord('s'):
275
+ self.ref_frame = self.current_frame.copy() # Update reference frame
276
+ self.ref_precomp = self.method.descriptor.detectAndCompute(self.ref_frame, None) #Cache ref features
277
+
278
+ self.current_frame = self.frame_grabber.get_last_frame()
279
+
280
+ #Measure avg. FPS
281
+ self.time_list.append(time()-t0)
282
+ if len(self.time_list) > self.max_cnt:
283
+ self.time_list.pop(0)
284
+ self.FPS = 1.0 / np.array(self.time_list).mean()
285
+
286
+ self.cleanup()
287
+
288
+ def cleanup(self):
289
+ self.frame_grabber.stop()
290
+ self.cap.release()
291
+ cv2.destroyAllWindows()
292
+
293
+ if __name__ == "__main__":
294
+ demo = MatchingDemo(args = argparser())
295
+ demo.main_loop()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy
2
+ opencv-python-headless
3
+ tqdm
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Shared utility functions for OmniGlue."""
16
+
17
+ import cv2
18
+ import numpy as np
19
+
20
+ from typing import Optional
21
+
22
+
23
+ def visualize_matches(
24
+ image0: np.ndarray,
25
+ image1: np.ndarray,
26
+ kp0: np.ndarray,
27
+ kp1: np.ndarray,
28
+ match_matrix: np.ndarray,
29
+ match_labels: Optional[np.ndarray] = None,
30
+ show_keypoints: bool = False,
31
+ highlight_unmatched: bool = False,
32
+ title: Optional[str] = None,
33
+ line_width: int = 1,
34
+ circle_radius: int = 4,
35
+ circle_thickness: int = 2,
36
+ rng: Optional['np.random.Generator'] = None,
37
+ ):
38
+ """Generates visualization of keypoints and matches for two images.
39
+
40
+ Stacks image0 and image1 horizontally. In case the two images have different
41
+ heights, scales image1 (and its keypoints) to match image0's height. Note
42
+ that keypoints must be in (x, y) format, NOT (row, col). If match_matrix
43
+ includes unmatched dustbins, the dustbins will be removed before visualizing
44
+ matches.
45
+
46
+ Args:
47
+ image0: (H, W, 3) array containing image0 contents.
48
+ image1: (H, W, 3) array containing image1 contents.
49
+ kp0: (N, 2) array where each row represents (x, y) coordinates of keypoints
50
+ in image0.
51
+ kp1: (M, 2) array, where each row represents (x, y) coordinates of keypoints
52
+ in image1.
53
+ match_matrix: (N, M) binary array, where values are non-zero for keypoint
54
+ indices making up a match.
55
+ match_labels: (N, M) binary array, where values are non-zero for keypoint
56
+ indices making up a ground-truth match. When None, matches from
57
+ 'match_matrix' are colored randomly. Otherwise, matches from
58
+ 'match_matrix' are colored according to accuracy (compared to labels).
59
+ show_keypoints: if True, all image0 and image1 keypoints (including
60
+ unmatched ones) are visualized.
61
+ highlight_unmatched: if True, highlights unmatched keypoints in blue.
62
+ title: if not None, adds title text to top left of visualization.
63
+ line_width: width of correspondence line, in pixels.
64
+ circle_radius: radius of keypoint circles, if visualized.
65
+ circle_thickness: thickness of keypoint circles, if visualized.
66
+ rng: np random number generator to generate the line colors.
67
+
68
+ Returns:
69
+ Numpy array of image0 and image1 side-by-side, with lines between matches
70
+ according to match_matrix. If show_keypoints is True, keypoints from both
71
+ images are also visualized.
72
+ """
73
+ # initialize RNG
74
+ if rng is None:
75
+ rng = np.random.default_rng()
76
+
77
+ # Make copy of input param that may be modified in this function.
78
+ kp1 = np.copy(kp1)
79
+
80
+ # Detect unmatched dustbins.
81
+ has_unmatched_dustbins = (match_matrix.shape[0] == kp0.shape[0] + 1) and (
82
+ match_matrix.shape[1] == kp1.shape[0] + 1
83
+ )
84
+
85
+ # If necessary, resize image1 so that the pair can be stacked horizontally.
86
+ height0 = image0.shape[0]
87
+ height1 = image1.shape[0]
88
+ if height0 != height1:
89
+ scale_factor = height0 / height1
90
+ if scale_factor <= 1.0:
91
+ interp_method = cv2.INTER_AREA
92
+ else:
93
+ interp_method = cv2.INTER_LINEAR
94
+ new_dim1 = (int(image1.shape[1] * scale_factor), height0)
95
+ image1 = cv2.resize(image1, new_dim1, interpolation=interp_method)
96
+ kp1 *= scale_factor
97
+
98
+ # Create side-by-side image and add lines for all matches.
99
+ viz = cv2.hconcat([image0, image1])
100
+ w0 = image0.shape[1]
101
+ matches = np.argwhere(
102
+ match_matrix[:-1, :-1] if has_unmatched_dustbins else match_matrix
103
+ )
104
+ for match in matches:
105
+ pt0 = (int(kp0[match[0], 0]), int(kp0[match[0], 1]))
106
+ pt1 = (int(kp1[match[1], 0] + w0), int(kp1[match[1], 1]))
107
+ if match_labels is None:
108
+ color = tuple(rng.integers(0, 255, size=3).tolist())
109
+ else:
110
+ if match_labels[match[0], match[1]]:
111
+ color = (0, 255, 0)
112
+ else:
113
+ color = (255, 0, 0)
114
+ cv2.line(viz, pt0, pt1, color, line_width)
115
+
116
+ # Optionally, add circles to output image to represent each keypoint.
117
+ if show_keypoints:
118
+ for i in range(np.shape(kp0)[0]):
119
+ kp = kp0[i, :]
120
+ if highlight_unmatched and has_unmatched_dustbins and match_matrix[i, -1]:
121
+ cv2.circle(
122
+ viz,
123
+ tuple(kp.astype(np.int32).tolist()),
124
+ circle_radius,
125
+ (255, 0, 0),
126
+ circle_thickness,
127
+ )
128
+ else:
129
+ cv2.circle(
130
+ viz,
131
+ tuple(kp.astype(np.int32).tolist()),
132
+ circle_radius,
133
+ (0, 0, 255),
134
+ circle_thickness,
135
+ )
136
+ for j in range(np.shape(kp1)[0]):
137
+ kp = kp1[j, :]
138
+ kp[0] += w0
139
+ if highlight_unmatched and has_unmatched_dustbins and match_matrix[-1, j]:
140
+ cv2.circle(
141
+ viz,
142
+ tuple(kp.astype(np.int32).tolist()),
143
+ circle_radius,
144
+ (255, 0, 0),
145
+ circle_thickness,
146
+ )
147
+ else:
148
+ cv2.circle(
149
+ viz,
150
+ tuple(kp.astype(np.int32).tolist()),
151
+ circle_radius,
152
+ (0, 0, 255),
153
+ circle_thickness,
154
+ )
155
+ if title is not None:
156
+ viz = cv2.putText(
157
+ viz,
158
+ title,
159
+ (5, 30),
160
+ cv2.FONT_HERSHEY_SIMPLEX,
161
+ 1,
162
+ (0, 0, 255),
163
+ 2,
164
+ cv2.LINE_AA,
165
+ )
166
+ return viz
weights/xfeat.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5187fd7bedd26c7fe6acc9685444493a165a35ecc087b33c2db3627f3ea10b
3
+ size 6247949