Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +6 -0
- .gradio/certificate.pem +31 -0
- LICENSE +674 -0
- README.md +157 -12
- README_DEMO.md +76 -0
- app.py +247 -0
- backbone.sh +179 -0
- config_ssl_upload.py +177 -0
- data/data_processing.ipynb +0 -0
- dataloaders/GenericSuperDatasetv2.py +445 -0
- dataloaders/ManualAnnoDatasetv2.py +756 -0
- dataloaders/PolypDataset.py +548 -0
- dataloaders/PolypTransforms.py +626 -0
- dataloaders/SimpleDataset.py +61 -0
- dataloaders/__init__.py +0 -0
- dataloaders/augutils.py +224 -0
- dataloaders/common.py +263 -0
- dataloaders/dataset_utils.py +128 -0
- dataloaders/dev_customized_med.py +250 -0
- dataloaders/image_transforms.py +362 -0
- dataloaders/niftiio.py +48 -0
- models/ProtoMedSAM.py +267 -0
- models/ProtoSAM.py +708 -0
- models/SamWrapper.py +68 -0
- models/__init__.py +0 -0
- models/__pycache__/ProtoSAM.cpython-312.pyc +0 -0
- models/__pycache__/SamWrapper.cpython-312.pyc +0 -0
- models/__pycache__/__init__.cpython-312.pyc +0 -0
- models/__pycache__/alpmodule.cpython-312.pyc +0 -0
- models/__pycache__/grid_proto_fewshot.cpython-312.pyc +0 -0
- models/alpmodule.py +199 -0
- models/backbone/__init__.py +0 -0
- models/backbone/__pycache__/__init__.cpython-312.pyc +0 -0
- models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc +0 -0
- models/backbone/torchvision_backbones.py +58 -0
- models/grid_proto_fewshot.py +427 -0
- models/segment_anything/__init__.py +15 -0
- models/segment_anything/__pycache__/__init__.cpython-312.pyc +0 -0
- models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc +0 -0
- models/segment_anything/__pycache__/build_sam.cpython-312.pyc +0 -0
- models/segment_anything/__pycache__/predictor.cpython-312.pyc +0 -0
- models/segment_anything/automatic_mask_generator.py +380 -0
- models/segment_anything/build_sam.py +107 -0
- models/segment_anything/modeling/__init__.py +11 -0
- models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc +0 -0
- models/segment_anything/modeling/__pycache__/common.cpython-312.pyc +0 -0
- models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc +0 -0
- models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc +0 -0
- models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc +0 -0
- models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc +0 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
exps/
|
3 |
+
.vscode/
|
4 |
+
debug/
|
5 |
+
test_*/
|
6 |
+
pretrained_model/*
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
README.md
CHANGED
@@ -1,12 +1,157 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: LoGoSAM_demo
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.29.0
|
6 |
+
---
|
7 |
+
# ProtoSAM - One shot segmentation with foundational models
|
8 |
+
|
9 |
+
Link to our paper [here](https://arxiv.org/abs/2407.07042). \
|
10 |
+
This work is the successor of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning) (Link to [Paper](arxiv.org/abs/2403.03273)).
|
11 |
+
|
12 |
+
## Demo Application
|
13 |
+
|
14 |
+
A Gradio-based demo application is now available for interactive inference with ProtoSAM. You can upload your own images and masks to test the model. See [README_DEMO.md](README_DEMO.md) for instructions on running the demo.
|
15 |
+
|
16 |
+
## Abstract
|
17 |
+
This work introduces a new framework, ProtoSAM, for one-shot image segmentation. It combines DINOv2, a vision transformer that extracts features from images, with an Adaptive Local Prototype Pooling (ALP) layer, which generates prototypes from a support image and its mask. These prototypes are used to create an initial coarse segmentation mask by comparing the query image's features with the prototypes.
|
18 |
+
Following the extraction of an initial mask, we use numerical methods to generate prompts, such as points and bounding boxes, which are then input into the Segment Anything Model (SAM), a prompt-based segmentation model trained on natural images. This allows segmenting new classes automatically and effectively without the need for additional training.
|
19 |
+
|
20 |
+
## How To Run
|
21 |
+
### 1. Data preprocessing
|
22 |
+
#### 1.1 CT and MRI Dataset
|
23 |
+
Please see the notebook `data/data_processing.ipynb` for instructions.
|
24 |
+
For convenience i've compiled the data processing instructions from https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation to a single notebook. \
|
25 |
+
The CT dataset is available here: https://www.synapse.org/Synapse:syn3553734 \
|
26 |
+
The MRI dataset is availabel here: https://chaos.grand-challenge.org
|
27 |
+
|
28 |
+
run `./data/CHAOST2/dcm_img_to_nii.sh` to convert dicom images to nifti files.
|
29 |
+
|
30 |
+
#### 1.2 Polyp Dataset
|
31 |
+
Data is available here: https://www.kaggle.com/datasets/hngphmv/polypdataset?select=train.csv
|
32 |
+
|
33 |
+
Put the dataset `data/PolypDataset/`
|
34 |
+
|
35 |
+
### 2. Running
|
36 |
+
#### 2.1 (Optional) Training and Validation of the coarse segmentation networks
|
37 |
+
```
|
38 |
+
./backbone.sh [MODE] [MODALITY] [LABEL_SET]
|
39 |
+
```
|
40 |
+
MODE - validation or training \
|
41 |
+
MODALITY - ct or mri \
|
42 |
+
LABEL_SET - 0 (kidneys), 1 (liver spleen)
|
43 |
+
|
44 |
+
for example:
|
45 |
+
```
|
46 |
+
./backbone.sh training mri 1
|
47 |
+
```
|
48 |
+
Please refer to `backbone.sh` for further configurations.
|
49 |
+
|
50 |
+
#### 2.1 Running ProtoSAM
|
51 |
+
Put all SAM checkpoint like sam_vit_b.pth, sam_vit_h.pth, medsam_vit_b.pth into the `pretrained_model` directory. \
|
52 |
+
Checkpoints are available at [SAM](https://github.com/facebookresearch/segment-anything) and [MedSAM](https://github.com/bowang-lab/MedSAM).
|
53 |
+
|
54 |
+
```
|
55 |
+
./run_protosam.sh [MODALITY] [LABEL_SET]
|
56 |
+
```
|
57 |
+
MODALITY - ct, mri or polyp \
|
58 |
+
LABEL_SET (only relevant if doing ct or mri) - 0 (kidneys), 1 (liver spleen)
|
59 |
+
Please refer to the `run_protosam.sh` script for further configurations.
|
60 |
+
|
61 |
+
|
62 |
+
## Acknowledgements
|
63 |
+
This work is largely based on [ALPNet](https://github.com/cheng-01037/Self-supervised-Fewshot-Medical-Image-Segmentation), [DINOv2](https://github.com/facebookresearch/dinov2), [SAM](https://github.com/facebookresearch/segment-anything) and is a continuation of [DINOv2-based-Self-Supervised-Learning](https://github.com/levayz/DINOv2-based-Self-Supervised-Learning).
|
64 |
+
|
65 |
+
## Cite
|
66 |
+
If you found this repo useful, please consider giving us a citation and a star!
|
67 |
+
|
68 |
+
```bibtex
|
69 |
+
@article{ayzenberg2024protosam,
|
70 |
+
title={ProtoSAM-One Shot Medical Image Segmentation With Foundational Models},
|
71 |
+
author={Ayzenberg, Lev and Giryes, Raja and Greenspan, Hayit},
|
72 |
+
journal={arXiv preprint arXiv:2407.07042},
|
73 |
+
year={2024}
|
74 |
+
}
|
75 |
+
|
76 |
+
@misc{ayzenberg2024dinov2,
|
77 |
+
title={DINOv2 based Self Supervised Learning For Few Shot Medical Image Segmentation},
|
78 |
+
author={Lev Ayzenberg and Raja Giryes and Hayit Greenspan},
|
79 |
+
year={2024},
|
80 |
+
eprint={2403.03273},
|
81 |
+
archivePrefix={arXiv},
|
82 |
+
primaryClass={cs.CV}
|
83 |
+
}
|
84 |
+
|
85 |
+
```
|
86 |
+
|
87 |
+
# ProtoSAM Segmentation Demo
|
88 |
+
|
89 |
+
This Streamlit application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
|
90 |
+
|
91 |
+
## Requirements
|
92 |
+
|
93 |
+
- Python 3.8 or higher
|
94 |
+
- CUDA-compatible GPU
|
95 |
+
- Required Python packages (see `requirements.txt`)
|
96 |
+
|
97 |
+
## Setup Instructions
|
98 |
+
|
99 |
+
1. Clone this repository:
|
100 |
+
```bash
|
101 |
+
git clone <your-repository-url>
|
102 |
+
cd <repository-name>
|
103 |
+
```
|
104 |
+
|
105 |
+
2. Create and activate a virtual environment (optional but recommended):
|
106 |
+
```bash
|
107 |
+
python -m venv venv
|
108 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
109 |
+
```
|
110 |
+
|
111 |
+
3. Install the required dependencies:
|
112 |
+
```bash
|
113 |
+
pip install -r requirements.txt
|
114 |
+
```
|
115 |
+
|
116 |
+
4. Download the pretrained models:
|
117 |
+
```bash
|
118 |
+
mkdir -p pretrained_model
|
119 |
+
# Download SAM ViT-H model
|
120 |
+
wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
121 |
+
mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
|
122 |
+
```
|
123 |
+
|
124 |
+
5. Update the model path in `app.py`:
|
125 |
+
- Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
|
126 |
+
|
127 |
+
## Running the App
|
128 |
+
|
129 |
+
Start the Streamlit app with:
|
130 |
+
```bash
|
131 |
+
streamlit run app.py
|
132 |
+
```
|
133 |
+
|
134 |
+
This will open a browser window with the interface for the segmentation demo.
|
135 |
+
|
136 |
+
## Usage
|
137 |
+
|
138 |
+
1. Upload a query image (the image you want to segment)
|
139 |
+
2. Upload a support image (an example image with a similar object)
|
140 |
+
3. Upload a support mask (the segmentation mask for the support image)
|
141 |
+
4. Use the sidebar to configure the model parameters if needed
|
142 |
+
5. Click "Run Inference" to generate the segmentation result
|
143 |
+
|
144 |
+
## Model Configuration
|
145 |
+
|
146 |
+
The app allows you to configure several model parameters via the sidebar:
|
147 |
+
- Use Bounding Box: Enable/disable bounding box input
|
148 |
+
- Use Points: Enable/disable point input
|
149 |
+
- Use Mask: Enable/disable mask input
|
150 |
+
- Use CCA: Enable/disable Connected Component Analysis
|
151 |
+
- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
|
152 |
+
|
153 |
+
## Notes
|
154 |
+
|
155 |
+
- This demo requires a GPU with CUDA support
|
156 |
+
- Large images may require more GPU memory
|
157 |
+
- For optimal results, use high-quality support images and masks
|
README_DEMO.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ProtoSAM Segmentation Demo
|
2 |
+
|
3 |
+
This Gradio application demonstrates the capabilities of the ProtoSAM model for few-shot segmentation. Users can upload a query image, support image, and support mask to generate a segmentation prediction.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Python 3.8 or higher
|
8 |
+
- CUDA-compatible GPU
|
9 |
+
- Required Python packages (see `requirements.txt`)
|
10 |
+
|
11 |
+
## Setup Instructions
|
12 |
+
|
13 |
+
1. Clone this repository:
|
14 |
+
```bash
|
15 |
+
git clone <your-repository-url>
|
16 |
+
cd <repository-name>
|
17 |
+
```
|
18 |
+
|
19 |
+
2. Create and activate a virtual environment (optional but recommended):
|
20 |
+
```bash
|
21 |
+
python -m venv venv
|
22 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
23 |
+
```
|
24 |
+
|
25 |
+
3. Install the required dependencies:
|
26 |
+
```bash
|
27 |
+
pip install -r requirements.txt
|
28 |
+
```
|
29 |
+
|
30 |
+
4. Download the pretrained models:
|
31 |
+
```bash
|
32 |
+
mkdir -p pretrained_model
|
33 |
+
# Download SAM ViT-H model
|
34 |
+
wget -P pretrained_model https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
35 |
+
mv pretrained_model/sam_vit_h_4b8939.pth pretrained_model/sam_vit_h.pth
|
36 |
+
```
|
37 |
+
|
38 |
+
5. Update the model path in `app.py`:
|
39 |
+
- Set the `reload_model_path` in the config dictionary to the path of your trained ProtoSAM model.
|
40 |
+
|
41 |
+
## Running the App
|
42 |
+
|
43 |
+
Start the app with:
|
44 |
+
```bash
|
45 |
+
./run_demo.sh
|
46 |
+
```
|
47 |
+
|
48 |
+
Or run it directly with:
|
49 |
+
```bash
|
50 |
+
python app.py
|
51 |
+
```
|
52 |
+
|
53 |
+
This will start the server and provide a link to access the demo in your browser.
|
54 |
+
|
55 |
+
## Usage
|
56 |
+
|
57 |
+
1. Upload a query image (the image you want to segment)
|
58 |
+
2. Upload a support image (an example image with a similar object)
|
59 |
+
3. Upload a support mask (the segmentation mask for the support image)
|
60 |
+
4. Configure the model parameters using the checkboxes
|
61 |
+
5. Click "Run Inference" to generate the segmentation result
|
62 |
+
|
63 |
+
## Model Configuration
|
64 |
+
|
65 |
+
The app allows you to configure several model parameters:
|
66 |
+
- Use Bounding Box: Enable/disable bounding box input
|
67 |
+
- Use Points: Enable/disable point input
|
68 |
+
- Use Mask: Enable/disable mask input
|
69 |
+
- Use CCA: Enable/disable Connected Component Analysis
|
70 |
+
- Coarse Prediction Only: Use only the coarse segmentation model without SAM refinement
|
71 |
+
|
72 |
+
## Notes
|
73 |
+
|
74 |
+
- This demo requires a GPU with CUDA support
|
75 |
+
- Large images may require more GPU memory
|
76 |
+
- For optimal results, use high-quality support images and masks
|
app.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
from models.ProtoSAM import ProtoSAM, ALPNetWrapper, InputFactory, TYPE_ALPNET
|
9 |
+
from models.grid_proto_fewshot import FewShotSeg
|
10 |
+
from models.segment_anything.utils.transforms import ResizeLongestSide
|
11 |
+
|
12 |
+
# Set environment variables for model caching
|
13 |
+
os.environ['TORCH_HOME'] = "./pretrained_model"
|
14 |
+
|
15 |
+
# Function to load the model
|
16 |
+
def load_model(config):
|
17 |
+
# Initial segmentation model
|
18 |
+
alpnet = FewShotSeg(
|
19 |
+
config["input_size"][0],
|
20 |
+
config["reload_model_path"],
|
21 |
+
config["model"]
|
22 |
+
)
|
23 |
+
alpnet.cuda()
|
24 |
+
base_model = ALPNetWrapper(alpnet)
|
25 |
+
|
26 |
+
# ProtoSAM model
|
27 |
+
sam_checkpoint = "pretrained_model/sam_vit_h.pth"
|
28 |
+
model = ProtoSAM(
|
29 |
+
image_size=(1024, 1024),
|
30 |
+
coarse_segmentation_model=base_model,
|
31 |
+
use_bbox=config["use_bbox"],
|
32 |
+
use_points=config["use_points"],
|
33 |
+
use_mask=config["use_mask"],
|
34 |
+
debug=False,
|
35 |
+
num_points_for_sam=1,
|
36 |
+
use_cca=config["do_cca"],
|
37 |
+
point_mode=config["point_mode"],
|
38 |
+
use_sam_trans=True,
|
39 |
+
coarse_pred_only=config["coarse_pred_only"],
|
40 |
+
sam_pretrained_path=sam_checkpoint,
|
41 |
+
use_neg_points=config["use_neg_points"],
|
42 |
+
)
|
43 |
+
model = model.to(torch.device("cuda"))
|
44 |
+
model.eval()
|
45 |
+
return model
|
46 |
+
|
47 |
+
# Function to preprocess images
|
48 |
+
def preprocess_image(image, transform):
|
49 |
+
if isinstance(image, np.ndarray):
|
50 |
+
image_np = image
|
51 |
+
else:
|
52 |
+
# Convert PIL Image to numpy array
|
53 |
+
image_np = np.array(image)
|
54 |
+
|
55 |
+
# Convert to RGB if grayscale
|
56 |
+
if len(image_np.shape) == 2:
|
57 |
+
image_np = np.stack([image_np] * 3, axis=2)
|
58 |
+
elif image_np.shape[2] == 1:
|
59 |
+
image_np = np.concatenate([image_np] * 3, axis=2)
|
60 |
+
|
61 |
+
# Apply transforms
|
62 |
+
image_tensor = transform(image_np).unsqueeze(0)
|
63 |
+
return image_tensor
|
64 |
+
|
65 |
+
# Function to create overlay visualization
|
66 |
+
def create_overlay(query_image, prediction, colormap='YlOrRd'):
|
67 |
+
"""
|
68 |
+
Create an overlay of the prediction on the query image
|
69 |
+
"""
|
70 |
+
# Convert tensors to numpy arrays for visualization
|
71 |
+
if isinstance(query_image, torch.Tensor):
|
72 |
+
query_image = query_image.cpu().squeeze().numpy()
|
73 |
+
|
74 |
+
if isinstance(prediction, torch.Tensor):
|
75 |
+
prediction = prediction.cpu().squeeze().numpy()
|
76 |
+
|
77 |
+
# Normalize image for visualization
|
78 |
+
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min() + 1e-8)
|
79 |
+
|
80 |
+
# Ensure binary mask
|
81 |
+
prediction = (prediction > 0).astype(np.float32)
|
82 |
+
|
83 |
+
# Create mask overlay
|
84 |
+
mask_cmap = plt.cm.get_cmap(colormap)
|
85 |
+
pred_rgba = mask_cmap(prediction)
|
86 |
+
pred_rgba[..., 3] = prediction * 0.7 # Set alpha channel
|
87 |
+
|
88 |
+
# Create matplotlib figure for overlay
|
89 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
90 |
+
|
91 |
+
# Handle grayscale vs RGB images
|
92 |
+
if len(query_image.shape) == 2:
|
93 |
+
ax.imshow(query_image, cmap='gray')
|
94 |
+
else:
|
95 |
+
if query_image.shape[0] == 3: # Channel-first format
|
96 |
+
query_image = np.transpose(query_image, (1, 2, 0))
|
97 |
+
ax.imshow(query_image)
|
98 |
+
|
99 |
+
ax.imshow(pred_rgba)
|
100 |
+
ax.axis('off')
|
101 |
+
plt.tight_layout()
|
102 |
+
|
103 |
+
# Convert to PIL Image
|
104 |
+
fig.canvas.draw()
|
105 |
+
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
106 |
+
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
107 |
+
plt.close(fig)
|
108 |
+
|
109 |
+
return img
|
110 |
+
|
111 |
+
# Model configuration
|
112 |
+
config = {
|
113 |
+
"input_size": [224],
|
114 |
+
"reload_model_path": "path/to/your/model.pth", # Update with your model path
|
115 |
+
"model": {"encoder": "resnet50", "decoder": "pspnet"},
|
116 |
+
"use_bbox": True,
|
117 |
+
"use_points": True,
|
118 |
+
"use_mask": True,
|
119 |
+
"do_cca": True,
|
120 |
+
"point_mode": "extreme",
|
121 |
+
"coarse_pred_only": False,
|
122 |
+
"use_neg_points": False,
|
123 |
+
"base_model": TYPE_ALPNET
|
124 |
+
}
|
125 |
+
|
126 |
+
# Function to run inference
|
127 |
+
def run_inference(query_image, support_image, support_mask, use_bbox, use_points, use_mask, use_cca, coarse_pred_only):
|
128 |
+
try:
|
129 |
+
# Update config based on user selections
|
130 |
+
config["use_bbox"] = use_bbox
|
131 |
+
config["use_points"] = use_points
|
132 |
+
config["use_mask"] = use_mask
|
133 |
+
config["do_cca"] = use_cca
|
134 |
+
config["coarse_pred_only"] = coarse_pred_only
|
135 |
+
|
136 |
+
# Check if CUDA is available
|
137 |
+
if not torch.cuda.is_available():
|
138 |
+
return None, "CUDA is not available. This demo requires GPU support."
|
139 |
+
|
140 |
+
# Load the model
|
141 |
+
model = load_model(config)
|
142 |
+
|
143 |
+
# Preprocess images
|
144 |
+
sam_trans = ResizeLongestSide(1024)
|
145 |
+
|
146 |
+
# Transform for images
|
147 |
+
transform = transforms.Compose([
|
148 |
+
transforms.ToTensor(),
|
149 |
+
transforms.Resize((1024, 1024), antialias=True),
|
150 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
151 |
+
])
|
152 |
+
|
153 |
+
# Process query image
|
154 |
+
query_img_tensor = preprocess_image(query_image, transform)
|
155 |
+
|
156 |
+
# Process support image
|
157 |
+
support_img_tensor = preprocess_image(support_image, transform)
|
158 |
+
|
159 |
+
# Process support mask (should be binary)
|
160 |
+
support_mask_np = np.array(support_mask)
|
161 |
+
support_mask_np = (support_mask_np > 127).astype(np.float32) # Binarize mask
|
162 |
+
support_mask_tensor = torch.from_numpy(support_mask_np).unsqueeze(0).unsqueeze(0)
|
163 |
+
support_mask_tensor = torch.nn.functional.interpolate(
|
164 |
+
support_mask_tensor, size=(1024, 1024), mode='nearest'
|
165 |
+
)
|
166 |
+
|
167 |
+
# Prepare model inputs
|
168 |
+
support_images = [support_img_tensor.cuda()]
|
169 |
+
support_masks = [support_mask_tensor.cuda()]
|
170 |
+
|
171 |
+
# Create model input
|
172 |
+
coarse_model_input = InputFactory.create_input(
|
173 |
+
input_type=config["base_model"],
|
174 |
+
query_image=query_img_tensor.cuda(),
|
175 |
+
support_images=support_images,
|
176 |
+
support_labels=support_masks,
|
177 |
+
isval=True,
|
178 |
+
val_wsize=3,
|
179 |
+
original_sz=query_img_tensor.shape[-2:],
|
180 |
+
img_sz=query_img_tensor.shape[-2:],
|
181 |
+
gts=None,
|
182 |
+
)
|
183 |
+
coarse_model_input.to(torch.device("cuda"))
|
184 |
+
|
185 |
+
# Run inference
|
186 |
+
with torch.no_grad():
|
187 |
+
query_pred, scores = model(
|
188 |
+
query_img_tensor.cuda(), coarse_model_input, degrees_rotate=0
|
189 |
+
)
|
190 |
+
|
191 |
+
# Create overlay visualization
|
192 |
+
result_image = create_overlay(query_img_tensor, query_pred)
|
193 |
+
|
194 |
+
confidence_score = np.mean(scores)
|
195 |
+
return result_image, f"Confidence Score: {confidence_score:.4f}"
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
return None, f"Error during inference: {str(e)}"
|
199 |
+
|
200 |
+
# Define the Gradio interface
|
201 |
+
def create_interface():
|
202 |
+
with gr.Blocks(title="ProtoSAM Segmentation Demo") as demo:
|
203 |
+
gr.Markdown("# ProtoSAM Segmentation Demo")
|
204 |
+
gr.Markdown("Upload a query image, support image, and support mask to generate a segmentation prediction.")
|
205 |
+
|
206 |
+
with gr.Row():
|
207 |
+
with gr.Column():
|
208 |
+
query_image = gr.Image(label="Query Image", type="pil")
|
209 |
+
support_image = gr.Image(label="Support Image", type="pil")
|
210 |
+
support_mask = gr.Image(label="Support Mask", type="pil")
|
211 |
+
|
212 |
+
with gr.Column():
|
213 |
+
result_image = gr.Image(label="Prediction Result")
|
214 |
+
result_text = gr.Textbox(label="Result Information")
|
215 |
+
|
216 |
+
with gr.Row():
|
217 |
+
with gr.Column():
|
218 |
+
use_bbox = gr.Checkbox(label="Use Bounding Box", value=True)
|
219 |
+
use_points = gr.Checkbox(label="Use Points", value=True)
|
220 |
+
use_mask = gr.Checkbox(label="Use Mask", value=True)
|
221 |
+
|
222 |
+
with gr.Column():
|
223 |
+
use_cca = gr.Checkbox(label="Use CCA", value=True)
|
224 |
+
coarse_pred_only = gr.Checkbox(label="Coarse Prediction Only", value=False)
|
225 |
+
run_button = gr.Button("Run Inference")
|
226 |
+
|
227 |
+
run_button.click(
|
228 |
+
fn=run_inference,
|
229 |
+
inputs=[
|
230 |
+
query_image,
|
231 |
+
support_image,
|
232 |
+
support_mask,
|
233 |
+
use_bbox,
|
234 |
+
use_points,
|
235 |
+
use_mask,
|
236 |
+
use_cca,
|
237 |
+
coarse_pred_only
|
238 |
+
],
|
239 |
+
outputs=[result_image, result_text]
|
240 |
+
)
|
241 |
+
|
242 |
+
return demo
|
243 |
+
|
244 |
+
# Create and launch the interface
|
245 |
+
if __name__ == "__main__":
|
246 |
+
demo = create_interface()
|
247 |
+
demo.launch(share=True)
|
backbone.sh
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
set -e
|
3 |
+
GPUID1=0
|
4 |
+
export CUDA_VISIBLE_DEVICES=$GPUID1
|
5 |
+
|
6 |
+
MODE=$1
|
7 |
+
if [ $MODE != "validation" ] && [ $MODE != "training" ]
|
8 |
+
then
|
9 |
+
echo "mode must be either validation or training"
|
10 |
+
exit 1
|
11 |
+
fi
|
12 |
+
|
13 |
+
# get modality as arg
|
14 |
+
MODALITY=$2
|
15 |
+
# make sure modality is either ct or mri
|
16 |
+
if [ $MODALITY != "ct" ] && [ $MODALITY != "mri" ]
|
17 |
+
then
|
18 |
+
echo "modality must be either ct or mri"
|
19 |
+
exit 1
|
20 |
+
fi
|
21 |
+
|
22 |
+
####### Shared configs ######
|
23 |
+
PROTO_GRID=8 # using 32 / 8 = 4, 4-by-4 prototype pooling window during training
|
24 |
+
INPUT_SIZE=256
|
25 |
+
ALL_EV=( 0 ) # 5-fold cross validation (0, 1, 2, 3, 4)
|
26 |
+
if [ $MODALITY == "ct" ]
|
27 |
+
then
|
28 |
+
DATASET='SABS_Superpix'
|
29 |
+
else
|
30 |
+
DATASET='CHAOST2_Superpix'
|
31 |
+
fi
|
32 |
+
|
33 |
+
if [ $INPUT_SIZE -gt 256 ]
|
34 |
+
then
|
35 |
+
DATASET=${DATASET}'_672'
|
36 |
+
fi
|
37 |
+
|
38 |
+
NWORKER=4
|
39 |
+
MODEL_NAME='dinov2_l14'
|
40 |
+
LORA=0
|
41 |
+
RELOAD_PATH=( "None" )
|
42 |
+
SKIP_SLICES="True"
|
43 |
+
DO_CCA="True"
|
44 |
+
TTT="False"
|
45 |
+
NSTEP=100000
|
46 |
+
RESET_AFTER_SLICE="True"
|
47 |
+
FINETUNE_ON_SUPPORT="False"
|
48 |
+
USE_SLICE_ADAPTER="False"
|
49 |
+
ADAPTER_LAYERS=1
|
50 |
+
CLAHE=False
|
51 |
+
ALL_SCALE=( "MIDDLE") # config of pseudolabels
|
52 |
+
|
53 |
+
LABEL_SETS=$3
|
54 |
+
EXCLU='[2,3]'
|
55 |
+
|
56 |
+
if [[ $MODALITY == "mri" && $LABEL_SETS -eq 1 ]]
|
57 |
+
then
|
58 |
+
echo "exluding 1, 4"
|
59 |
+
EXCLU='[1,4]' # liver(1), spleen(4)
|
60 |
+
fi
|
61 |
+
|
62 |
+
ORGANS='kidneys'
|
63 |
+
if [ $LABEL_SETS -eq 1 ]
|
64 |
+
then
|
65 |
+
ORGANS='liver_spleen'
|
66 |
+
fi
|
67 |
+
|
68 |
+
|
69 |
+
FREE_DESC=""
|
70 |
+
CPT="${MODE}_${MODEL_NAME}_${MODALITY}"
|
71 |
+
if [ -n "$FREE_DESC" ]
|
72 |
+
then
|
73 |
+
CPT="${CPT}_${FREE_DESC}"
|
74 |
+
fi
|
75 |
+
|
76 |
+
if [[ $TTT == "True" ]]
|
77 |
+
then
|
78 |
+
CPT="${CPT}_ttt_nstep_${NSTEP}"
|
79 |
+
if [ $RESET_AFTER_SLICE == "True" ]
|
80 |
+
then
|
81 |
+
CPT="${CPT}_reset_after_slice"
|
82 |
+
fi
|
83 |
+
fi
|
84 |
+
|
85 |
+
if [ $USE_SLICE_ADAPTER == "True" ]
|
86 |
+
then
|
87 |
+
CPT="${CPT}_w_adapter_${ADAPTER_LAYERS}_layers"
|
88 |
+
fi
|
89 |
+
|
90 |
+
if [ $LORA -ne 0 ]
|
91 |
+
then
|
92 |
+
CPT="${CPT}_lora_${LORA}"
|
93 |
+
fi
|
94 |
+
|
95 |
+
if [ $CLAHE == "True" ]
|
96 |
+
then
|
97 |
+
CPT="${CPT}_w_clahe"
|
98 |
+
fi
|
99 |
+
|
100 |
+
if [ $DO_CCA = "True" ]
|
101 |
+
then
|
102 |
+
CPT="${CPT}_cca"
|
103 |
+
fi
|
104 |
+
|
105 |
+
CPT="${CPT}_grid_${PROTO_GRID}_res_${INPUT_SIZE}"
|
106 |
+
|
107 |
+
if [ ${EXCLU} = "[]" ]
|
108 |
+
then
|
109 |
+
CPT="${CPT}_setting1"
|
110 |
+
else
|
111 |
+
CPT="${CPT}_setting2"
|
112 |
+
fi
|
113 |
+
|
114 |
+
CPT="${CPT}_${ORGANS}_fold"
|
115 |
+
|
116 |
+
###### Training configs (irrelavent in testing) ######
|
117 |
+
DECAY=0.95
|
118 |
+
|
119 |
+
MAX_ITER=1000 # defines the size of an epoch
|
120 |
+
SNAPSHOT_INTERVAL=25000 # interval for saving snapshot
|
121 |
+
SEED='1234'
|
122 |
+
|
123 |
+
###### Validation configs ######
|
124 |
+
SUPP_ID='[6]' # using the additionally loaded scan as support
|
125 |
+
if [ $MODALITY == "mri" ]
|
126 |
+
then
|
127 |
+
SUPP_ID='[4]'
|
128 |
+
fi
|
129 |
+
|
130 |
+
echo ===================================
|
131 |
+
|
132 |
+
for ((i=0; i<${#ALL_EV[@]}; i++))
|
133 |
+
do
|
134 |
+
EVAL_FOLD=${ALL_EV[i]}
|
135 |
+
CPT_W_FOLD="${CPT}_${EVAL_FOLD}"
|
136 |
+
echo $CPT_W_FOLD on GPU $GPUID1
|
137 |
+
for SUPERPIX_SCALE in "${ALL_SCALE[@]}"
|
138 |
+
do
|
139 |
+
PREFIX="test_vfold${EVAL_FOLD}"
|
140 |
+
echo $PREFIX
|
141 |
+
LOGDIR="./test_${MODALITY}/${CPT_W_FOLD}"
|
142 |
+
|
143 |
+
if [ ! -d $LOGDIR ]
|
144 |
+
then
|
145 |
+
mkdir -p $LOGDIR
|
146 |
+
fi
|
147 |
+
|
148 |
+
python3 $MODE.py with \
|
149 |
+
"modelname=$MODEL_NAME" \
|
150 |
+
'usealign=True' \
|
151 |
+
'optim_type=sgd' \
|
152 |
+
reload_model_path=${RELOAD_PATH[i]} \
|
153 |
+
num_workers=$NWORKER \
|
154 |
+
scan_per_load=-1 \
|
155 |
+
label_sets=$LABEL_SETS \
|
156 |
+
'use_wce=True' \
|
157 |
+
exp_prefix=$PREFIX \
|
158 |
+
'clsname=grid_proto' \
|
159 |
+
n_steps=$NSTEP \
|
160 |
+
exclude_cls_list=$EXCLU \
|
161 |
+
eval_fold=$EVAL_FOLD \
|
162 |
+
dataset=$DATASET \
|
163 |
+
proto_grid_size=$PROTO_GRID \
|
164 |
+
max_iters_per_load=$MAX_ITER \
|
165 |
+
min_fg_data=1 seed=$SEED \
|
166 |
+
save_snapshot_every=$SNAPSHOT_INTERVAL \
|
167 |
+
superpix_scale=$SUPERPIX_SCALE \
|
168 |
+
lr_step_gamma=$DECAY \
|
169 |
+
path.log_dir=$LOGDIR \
|
170 |
+
support_idx=$SUPP_ID \
|
171 |
+
lora=$LORA \
|
172 |
+
do_cca=$DO_CCA \
|
173 |
+
ttt=$TTT \
|
174 |
+
adapter_layers=$ADAPTER_LAYERS \
|
175 |
+
use_slice_adapter=$USE_SLICE_ADAPTER \
|
176 |
+
reset_after_slice=$RESET_AFTER_SLICE \
|
177 |
+
"input_size=($INPUT_SIZE, $INPUT_SIZE)"
|
178 |
+
done
|
179 |
+
done
|
config_ssl_upload.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Experiment configuration file
|
3 |
+
Extended from config file from original PANet Repository
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import glob
|
8 |
+
import itertools
|
9 |
+
|
10 |
+
import sacred
|
11 |
+
from sacred import Experiment
|
12 |
+
from sacred.observers import FileStorageObserver
|
13 |
+
from sacred.utils import apply_backspaces_and_linefeeds
|
14 |
+
|
15 |
+
from platform import node
|
16 |
+
from datetime import datetime
|
17 |
+
|
18 |
+
from util.consts import IMG_SIZE
|
19 |
+
|
20 |
+
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
|
21 |
+
sacred.SETTINGS.CAPTURE_MODE = 'no'
|
22 |
+
|
23 |
+
ex = Experiment('mySSL')
|
24 |
+
ex.captured_out_filter = apply_backspaces_and_linefeeds
|
25 |
+
|
26 |
+
source_folders = ['.', './dataloaders', './models', './util']
|
27 |
+
sources_to_save = list(itertools.chain.from_iterable(
|
28 |
+
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
|
29 |
+
for source_file in sources_to_save:
|
30 |
+
ex.add_source_file(source_file)
|
31 |
+
|
32 |
+
@ex.config
|
33 |
+
def cfg():
|
34 |
+
"""Default configurations"""
|
35 |
+
seed = 1234
|
36 |
+
gpu_id = 0
|
37 |
+
mode = 'train' # for now only allows 'train'
|
38 |
+
do_validation=False
|
39 |
+
num_workers = 4 # 0 for debugging.
|
40 |
+
|
41 |
+
dataset = 'CHAOST2' # i.e. abdominal MRI
|
42 |
+
use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images
|
43 |
+
|
44 |
+
### Training
|
45 |
+
n_steps = 100100
|
46 |
+
batch_size = 1
|
47 |
+
lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)]
|
48 |
+
lr_step_gamma = 0.95
|
49 |
+
ignore_label = 255
|
50 |
+
print_interval = 100
|
51 |
+
save_snapshot_every = 25000
|
52 |
+
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
|
53 |
+
epochs=1
|
54 |
+
scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory
|
55 |
+
which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms
|
56 |
+
input_size = (IMG_SIZE, IMG_SIZE)
|
57 |
+
min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process
|
58 |
+
label_sets = 0 # which group of labels taking as training (the rest are for testing)
|
59 |
+
curr_cls = "" # choose between rk, lk, spleen and liver
|
60 |
+
exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1
|
61 |
+
usealign = True # see vanilla PANet
|
62 |
+
use_wce = True
|
63 |
+
use_dinov2_loss = False
|
64 |
+
dice_loss = False
|
65 |
+
### Validation
|
66 |
+
z_margin = 0
|
67 |
+
eval_fold = 0 # which fold for 5 fold cross validation
|
68 |
+
support_idx=[-1] # indicating which scan is used as support in testing.
|
69 |
+
val_wsize=2 # L_H, L_W in testing
|
70 |
+
n_sup_part = 3 # number of chuncks in testing
|
71 |
+
use_clahe = False
|
72 |
+
use_slice_adapter = False
|
73 |
+
adapter_layers=3
|
74 |
+
debug=True
|
75 |
+
skip_no_organ_slices=True
|
76 |
+
# Network
|
77 |
+
modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab
|
78 |
+
clsname = None #
|
79 |
+
reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization)
|
80 |
+
proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training
|
81 |
+
feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future
|
82 |
+
lora = 0
|
83 |
+
use_3_slices=False
|
84 |
+
do_cca=False
|
85 |
+
use_edge_detector=False
|
86 |
+
finetune_on_support=False
|
87 |
+
sliding_window_confidence_segmentation=False
|
88 |
+
finetune_model_on_single_slice=False
|
89 |
+
online_finetuning=True
|
90 |
+
|
91 |
+
use_bbox=True # for SAM
|
92 |
+
use_points=True # for SAM
|
93 |
+
use_mask=False # for SAM
|
94 |
+
base_model="alpnet" # or "SAM"
|
95 |
+
# SSL
|
96 |
+
superpix_scale = 'MIDDLE' #MIDDLE/ LARGE
|
97 |
+
use_pos_enc=False
|
98 |
+
support_txt_file = None # path to a txt file containing support slices
|
99 |
+
augment_support_set=False
|
100 |
+
coarse_pred_only=False # for ProtoSAM
|
101 |
+
point_mode="both" # for ProtoSAM, choose: both, conf, centroid
|
102 |
+
use_neg_points=False
|
103 |
+
n_support=1 # num support images
|
104 |
+
protosam_sam_ver="sam_h" # or medsam
|
105 |
+
grad_accumulation_steps=1
|
106 |
+
ttt=False
|
107 |
+
reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice
|
108 |
+
model = {
|
109 |
+
'align': usealign,
|
110 |
+
'dinov2_loss': use_dinov2_loss,
|
111 |
+
'use_coco_init': use_coco_init,
|
112 |
+
'which_model': modelname,
|
113 |
+
'cls_name': clsname,
|
114 |
+
'proto_grid_size' : proto_grid_size,
|
115 |
+
'feature_hw': feature_hw,
|
116 |
+
'reload_model_path': reload_model_path,
|
117 |
+
'lora': lora,
|
118 |
+
'use_slice_adapter': use_slice_adapter,
|
119 |
+
'adapter_layers': adapter_layers,
|
120 |
+
'debug': debug,
|
121 |
+
'use_pos_enc': use_pos_enc
|
122 |
+
}
|
123 |
+
|
124 |
+
task = {
|
125 |
+
'n_ways': 1,
|
126 |
+
'n_shots': 1,
|
127 |
+
'n_queries': 1,
|
128 |
+
'npart': n_sup_part
|
129 |
+
}
|
130 |
+
|
131 |
+
optim_type = 'sgd'
|
132 |
+
lr=1e-3
|
133 |
+
momentum=0.9
|
134 |
+
weight_decay=0.0005
|
135 |
+
optim = {
|
136 |
+
'lr': lr,
|
137 |
+
'momentum': momentum,
|
138 |
+
'weight_decay': weight_decay
|
139 |
+
}
|
140 |
+
|
141 |
+
exp_prefix = ''
|
142 |
+
|
143 |
+
exp_str = '_'.join(
|
144 |
+
[exp_prefix]
|
145 |
+
+ [dataset,]
|
146 |
+
+ [f'sets_{label_sets}_{task["n_shots"]}shot'])
|
147 |
+
|
148 |
+
path = {
|
149 |
+
'log_dir': './runs',
|
150 |
+
'SABS':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"
|
151 |
+
},
|
152 |
+
'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"
|
153 |
+
},
|
154 |
+
'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"
|
155 |
+
},
|
156 |
+
'C0':{'data_dir': "feed your dataset path here"
|
157 |
+
},
|
158 |
+
'CHAOST2':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"
|
159 |
+
},
|
160 |
+
'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"
|
161 |
+
},
|
162 |
+
'SABS_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/sabs_CT_normalized/sabs_CT_normalized"},
|
163 |
+
'C0_Superpix':{'data_dir': "feed your dataset path here"},
|
164 |
+
'CHAOST2_Superpix':{'data_dir': "/kaggle/input/preprocessed-data/chaos_MR_T2_normalized/chaos_MR_T2_normalized"},
|
165 |
+
'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"},
|
166 |
+
'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"},
|
167 |
+
'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"},
|
168 |
+
}
|
169 |
+
|
170 |
+
|
171 |
+
@ex.config_hook
|
172 |
+
def add_observer(config, command_name, logger):
|
173 |
+
"""A hook fucntion to add observer"""
|
174 |
+
exp_name = f'{ex.path}_{config["exp_str"]}'
|
175 |
+
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
|
176 |
+
ex.observers.append(observer)
|
177 |
+
return config
|
data/data_processing.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
dataloaders/GenericSuperDatasetv2.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset for training with pseudolabels
|
3 |
+
TODO:
|
4 |
+
1. Merge with manual annotated dataset
|
5 |
+
2. superpixel_scale -> superpix_config, feed like a dict
|
6 |
+
"""
|
7 |
+
import glob
|
8 |
+
import numpy as np
|
9 |
+
import dataloaders.augutils as myaug
|
10 |
+
import torch
|
11 |
+
import random
|
12 |
+
import os
|
13 |
+
import copy
|
14 |
+
import platform
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
import cv2
|
18 |
+
from dataloaders.common import BaseDataset, Subset
|
19 |
+
from dataloaders.dataset_utils import*
|
20 |
+
from pdb import set_trace
|
21 |
+
from util.utils import CircularList
|
22 |
+
from util.consts import IMG_SIZE
|
23 |
+
|
24 |
+
class SuperpixelDataset(BaseDataset):
|
25 |
+
def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, num_rep = 2, min_fg = '', nsup = 1, fix_length = None, tile_z_dim = 3, exclude_list = [], train_list = [], superpix_scale = 'SMALL', norm_mean=None, norm_std=None, supervised_train=False, use_3_slices=False, **kwargs):
|
26 |
+
"""
|
27 |
+
Pseudolabel dataset
|
28 |
+
Args:
|
29 |
+
which_dataset: name of the dataset to use
|
30 |
+
base_dir: directory of dataset
|
31 |
+
idx_split: index of data split as we will do cross validation
|
32 |
+
mode: 'train', 'val'.
|
33 |
+
nsup: number of scans used as support. currently idle for superpixel dataset
|
34 |
+
transforms: data transform (augmentation) function
|
35 |
+
scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
|
36 |
+
num_rep: Number of augmentation applied for a same pseudolabel
|
37 |
+
tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
|
38 |
+
fix_length: fix the length of dataset
|
39 |
+
exclude_list: Labels to be excluded
|
40 |
+
superpix_scale: config of superpixels
|
41 |
+
"""
|
42 |
+
super(SuperpixelDataset, self).__init__(base_dir)
|
43 |
+
|
44 |
+
self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
|
45 |
+
self.sep = DATASET_INFO[which_dataset]['_SEP']
|
46 |
+
self.pseu_label_name = DATASET_INFO[which_dataset]['PSEU_LABEL_NAME']
|
47 |
+
self.real_label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
|
48 |
+
|
49 |
+
self.image_size = image_size
|
50 |
+
self.transforms = transforms
|
51 |
+
self.is_train = True if mode == 'train' else False
|
52 |
+
self.supervised_train = supervised_train
|
53 |
+
if self.supervised_train and len(train_list) == 0:
|
54 |
+
raise Exception('Please provide training labels')
|
55 |
+
# assert mode == 'train'
|
56 |
+
self.fix_length = fix_length
|
57 |
+
if self.supervised_train:
|
58 |
+
# self.nclass = len(self.real_label_name)
|
59 |
+
self.nclass = len(self.pseu_label_name)
|
60 |
+
else:
|
61 |
+
self.nclass = len(self.pseu_label_name)
|
62 |
+
self.num_rep = num_rep
|
63 |
+
self.tile_z_dim = tile_z_dim
|
64 |
+
self.use_3_slices = use_3_slices
|
65 |
+
if tile_z_dim > 1 and self.use_3_slices:
|
66 |
+
raise Exception("tile_z_dim and use_3_slices shouldn't be used together")
|
67 |
+
|
68 |
+
# find scans in the data folder
|
69 |
+
self.nsup = nsup
|
70 |
+
self.base_dir = base_dir
|
71 |
+
self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
|
72 |
+
self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x)))
|
73 |
+
|
74 |
+
# experiment configs
|
75 |
+
self.exclude_lbs = exclude_list
|
76 |
+
self.train_list = train_list
|
77 |
+
self.superpix_scale = superpix_scale
|
78 |
+
if len(exclude_list) > 0:
|
79 |
+
print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
|
80 |
+
self.idx_split = idx_split
|
81 |
+
self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
|
82 |
+
self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
|
83 |
+
self.scan_per_load = scan_per_load
|
84 |
+
|
85 |
+
self.info_by_scan = None
|
86 |
+
self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
|
87 |
+
self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()], ct_mean=norm_mean, ct_std=norm_std)
|
88 |
+
|
89 |
+
if self.is_train:
|
90 |
+
if scan_per_load > 0: # if the dataset is too large, only reload a subset in each sub-epoch
|
91 |
+
self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
|
92 |
+
else: # load the entire set without a buffer
|
93 |
+
self.pid_curr_load = self.scan_ids
|
94 |
+
elif mode == 'val':
|
95 |
+
self.pid_curr_load = self.scan_ids
|
96 |
+
else:
|
97 |
+
raise Exception
|
98 |
+
|
99 |
+
self.use_clahe = False
|
100 |
+
if kwargs['use_clahe']:
|
101 |
+
self.use_clahe = True
|
102 |
+
clip_limit = 4.0 if self.img_modality == 'MR' else 2.0
|
103 |
+
self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(7,7))
|
104 |
+
|
105 |
+
self.actual_dataset = self.read_dataset()
|
106 |
+
self.size = len(self.actual_dataset)
|
107 |
+
self.overall_slice_by_cls = self.read_classfiles()
|
108 |
+
|
109 |
+
print("###### Initial scans loaded: ######")
|
110 |
+
print(self.pid_curr_load)
|
111 |
+
|
112 |
+
def get_scanids(self, mode, idx_split):
|
113 |
+
"""
|
114 |
+
Load scans by train-test split
|
115 |
+
leaving one additional scan as the support scan. if the last fold, taking scan 0 as the additional one
|
116 |
+
Args:
|
117 |
+
idx_split: index for spliting cross-validation folds
|
118 |
+
"""
|
119 |
+
val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
|
120 |
+
if mode == 'train':
|
121 |
+
return [ ii for ii in self.img_pids if ii not in val_ids ]
|
122 |
+
elif mode == 'val':
|
123 |
+
return val_ids
|
124 |
+
|
125 |
+
def reload_buffer(self):
|
126 |
+
"""
|
127 |
+
Reload a only portion of the entire dataset, if the dataset is too large
|
128 |
+
1. delete original buffer
|
129 |
+
2. update self.ids_this_batch
|
130 |
+
3. update other internel variables like __len__
|
131 |
+
"""
|
132 |
+
if self.scan_per_load <= 0:
|
133 |
+
print("We are not using the reload buffer, doing notiong")
|
134 |
+
return -1
|
135 |
+
|
136 |
+
del self.actual_dataset
|
137 |
+
del self.info_by_scan
|
138 |
+
|
139 |
+
self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
|
140 |
+
self.actual_dataset = self.read_dataset()
|
141 |
+
self.size = len(self.actual_dataset)
|
142 |
+
self.update_subclass_lookup()
|
143 |
+
print(f'Loader buffer reloaded with a new size of {self.size} slices')
|
144 |
+
|
145 |
+
def organize_sample_fids(self):
|
146 |
+
out_list = {}
|
147 |
+
for curr_id in self.scan_ids:
|
148 |
+
curr_dict = {}
|
149 |
+
|
150 |
+
_img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
|
151 |
+
_lb_fid = os.path.join(self.base_dir, f'superpix-{self.superpix_scale}_{curr_id}.nii.gz')
|
152 |
+
_gt_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
|
153 |
+
|
154 |
+
curr_dict["img_fid"] = _img_fid
|
155 |
+
curr_dict["lbs_fid"] = _lb_fid
|
156 |
+
curr_dict["gt_lbs_fid"] = _gt_lb_fid
|
157 |
+
out_list[str(curr_id)] = curr_dict
|
158 |
+
return out_list
|
159 |
+
|
160 |
+
def read_dataset(self):
|
161 |
+
"""
|
162 |
+
Read images into memory and store them in 2D
|
163 |
+
Build tables for the position of an individual 2D slice in the entire dataset
|
164 |
+
"""
|
165 |
+
out_list = []
|
166 |
+
self.scan_z_idx = {}
|
167 |
+
self.info_by_scan = {} # meta data of each scan
|
168 |
+
glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
|
169 |
+
|
170 |
+
for scan_id, itm in self.img_lb_fids.items():
|
171 |
+
if scan_id not in self.pid_curr_load:
|
172 |
+
continue
|
173 |
+
|
174 |
+
img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
|
175 |
+
# read connected graph of labels
|
176 |
+
if self.use_clahe:
|
177 |
+
# img = nself.clahe.apply(img.astype(np.uint8))
|
178 |
+
if self.img_modality == 'MR':
|
179 |
+
img = np.stack([((slice - slice.min()) / (slice.max() - slice.min())) * 255 for slice in img], axis=0)
|
180 |
+
img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
|
181 |
+
|
182 |
+
img = img.transpose(1,2,0)
|
183 |
+
self.info_by_scan[scan_id] = _info
|
184 |
+
|
185 |
+
img = np.float32(img)
|
186 |
+
img = self.norm_func(img)
|
187 |
+
self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
|
188 |
+
|
189 |
+
if self.supervised_train:
|
190 |
+
lb = read_nii_bysitk(itm["gt_lbs_fid"])
|
191 |
+
else:
|
192 |
+
lb = read_nii_bysitk(itm["lbs_fid"])
|
193 |
+
lb = lb.transpose(1,2,0)
|
194 |
+
lb = np.int32(lb)
|
195 |
+
|
196 |
+
# resize img and lb to self.image_size
|
197 |
+
img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
|
198 |
+
lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
|
199 |
+
|
200 |
+
# format of slices: [axial_H x axial_W x Z]
|
201 |
+
if self.supervised_train:
|
202 |
+
# remove all images that dont have the training labels
|
203 |
+
del_indices = [i for i in range(img.shape[-1]) if not np.any(np.isin(lb[..., i], self.train_list))]
|
204 |
+
# create an new img and lb without indices in del_indices
|
205 |
+
new_img = np.zeros((img.shape[0], img.shape[1], img.shape[2] - len(del_indices)))
|
206 |
+
new_lb = np.zeros((lb.shape[0], lb.shape[1], lb.shape[2] - len(del_indices)))
|
207 |
+
new_img = img[..., ~np.isin(np.arange(img.shape[-1]), del_indices)]
|
208 |
+
new_lb = lb[..., ~np.isin(np.arange(lb.shape[-1]), del_indices)]
|
209 |
+
|
210 |
+
img = new_img
|
211 |
+
lb = new_lb
|
212 |
+
a = [i for i in range(img.shape[-1]) if lb[...,i].max() == 0]
|
213 |
+
|
214 |
+
nframes = img.shape[-1]
|
215 |
+
assert img.shape[-1] == lb.shape[-1]
|
216 |
+
base_idx = img.shape[-1] // 2 # index of the middle slice
|
217 |
+
|
218 |
+
# re-organize 3D images into 2D slices and record essential information for each slice
|
219 |
+
out_list.append( {"img": img[..., 0: 1],
|
220 |
+
"lb":lb[..., 0: 0 + 1],
|
221 |
+
"sup_max_cls": lb[..., 0: 0 + 1].max(),
|
222 |
+
"is_start": True,
|
223 |
+
"is_end": False,
|
224 |
+
"nframe": nframes,
|
225 |
+
"scan_id": scan_id,
|
226 |
+
"z_id":0,
|
227 |
+
})
|
228 |
+
|
229 |
+
self.scan_z_idx[scan_id][0] = glb_idx
|
230 |
+
glb_idx += 1
|
231 |
+
|
232 |
+
for ii in range(1, img.shape[-1] - 1):
|
233 |
+
out_list.append( {"img": img[..., ii: ii + 1],
|
234 |
+
"lb":lb[..., ii: ii + 1],
|
235 |
+
"is_start": False,
|
236 |
+
"is_end": False,
|
237 |
+
"sup_max_cls": lb[..., ii: ii + 1].max(),
|
238 |
+
"nframe": nframes,
|
239 |
+
"scan_id": scan_id,
|
240 |
+
"z_id": ii,
|
241 |
+
})
|
242 |
+
self.scan_z_idx[scan_id][ii] = glb_idx
|
243 |
+
glb_idx += 1
|
244 |
+
|
245 |
+
ii += 1 # last slice of a 3D volume
|
246 |
+
out_list.append( {"img": img[..., ii: ii + 1],
|
247 |
+
"lb":lb[..., ii: ii+ 1],
|
248 |
+
"is_start": False,
|
249 |
+
"is_end": True,
|
250 |
+
"sup_max_cls": lb[..., ii: ii + 1].max(),
|
251 |
+
"nframe": nframes,
|
252 |
+
"scan_id": scan_id,
|
253 |
+
"z_id": ii,
|
254 |
+
})
|
255 |
+
|
256 |
+
self.scan_z_idx[scan_id][ii] = glb_idx
|
257 |
+
glb_idx += 1
|
258 |
+
|
259 |
+
return out_list
|
260 |
+
|
261 |
+
def read_classfiles(self):
|
262 |
+
"""
|
263 |
+
Load the scan-slice-class indexing file
|
264 |
+
"""
|
265 |
+
with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
|
266 |
+
cls_map = json.load( fopen)
|
267 |
+
fopen.close()
|
268 |
+
|
269 |
+
with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
|
270 |
+
self.tp1_cls_map = json.load( fopen)
|
271 |
+
fopen.close()
|
272 |
+
|
273 |
+
return cls_map
|
274 |
+
|
275 |
+
def get_superpixels_similarity(self, sp1, sp2):
|
276 |
+
pass
|
277 |
+
|
278 |
+
def supcls_pick_binarize(self, super_map, sup_max_cls, bi_val=None, conn_graph=None, img=None):
|
279 |
+
if bi_val is None:
|
280 |
+
# bi_val = np.random.randint(1, sup_max_cls)
|
281 |
+
bi_val = random.choice(list(np.unique(super_map)))
|
282 |
+
if conn_graph is not None and img is not None:
|
283 |
+
# get number of neighbors of bi_val
|
284 |
+
neighbors = conn_graph[bi_val]
|
285 |
+
# pick a random number of neighbors and merge them
|
286 |
+
n_neighbors = np.random.randint(0, len(neighbors))
|
287 |
+
try:
|
288 |
+
neighbors = random.sample(neighbors, n_neighbors)
|
289 |
+
except TypeError:
|
290 |
+
neighbors = []
|
291 |
+
# merge neighbors
|
292 |
+
super_map = np.where(np.isin(super_map, neighbors), bi_val, super_map)
|
293 |
+
return np.float32(super_map == bi_val)
|
294 |
+
|
295 |
+
def supcls_pick(self, super_map):
|
296 |
+
return random.choice(list(np.unique(super_map)))
|
297 |
+
|
298 |
+
def get_3_slice_adjacent_image(self, image_t, index):
|
299 |
+
curr_dict = self.actual_dataset[index]
|
300 |
+
prev_image = np.zeros_like(image_t)
|
301 |
+
|
302 |
+
if index > 1 and not curr_dict["is_start"]:
|
303 |
+
prev_dict = self.actual_dataset[index - 1]
|
304 |
+
prev_image = prev_dict["img"]
|
305 |
+
|
306 |
+
next_image = np.zeros_like(image_t)
|
307 |
+
if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
|
308 |
+
next_dict = self.actual_dataset[index + 1]
|
309 |
+
next_image = next_dict["img"]
|
310 |
+
|
311 |
+
image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
|
312 |
+
|
313 |
+
return image_t
|
314 |
+
|
315 |
+
def __getitem__(self, index):
|
316 |
+
index = index % len(self.actual_dataset)
|
317 |
+
curr_dict = self.actual_dataset[index]
|
318 |
+
sup_max_cls = curr_dict['sup_max_cls']
|
319 |
+
if sup_max_cls < 1:
|
320 |
+
return self.__getitem__(index + 1)
|
321 |
+
|
322 |
+
image_t = curr_dict["img"]
|
323 |
+
label_raw = curr_dict["lb"]
|
324 |
+
|
325 |
+
if self.use_3_slices:
|
326 |
+
image_t = self.get_3_slice_adjacent_image(image_t, index)
|
327 |
+
|
328 |
+
for _ex_cls in self.exclude_lbs:
|
329 |
+
if curr_dict["z_id"] in self.tp1_cls_map[self.real_label_name[_ex_cls]][curr_dict["scan_id"]]: # if using setting 1, this slice need to be excluded since it contains label which is supposed to be unseen
|
330 |
+
return self.__getitem__(torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
|
331 |
+
|
332 |
+
if self.supervised_train:
|
333 |
+
superpix_label = -1
|
334 |
+
label_t = np.float32(label_raw)
|
335 |
+
|
336 |
+
lb_id = random.choice(list(set(np.unique(label_raw)) & set(self.train_list)))
|
337 |
+
label_t[label_t != lb_id] = 0
|
338 |
+
label_t[label_t == lb_id] = 1
|
339 |
+
|
340 |
+
else:
|
341 |
+
superpix_label = self.supcls_pick(label_raw)
|
342 |
+
label_t = np.float32(label_raw == superpix_label)
|
343 |
+
|
344 |
+
pair_buffer = []
|
345 |
+
|
346 |
+
comp = np.concatenate( [image_t, label_t], axis = -1 )
|
347 |
+
|
348 |
+
for ii in range(self.num_rep):
|
349 |
+
if self.transforms is not None:
|
350 |
+
img, lb = self.transforms(comp, c_img = image_t.shape[-1], c_label = 1, nclass = self.nclass, is_train = True, use_onehot = False)
|
351 |
+
else:
|
352 |
+
img, lb = comp[:, :, 0:1], comp[:, :, 1:2]
|
353 |
+
# if ii % 2 == 0:
|
354 |
+
# label_raw = lb
|
355 |
+
# lb = lb == superpix_label
|
356 |
+
|
357 |
+
img = torch.from_numpy( np.transpose( img, (2, 0, 1)) ).float()
|
358 |
+
lb = torch.from_numpy( lb.squeeze(-1)).float()
|
359 |
+
|
360 |
+
img = img.repeat( [ self.tile_z_dim, 1, 1] )
|
361 |
+
|
362 |
+
is_start = curr_dict["is_start"]
|
363 |
+
is_end = curr_dict["is_end"]
|
364 |
+
nframe = np.int32(curr_dict["nframe"])
|
365 |
+
scan_id = curr_dict["scan_id"]
|
366 |
+
z_id = curr_dict["z_id"]
|
367 |
+
|
368 |
+
sample = {"image": img,
|
369 |
+
"label":lb,
|
370 |
+
"is_start": is_start,
|
371 |
+
"is_end": is_end,
|
372 |
+
"nframe": nframe,
|
373 |
+
"scan_id": scan_id,
|
374 |
+
"z_id": z_id
|
375 |
+
}
|
376 |
+
|
377 |
+
# Add auxiliary attributes
|
378 |
+
if self.aux_attrib is not None:
|
379 |
+
for key_prefix in self.aux_attrib:
|
380 |
+
# Process the data sample, create new attributes and save them in a dictionary
|
381 |
+
aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
|
382 |
+
for key_suffix in aux_attrib_val:
|
383 |
+
# one function may create multiple attributes, so we need suffix to distinguish them
|
384 |
+
sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
|
385 |
+
pair_buffer.append(sample)
|
386 |
+
|
387 |
+
support_images = []
|
388 |
+
support_mask = []
|
389 |
+
support_class = []
|
390 |
+
|
391 |
+
query_images = []
|
392 |
+
query_labels = []
|
393 |
+
query_class = []
|
394 |
+
|
395 |
+
for idx, itm in enumerate(pair_buffer):
|
396 |
+
if idx % 2 == 0:
|
397 |
+
support_images.append(itm["image"])
|
398 |
+
support_class.append(1) # pseudolabel class
|
399 |
+
support_mask.append( self.getMaskMedImg( itm["label"], 1, [1] ))
|
400 |
+
else:
|
401 |
+
query_images.append(itm["image"])
|
402 |
+
query_class.append(1)
|
403 |
+
query_labels.append( itm["label"])
|
404 |
+
|
405 |
+
return {'class_ids': [support_class],
|
406 |
+
'support_images': [support_images], #
|
407 |
+
'superpix_label': superpix_label,
|
408 |
+
'superpix_label_raw': label_raw[:,:,0],
|
409 |
+
'support_mask': [support_mask],
|
410 |
+
'query_images': query_images, #
|
411 |
+
'query_labels': query_labels,
|
412 |
+
'scan_id': scan_id,
|
413 |
+
'z_id': z_id,
|
414 |
+
'nframe': nframe,
|
415 |
+
}
|
416 |
+
|
417 |
+
|
418 |
+
def __len__(self):
|
419 |
+
"""
|
420 |
+
copy-paste from basic naive dataset configuration
|
421 |
+
"""
|
422 |
+
if self.fix_length != None:
|
423 |
+
assert self.fix_length >= len(self.actual_dataset)
|
424 |
+
return self.fix_length
|
425 |
+
else:
|
426 |
+
return len(self.actual_dataset)
|
427 |
+
|
428 |
+
def getMaskMedImg(self, label, class_id, class_ids):
|
429 |
+
"""
|
430 |
+
Generate FG/BG mask from the segmentation mask
|
431 |
+
|
432 |
+
Args:
|
433 |
+
label: semantic mask
|
434 |
+
class_id: semantic class of interest
|
435 |
+
class_ids: all class id in this episode
|
436 |
+
"""
|
437 |
+
fg_mask = torch.where(label == class_id,
|
438 |
+
torch.ones_like(label), torch.zeros_like(label))
|
439 |
+
bg_mask = torch.where(label != class_id,
|
440 |
+
torch.ones_like(label), torch.zeros_like(label))
|
441 |
+
for class_id in class_ids:
|
442 |
+
bg_mask[label == class_id] = 0
|
443 |
+
|
444 |
+
return {'fg_mask': fg_mask,
|
445 |
+
'bg_mask': bg_mask}
|
dataloaders/ManualAnnoDatasetv2.py
ADDED
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Manually labeled dataset
|
3 |
+
TODO:
|
4 |
+
1. Merge with superpixel dataset
|
5 |
+
"""
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
import dataloaders.augutils as myaug
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
import os
|
12 |
+
import copy
|
13 |
+
import platform
|
14 |
+
import json
|
15 |
+
import re
|
16 |
+
import cv2
|
17 |
+
from dataloaders.common import BaseDataset, Subset, ValidationDataset
|
18 |
+
# from common import BaseDataset, Subset
|
19 |
+
from dataloaders.dataset_utils import*
|
20 |
+
from pdb import set_trace
|
21 |
+
from util.utils import CircularList
|
22 |
+
from util.consts import IMG_SIZE
|
23 |
+
|
24 |
+
MODE_DEFAULT = "default"
|
25 |
+
MODE_FULL_SCAN = "full_scan"
|
26 |
+
|
27 |
+
class ManualAnnoDataset(BaseDataset):
|
28 |
+
def __init__(self, which_dataset, base_dir, idx_split, mode, image_size, transforms, scan_per_load, min_fg = '', fix_length = None, tile_z_dim = 3, nsup = 1, exclude_list = [], extern_normalize_func = None, **kwargs):
|
29 |
+
"""
|
30 |
+
Manually labeled dataset
|
31 |
+
Args:
|
32 |
+
which_dataset: name of the dataset to use
|
33 |
+
base_dir: directory of dataset
|
34 |
+
idx_split: index of data split as we will do cross validation
|
35 |
+
mode: 'train', 'val'.
|
36 |
+
transforms: data transform (augmentation) function
|
37 |
+
min_fg: minimum number of positive pixels in a 2D slice, mainly for stablize training when trained on manually labeled dataset
|
38 |
+
scan_per_load: loading a portion of the entire dataset, in case that the dataset is too large to fit into the memory. Set to -1 if loading the entire dataset at one time
|
39 |
+
tile_z_dim: number of identical slices to tile along channel dimension, for fitting 2D single-channel medical images into off-the-shelf networks designed for RGB natural images
|
40 |
+
nsup: number of support scans
|
41 |
+
fix_length: fix the length of dataset
|
42 |
+
exclude_list: Labels to be excluded
|
43 |
+
extern_normalize_function: normalization function used for data pre-processing
|
44 |
+
"""
|
45 |
+
super(ManualAnnoDataset, self).__init__(base_dir)
|
46 |
+
self.img_modality = DATASET_INFO[which_dataset]['MODALITY']
|
47 |
+
self.sep = DATASET_INFO[which_dataset]['_SEP']
|
48 |
+
self.label_name = DATASET_INFO[which_dataset]['REAL_LABEL_NAME']
|
49 |
+
self.image_size = image_size
|
50 |
+
self.transforms = transforms
|
51 |
+
self.is_train = True if mode == 'train' else False
|
52 |
+
self.phase = mode
|
53 |
+
self.fix_length = fix_length
|
54 |
+
self.all_label_names = self.label_name
|
55 |
+
self.nclass = len(self.label_name)
|
56 |
+
self.tile_z_dim = tile_z_dim
|
57 |
+
self.base_dir = base_dir
|
58 |
+
self.nsup = nsup
|
59 |
+
self.img_pids = [ re.findall('\d+', fid)[-1] for fid in glob.glob(self.base_dir + "/image_*.nii") ]
|
60 |
+
self.img_pids = CircularList(sorted( self.img_pids, key = lambda x: int(x))) # make it circular for the ease of spliting folds
|
61 |
+
if 'use_clahe' not in kwargs:
|
62 |
+
self.use_clahe = False
|
63 |
+
else:
|
64 |
+
self.use_clahe = kwargs['use_clahe']
|
65 |
+
if self.use_clahe:
|
66 |
+
self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(7,7))
|
67 |
+
|
68 |
+
self.use_3_slices = kwargs["use_3_slices"] if 'use_3_slices' in kwargs else False
|
69 |
+
if self.use_3_slices:
|
70 |
+
self.tile_z_dim=1
|
71 |
+
|
72 |
+
self.get_item_mode = MODE_DEFAULT
|
73 |
+
if 'get_item_mode' in kwargs:
|
74 |
+
self.get_item_mode = kwargs['get_item_mode']
|
75 |
+
|
76 |
+
self.exclude_lbs = exclude_list
|
77 |
+
if len(exclude_list) > 0:
|
78 |
+
print(f'###### Dataset: the following classes has been excluded {exclude_list}######')
|
79 |
+
|
80 |
+
self.idx_split = idx_split
|
81 |
+
self.scan_ids = self.get_scanids(mode, idx_split) # patient ids of the entire fold
|
82 |
+
self.min_fg = min_fg if isinstance(min_fg, str) else str(min_fg)
|
83 |
+
|
84 |
+
self.scan_per_load = scan_per_load
|
85 |
+
|
86 |
+
self.info_by_scan = None
|
87 |
+
self.img_lb_fids = self.organize_sample_fids() # information of scans of the entire fold
|
88 |
+
|
89 |
+
if extern_normalize_func is not None: # helps to keep consistent between training and testing dataset.
|
90 |
+
self.norm_func = extern_normalize_func
|
91 |
+
print(f'###### Dataset: using external normalization statistics ######')
|
92 |
+
else:
|
93 |
+
self.norm_func = get_normalize_op(self.img_modality, [ fid_pair['img_fid'] for _, fid_pair in self.img_lb_fids.items()])
|
94 |
+
print(f'###### Dataset: using normalization statistics calculated from loaded data ######')
|
95 |
+
|
96 |
+
if self.is_train:
|
97 |
+
if scan_per_load > 0: # buffer needed
|
98 |
+
self.pid_curr_load = np.random.choice( self.scan_ids, replace = False, size = self.scan_per_load)
|
99 |
+
else: # load the entire set without a buffer
|
100 |
+
self.pid_curr_load = self.scan_ids
|
101 |
+
elif mode == 'val':
|
102 |
+
self.pid_curr_load = self.scan_ids
|
103 |
+
self.potential_support_sid = []
|
104 |
+
else:
|
105 |
+
raise Exception
|
106 |
+
self.actual_dataset = self.read_dataset()
|
107 |
+
self.size = len(self.actual_dataset)
|
108 |
+
self.overall_slice_by_cls = self.read_classfiles()
|
109 |
+
self.update_subclass_lookup()
|
110 |
+
|
111 |
+
def get_scanids(self, mode, idx_split):
|
112 |
+
val_ids = copy.deepcopy(self.img_pids[self.sep[idx_split]: self.sep[idx_split + 1] + self.nsup])
|
113 |
+
self.potential_support_sid = val_ids[-self.nsup:] # this is actual file scan id, not index
|
114 |
+
if mode == 'train':
|
115 |
+
return [ ii for ii in self.img_pids if ii not in val_ids ]
|
116 |
+
elif mode == 'val':
|
117 |
+
return val_ids
|
118 |
+
|
119 |
+
def reload_buffer(self):
|
120 |
+
"""
|
121 |
+
Reload a portion of the entire dataset, if the dataset is too large
|
122 |
+
1. delete original buffer
|
123 |
+
2. update self.ids_this_batch
|
124 |
+
3. update other internel variables like __len__
|
125 |
+
"""
|
126 |
+
if self.scan_per_load <= 0:
|
127 |
+
print("We are not using the reload buffer, doing notiong")
|
128 |
+
return -1
|
129 |
+
|
130 |
+
del self.actual_dataset
|
131 |
+
del self.info_by_scan
|
132 |
+
self.pid_curr_load = np.random.choice( self.scan_ids, size = self.scan_per_load, replace = False )
|
133 |
+
self.actual_dataset = self.read_dataset()
|
134 |
+
self.size = len(self.actual_dataset)
|
135 |
+
self.update_subclass_lookup()
|
136 |
+
print(f'Loader buffer reloaded with a new size of {self.size} slices')
|
137 |
+
|
138 |
+
def organize_sample_fids(self):
|
139 |
+
out_list = {}
|
140 |
+
for curr_id in self.scan_ids:
|
141 |
+
curr_dict = {}
|
142 |
+
|
143 |
+
_img_fid = os.path.join(self.base_dir, f'image_{curr_id}.nii.gz')
|
144 |
+
_lb_fid = os.path.join(self.base_dir, f'label_{curr_id}.nii.gz')
|
145 |
+
|
146 |
+
curr_dict["img_fid"] = _img_fid
|
147 |
+
curr_dict["lbs_fid"] = _lb_fid
|
148 |
+
out_list[str(curr_id)] = curr_dict
|
149 |
+
return out_list
|
150 |
+
|
151 |
+
def read_dataset(self):
|
152 |
+
"""
|
153 |
+
Build index pointers to individual slices
|
154 |
+
Also keep a look-up table from scan_id, slice to index
|
155 |
+
"""
|
156 |
+
out_list = []
|
157 |
+
self.scan_z_idx = {}
|
158 |
+
self.info_by_scan = {} # meta data of each scan
|
159 |
+
glb_idx = 0 # global index of a certain slice in a certain scan in entire dataset
|
160 |
+
|
161 |
+
for scan_id, itm in self.img_lb_fids.items():
|
162 |
+
if scan_id not in self.pid_curr_load:
|
163 |
+
continue
|
164 |
+
|
165 |
+
img, _info = read_nii_bysitk(itm["img_fid"], peel_info = True) # get the meta information out
|
166 |
+
|
167 |
+
img = img.transpose(1,2,0)
|
168 |
+
|
169 |
+
self.info_by_scan[scan_id] = _info
|
170 |
+
|
171 |
+
if self.use_clahe:
|
172 |
+
img = np.stack([self.clahe.apply(slice.astype(np.uint8)) for slice in img], axis=0)
|
173 |
+
|
174 |
+
img = np.float32(img)
|
175 |
+
img = self.norm_func(img)
|
176 |
+
|
177 |
+
self.scan_z_idx[scan_id] = [-1 for _ in range(img.shape[-1])]
|
178 |
+
|
179 |
+
lb = read_nii_bysitk(itm["lbs_fid"])
|
180 |
+
lb = lb.transpose(1,2,0)
|
181 |
+
|
182 |
+
lb = np.float32(lb)
|
183 |
+
|
184 |
+
img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
|
185 |
+
lb = cv2.resize(lb, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
|
186 |
+
|
187 |
+
assert img.shape[-1] == lb.shape[-1]
|
188 |
+
base_idx = img.shape[-1] // 2 # index of the middle slice
|
189 |
+
|
190 |
+
# write the beginning frame
|
191 |
+
out_list.append( {"img": img[..., 0: 1],
|
192 |
+
"lb":lb[..., 0: 0 + 1],
|
193 |
+
"is_start": True,
|
194 |
+
"is_end": False,
|
195 |
+
"nframe": img.shape[-1],
|
196 |
+
"scan_id": scan_id,
|
197 |
+
"z_id":0})
|
198 |
+
|
199 |
+
self.scan_z_idx[scan_id][0] = glb_idx
|
200 |
+
glb_idx += 1
|
201 |
+
|
202 |
+
for ii in range(1, img.shape[-1] - 1):
|
203 |
+
out_list.append( {"img": img[..., ii: ii + 1],
|
204 |
+
"lb":lb[..., ii: ii + 1],
|
205 |
+
"is_start": False,
|
206 |
+
"is_end": False,
|
207 |
+
"nframe": -1,
|
208 |
+
"scan_id": scan_id,
|
209 |
+
"z_id": ii
|
210 |
+
})
|
211 |
+
self.scan_z_idx[scan_id][ii] = glb_idx
|
212 |
+
glb_idx += 1
|
213 |
+
|
214 |
+
ii += 1 # last frame, note the is_end flag
|
215 |
+
out_list.append( {"img": img[..., ii: ii + 1],
|
216 |
+
"lb":lb[..., ii: ii+ 1],
|
217 |
+
"is_start": False,
|
218 |
+
"is_end": True,
|
219 |
+
"nframe": -1,
|
220 |
+
"scan_id": scan_id,
|
221 |
+
"z_id": ii
|
222 |
+
})
|
223 |
+
|
224 |
+
self.scan_z_idx[scan_id][ii] = glb_idx
|
225 |
+
glb_idx += 1
|
226 |
+
|
227 |
+
return out_list
|
228 |
+
|
229 |
+
def read_classfiles(self):
|
230 |
+
with open( os.path.join(self.base_dir, f'.classmap_{self.min_fg}.json') , 'r' ) as fopen:
|
231 |
+
cls_map = json.load( fopen)
|
232 |
+
fopen.close()
|
233 |
+
|
234 |
+
with open( os.path.join(self.base_dir, '.classmap_1.json') , 'r' ) as fopen:
|
235 |
+
self.tp1_cls_map = json.load( fopen)
|
236 |
+
fopen.close()
|
237 |
+
|
238 |
+
return cls_map
|
239 |
+
|
240 |
+
def __getitem__(self, index):
|
241 |
+
if self.get_item_mode == MODE_DEFAULT:
|
242 |
+
return self.__getitem_default__(index)
|
243 |
+
elif self.get_item_mode == MODE_FULL_SCAN:
|
244 |
+
return self.__get_ct_scan___(index)
|
245 |
+
else:
|
246 |
+
raise NotImplementedError("Unknown mode")
|
247 |
+
|
248 |
+
|
249 |
+
def __get_ct_scan___(self, index):
|
250 |
+
scan_n = index % len(self.scan_z_idx)
|
251 |
+
scan_id = list(self.scan_z_idx.keys())[scan_n]
|
252 |
+
scan_slices = self.scan_z_idx[scan_id]
|
253 |
+
|
254 |
+
scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
|
255 |
+
|
256 |
+
scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
|
257 |
+
|
258 |
+
scan_imgs = np.float32(scan_imgs)
|
259 |
+
scan_lbs = np.float32(scan_lbs)
|
260 |
+
|
261 |
+
scan_imgs = torch.from_numpy(scan_imgs).unsqueeze(0)
|
262 |
+
scan_lbs = torch.from_numpy(scan_lbs)
|
263 |
+
|
264 |
+
if self.tile_z_dim:
|
265 |
+
scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
|
266 |
+
assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
|
267 |
+
|
268 |
+
# # reshape to C, D, H, W
|
269 |
+
# scan_imgs = scan_imgs.permute(1, 0, 2, 3)
|
270 |
+
# scan_lbs = scan_lbs.permute(1, 0, 2, 3)
|
271 |
+
|
272 |
+
sample = {"image": scan_imgs,
|
273 |
+
"label":scan_lbs,
|
274 |
+
"scan_id": scan_id,
|
275 |
+
}
|
276 |
+
|
277 |
+
return sample
|
278 |
+
|
279 |
+
|
280 |
+
def get_3_slice_adjacent_image(self, image_t, index):
|
281 |
+
curr_dict = self.actual_dataset[index]
|
282 |
+
prev_image = np.zeros_like(image_t)
|
283 |
+
|
284 |
+
if index > 1 and not curr_dict["is_start"]:
|
285 |
+
prev_dict = self.actual_dataset[index - 1]
|
286 |
+
prev_image = prev_dict["img"]
|
287 |
+
|
288 |
+
next_image = np.zeros_like(image_t)
|
289 |
+
if index < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
|
290 |
+
next_dict = self.actual_dataset[index + 1]
|
291 |
+
next_image = next_dict["img"]
|
292 |
+
|
293 |
+
image_t = np.concatenate([prev_image, image_t, next_image], axis=-1)
|
294 |
+
|
295 |
+
return image_t
|
296 |
+
|
297 |
+
|
298 |
+
def __getitem_default__(self, index):
|
299 |
+
index = index % len(self.actual_dataset)
|
300 |
+
curr_dict = self.actual_dataset[index]
|
301 |
+
if self.is_train:
|
302 |
+
if len(self.exclude_lbs) > 0:
|
303 |
+
for _ex_cls in self.exclude_lbs:
|
304 |
+
if curr_dict["z_id"] in self.tp1_cls_map[self.label_name[_ex_cls]][curr_dict["scan_id"]]: # this slice need to be excluded since it contains label which is supposed to be unseen
|
305 |
+
return self.__getitem__(index + torch.randint(low = 0, high = self.__len__() - 1, size = (1,)))
|
306 |
+
|
307 |
+
comp = np.concatenate( [curr_dict["img"], curr_dict["lb"]], axis = -1 )
|
308 |
+
if self.transforms is not None:
|
309 |
+
img, lb = self.transforms(comp, c_img = 1, c_label = 1, nclass = self.nclass, use_onehot = False)
|
310 |
+
else:
|
311 |
+
raise Exception("No transform function is provided")
|
312 |
+
|
313 |
+
else:
|
314 |
+
img = curr_dict['img']
|
315 |
+
lb = curr_dict['lb']
|
316 |
+
|
317 |
+
|
318 |
+
img = np.float32(img)
|
319 |
+
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
|
320 |
+
if self.use_3_slices:
|
321 |
+
img = self.get_3_slice_adjacent_image(img, index)
|
322 |
+
|
323 |
+
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
|
324 |
+
lb = torch.from_numpy( lb)
|
325 |
+
|
326 |
+
if self.tile_z_dim:
|
327 |
+
img = img.repeat( [ self.tile_z_dim, 1, 1] )
|
328 |
+
assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
|
329 |
+
|
330 |
+
is_start = curr_dict["is_start"]
|
331 |
+
is_end = curr_dict["is_end"]
|
332 |
+
nframe = np.int32(curr_dict["nframe"])
|
333 |
+
scan_id = curr_dict["scan_id"]
|
334 |
+
z_id = curr_dict["z_id"]
|
335 |
+
|
336 |
+
sample = {"image": img,
|
337 |
+
"label":lb,
|
338 |
+
"is_start": is_start,
|
339 |
+
"is_end": is_end,
|
340 |
+
"nframe": nframe,
|
341 |
+
"scan_id": scan_id,
|
342 |
+
"z_id": z_id
|
343 |
+
}
|
344 |
+
# Add auxiliary attributes
|
345 |
+
if self.aux_attrib is not None:
|
346 |
+
for key_prefix in self.aux_attrib:
|
347 |
+
# Process the data sample, create new attributes and save them in a dictionary
|
348 |
+
aux_attrib_val = self.aux_attrib[key_prefix](sample, **self.aux_attrib_args[key_prefix])
|
349 |
+
for key_suffix in aux_attrib_val:
|
350 |
+
# one function may create multiple attributes, so we need suffix to distinguish them
|
351 |
+
sample[key_prefix + '_' + key_suffix] = aux_attrib_val[key_suffix]
|
352 |
+
|
353 |
+
return sample
|
354 |
+
|
355 |
+
def __len__(self):
|
356 |
+
"""
|
357 |
+
copy-paste from basic naive dataset configuration
|
358 |
+
"""
|
359 |
+
if self.get_item_mode == MODE_FULL_SCAN:
|
360 |
+
return len(self.scan_z_idx)
|
361 |
+
|
362 |
+
if self.fix_length != None:
|
363 |
+
assert self.fix_length >= len(self.actual_dataset)
|
364 |
+
return self.fix_length
|
365 |
+
else:
|
366 |
+
return len(self.actual_dataset)
|
367 |
+
|
368 |
+
def update_subclass_lookup(self):
|
369 |
+
"""
|
370 |
+
Updating the class-slice indexing list
|
371 |
+
Args:
|
372 |
+
[internal] overall_slice_by_cls:
|
373 |
+
{
|
374 |
+
class1: {pid1: [slice1, slice2, ....],
|
375 |
+
pid2: [slice1, slice2]},
|
376 |
+
...}
|
377 |
+
class2:
|
378 |
+
...
|
379 |
+
}
|
380 |
+
out[internal]:
|
381 |
+
{
|
382 |
+
class1: [ idx1, idx2, ... ],
|
383 |
+
class2: [ idx1, idx2, ... ],
|
384 |
+
...
|
385 |
+
}
|
386 |
+
|
387 |
+
"""
|
388 |
+
# delete previous ones if any
|
389 |
+
assert self.overall_slice_by_cls is not None
|
390 |
+
|
391 |
+
if not hasattr(self, 'idx_by_class'):
|
392 |
+
self.idx_by_class = {}
|
393 |
+
# filter the new one given the actual list
|
394 |
+
for cls in self.label_name:
|
395 |
+
if cls not in self.idx_by_class.keys():
|
396 |
+
self.idx_by_class[cls] = []
|
397 |
+
else:
|
398 |
+
del self.idx_by_class[cls][:]
|
399 |
+
for cls, dict_by_pid in self.overall_slice_by_cls.items():
|
400 |
+
for pid, slice_list in dict_by_pid.items():
|
401 |
+
if pid not in self.pid_curr_load:
|
402 |
+
continue
|
403 |
+
self.idx_by_class[cls] += [ self.scan_z_idx[pid][_sli] for _sli in slice_list ]
|
404 |
+
print("###### index-by-class table has been reloaded ######")
|
405 |
+
|
406 |
+
def getMaskMedImg(self, label, class_id, class_ids):
|
407 |
+
"""
|
408 |
+
Generate FG/BG mask from the segmentation mask. Used when getting the support
|
409 |
+
"""
|
410 |
+
# Dense Mask
|
411 |
+
fg_mask = torch.where(label == class_id,
|
412 |
+
torch.ones_like(label), torch.zeros_like(label))
|
413 |
+
bg_mask = torch.where(label != class_id,
|
414 |
+
torch.ones_like(label), torch.zeros_like(label))
|
415 |
+
for class_id in class_ids:
|
416 |
+
bg_mask[label == class_id] = 0
|
417 |
+
|
418 |
+
return {'fg_mask': fg_mask,
|
419 |
+
'bg_mask': bg_mask}
|
420 |
+
|
421 |
+
def subsets(self, sub_args_lst=None):
|
422 |
+
"""
|
423 |
+
Override base-class subset method
|
424 |
+
Create subsets by scan_ids
|
425 |
+
|
426 |
+
output: list [[<fid in each class>] <class1>, <class2> ]
|
427 |
+
"""
|
428 |
+
|
429 |
+
if sub_args_lst is not None:
|
430 |
+
subsets = []
|
431 |
+
ii = 0
|
432 |
+
for cls_name, index_list in self.idx_by_class.items():
|
433 |
+
subsets.append( Subset(dataset = self, indices = index_list, sub_attrib_args = sub_args_lst[ii]) )
|
434 |
+
ii += 1
|
435 |
+
else:
|
436 |
+
subsets = [Subset(dataset=self, indices=index_list) for _, index_list in self.idx_by_class.items()]
|
437 |
+
return subsets
|
438 |
+
|
439 |
+
def get_support(self, curr_class: int, class_idx: list, scan_idx: list, npart: int):
|
440 |
+
"""
|
441 |
+
getting (probably multi-shot) support set for evaluation
|
442 |
+
sample from 50% (1shot) or 20 35 50 65 80 (5shot)
|
443 |
+
Args:
|
444 |
+
curr_cls: current class to segment, starts from 1
|
445 |
+
class_idx: a list of all foreground class in nways, starts from 1
|
446 |
+
npart: how may chunks used to split the support
|
447 |
+
scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
|
448 |
+
being served as support, in self.pid_curr_load
|
449 |
+
"""
|
450 |
+
assert npart % 2 == 1
|
451 |
+
assert curr_class != 0; assert 0 not in class_idx
|
452 |
+
# assert not self.is_train
|
453 |
+
|
454 |
+
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
|
455 |
+
# print(f'###### Using {len(scan_idx)} shot evaluation!')
|
456 |
+
|
457 |
+
if npart == 1:
|
458 |
+
pcts = [0.5]
|
459 |
+
else:
|
460 |
+
half_part = 1 / (npart * 2)
|
461 |
+
part_interval = (1.0 - 1.0 / npart) / (npart - 1)
|
462 |
+
pcts = [ half_part + part_interval * ii for ii in range(npart) ]
|
463 |
+
|
464 |
+
# print(f'###### Parts percentage: {pcts} ######')
|
465 |
+
|
466 |
+
# norm_func = get_normalize_op(modality='MR', fids=None)
|
467 |
+
out_buffer = [] # [{scanid, img, lb}]
|
468 |
+
for _part in range(npart):
|
469 |
+
concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
|
470 |
+
for scan_order in scan_idx:
|
471 |
+
_scan_id = self.pid_curr_load[ scan_order ]
|
472 |
+
print(f'Using scan {_scan_id} as support!')
|
473 |
+
|
474 |
+
# for _pc in pcts:
|
475 |
+
_zlist = self.tp1_cls_map[self.label_name[curr_class]][_scan_id] # list of indices
|
476 |
+
_zid = _zlist[int(pcts[_part] * len(_zlist))]
|
477 |
+
_glb_idx = self.scan_z_idx[_scan_id][_zid]
|
478 |
+
|
479 |
+
# almost copy-paste __getitem__ but no augmentation
|
480 |
+
curr_dict = self.actual_dataset[_glb_idx]
|
481 |
+
img = curr_dict['img']
|
482 |
+
lb = curr_dict['lb']
|
483 |
+
|
484 |
+
if self.use_3_slices:
|
485 |
+
prev_image = np.zeros_like(img)
|
486 |
+
if _glb_idx > 1 and not curr_dict["is_start"]:
|
487 |
+
prev_dict = self.actual_dataset[_glb_idx - 1]
|
488 |
+
prev_image = prev_dict["img"]
|
489 |
+
|
490 |
+
next_image = np.zeros_like(img)
|
491 |
+
if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
|
492 |
+
next_dict = self.actual_dataset[_glb_idx + 1]
|
493 |
+
next_image = next_dict["img"]
|
494 |
+
|
495 |
+
img = np.concatenate([prev_image, img, next_image], axis=-1)
|
496 |
+
|
497 |
+
img = np.float32(img)
|
498 |
+
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
|
499 |
+
|
500 |
+
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
|
501 |
+
lb = torch.from_numpy( lb )
|
502 |
+
|
503 |
+
if self.tile_z_dim:
|
504 |
+
img = img.repeat( [ self.tile_z_dim, 1, 1] )
|
505 |
+
assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
|
506 |
+
|
507 |
+
is_start = curr_dict["is_start"]
|
508 |
+
is_end = curr_dict["is_end"]
|
509 |
+
nframe = np.int32(curr_dict["nframe"])
|
510 |
+
scan_id = curr_dict["scan_id"]
|
511 |
+
z_id = curr_dict["z_id"]
|
512 |
+
|
513 |
+
sample = {"image": img,
|
514 |
+
"label":lb,
|
515 |
+
"is_start": is_start,
|
516 |
+
"inst": None,
|
517 |
+
"scribble": None,
|
518 |
+
"is_end": is_end,
|
519 |
+
"nframe": nframe,
|
520 |
+
"scan_id": scan_id,
|
521 |
+
"z_id": z_id
|
522 |
+
}
|
523 |
+
|
524 |
+
concat_buffer.append(sample)
|
525 |
+
out_buffer.append({
|
526 |
+
"image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
|
527 |
+
"label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
|
528 |
+
|
529 |
+
})
|
530 |
+
|
531 |
+
# do the concat, and add to output_buffer
|
532 |
+
|
533 |
+
# post-processing, including keeping the foreground and suppressing background.
|
534 |
+
support_images = []
|
535 |
+
support_mask = []
|
536 |
+
support_class = []
|
537 |
+
for itm in out_buffer:
|
538 |
+
support_images.append(itm["image"])
|
539 |
+
support_class.append(curr_class)
|
540 |
+
support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
|
541 |
+
|
542 |
+
return {'class_ids': [support_class],
|
543 |
+
'support_images': [support_images], #
|
544 |
+
'support_mask': [support_mask],
|
545 |
+
}
|
546 |
+
|
547 |
+
def get_support_scan(self, curr_class: int, class_idx: list, scan_idx: list):
|
548 |
+
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
|
549 |
+
# print(f'###### Using {len(scan_idx)} shot evaluation!')
|
550 |
+
scan_slices = self.scan_z_idx[self.potential_support_sid[0]]
|
551 |
+
scan_imgs = np.concatenate([self.actual_dataset[_idx]["img"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
|
552 |
+
|
553 |
+
scan_lbs = np.concatenate([self.actual_dataset[_idx]["lb"] for _idx in scan_slices], axis = -1).transpose(2, 0, 1)
|
554 |
+
# binarize the labels
|
555 |
+
scan_lbs[scan_lbs != curr_class] = 0
|
556 |
+
scan_lbs[scan_lbs == curr_class] = 1
|
557 |
+
|
558 |
+
scan_imgs = torch.from_numpy(np.float32(scan_imgs)).unsqueeze(0)
|
559 |
+
scan_lbs = torch.from_numpy(np.float32(scan_lbs))
|
560 |
+
|
561 |
+
if self.tile_z_dim:
|
562 |
+
scan_imgs = scan_imgs.repeat(self.tile_z_dim, 1, 1, 1)
|
563 |
+
assert scan_imgs.ndimension() == 4, f'actual dim {scan_imgs.ndimension()}'
|
564 |
+
|
565 |
+
# reshape to C, D, H, W
|
566 |
+
sample = {"scan": scan_imgs,
|
567 |
+
"labels":scan_lbs,
|
568 |
+
}
|
569 |
+
|
570 |
+
return sample
|
571 |
+
|
572 |
+
|
573 |
+
def get_support_multiple_classes(self, classes: list, scan_idx: list, npart: int, use_3_slices=False):
|
574 |
+
"""
|
575 |
+
getting (probably multi-shot) support set for evaluation
|
576 |
+
sample from 50% (1shot) or 20 35 50 65 80 (5shot)
|
577 |
+
Args:
|
578 |
+
curr_cls: current class to segment, starts from 1
|
579 |
+
class_idx: a list of all foreground class in nways, starts from 1
|
580 |
+
npart: how may chunks used to split the support
|
581 |
+
scan_idx: a list, indicating the current **i_th** (note this is idx not pid) training scan
|
582 |
+
being served as support, in self.pid_curr_load
|
583 |
+
"""
|
584 |
+
assert npart % 2 == 1
|
585 |
+
# assert curr_class != 0; assert 0 not in class_idx
|
586 |
+
# assert not self.is_train
|
587 |
+
|
588 |
+
self.potential_support_sid = [self.pid_curr_load[ii] for ii in scan_idx ]
|
589 |
+
# print(f'###### Using {len(scan_idx)} shot evaluation!')
|
590 |
+
|
591 |
+
if npart == 1:
|
592 |
+
pcts = [0.5]
|
593 |
+
else:
|
594 |
+
half_part = 1 / (npart * 2)
|
595 |
+
part_interval = (1.0 - 1.0 / npart) / (npart - 1)
|
596 |
+
pcts = [ half_part + part_interval * ii for ii in range(npart) ]
|
597 |
+
|
598 |
+
# print(f'###### Parts percentage: {pcts} ######')
|
599 |
+
|
600 |
+
out_buffer = [] # [{scanid, img, lb}]
|
601 |
+
for _part in range(npart):
|
602 |
+
concat_buffer = [] # for each fold do a concat in image and mask in batch dimension
|
603 |
+
for scan_order in scan_idx:
|
604 |
+
_scan_id = self.pid_curr_load[ scan_order ]
|
605 |
+
print(f'Using scan {_scan_id} as support!')
|
606 |
+
|
607 |
+
# for _pc in pcts:
|
608 |
+
zlist = []
|
609 |
+
for curr_class in classes:
|
610 |
+
zlist.append(self.tp1_cls_map[self.label_name[curr_class]][_scan_id]) # list of indices
|
611 |
+
# merge all the lists in zlist and keep only the unique elements
|
612 |
+
# _zlist = sorted(list(set([item for sublist in zlist for item in sublist])))
|
613 |
+
# take only the indices that appear in all of the sublist
|
614 |
+
_zlist = sorted(list(set.intersection(*map(set, zlist))))
|
615 |
+
_zid = _zlist[int(pcts[_part] * len(_zlist))]
|
616 |
+
_glb_idx = self.scan_z_idx[_scan_id][_zid]
|
617 |
+
|
618 |
+
# almost copy-paste __getitem__ but no augmentation
|
619 |
+
curr_dict = self.actual_dataset[_glb_idx]
|
620 |
+
img = curr_dict['img']
|
621 |
+
lb = curr_dict['lb']
|
622 |
+
|
623 |
+
if use_3_slices:
|
624 |
+
prev_image = np.zeros_like(img)
|
625 |
+
if _glb_idx > 1 and not curr_dict["is_start"]:
|
626 |
+
prev_dict = self.actual_dataset[_glb_idx - 1]
|
627 |
+
assert prev_dict["scan_id"] == curr_dict["scan_id"]
|
628 |
+
assert prev_dict["z_id"] == curr_dict["z_id"] - 1
|
629 |
+
prev_image = prev_dict["img"]
|
630 |
+
|
631 |
+
next_image = np.zeros_like(img)
|
632 |
+
if _glb_idx < len(self.actual_dataset) - 1 and not curr_dict["is_end"]:
|
633 |
+
next_dict = self.actual_dataset[_glb_idx + 1]
|
634 |
+
assert next_dict["scan_id"] == curr_dict["scan_id"]
|
635 |
+
assert next_dict["z_id"] == curr_dict["z_id"] + 1
|
636 |
+
next_image = next_dict["img"]
|
637 |
+
|
638 |
+
img = np.concatenate([prev_image, img, next_image], axis=-1)
|
639 |
+
|
640 |
+
img = np.float32(img)
|
641 |
+
lb = np.float32(lb).squeeze(-1) # NOTE: to be suitable for the PANet structure
|
642 |
+
# zero all labels that are not in the classes arg
|
643 |
+
mask = np.zeros_like(lb)
|
644 |
+
for cls in classes:
|
645 |
+
mask[lb == cls] = 1
|
646 |
+
lb[~mask.astype(np.bool)] = 0
|
647 |
+
|
648 |
+
img = torch.from_numpy( np.transpose(img, (2, 0, 1)) )
|
649 |
+
lb = torch.from_numpy( lb )
|
650 |
+
|
651 |
+
if self.tile_z_dim:
|
652 |
+
img = img.repeat( [ self.tile_z_dim, 1, 1] )
|
653 |
+
assert img.ndimension() == 3, f'actual dim {img.ndimension()}'
|
654 |
+
|
655 |
+
is_start = curr_dict["is_start"]
|
656 |
+
is_end = curr_dict["is_end"]
|
657 |
+
nframe = np.int32(curr_dict["nframe"])
|
658 |
+
scan_id = curr_dict["scan_id"]
|
659 |
+
z_id = curr_dict["z_id"]
|
660 |
+
|
661 |
+
sample = {"image": img,
|
662 |
+
"label":lb,
|
663 |
+
"is_start": is_start,
|
664 |
+
"inst": None,
|
665 |
+
"scribble": None,
|
666 |
+
"is_end": is_end,
|
667 |
+
"nframe": nframe,
|
668 |
+
"scan_id": scan_id,
|
669 |
+
"z_id": z_id
|
670 |
+
}
|
671 |
+
|
672 |
+
concat_buffer.append(sample)
|
673 |
+
out_buffer.append({
|
674 |
+
"image": torch.stack([itm["image"] for itm in concat_buffer], dim = 0),
|
675 |
+
"label": torch.stack([itm["label"] for itm in concat_buffer], dim = 0),
|
676 |
+
|
677 |
+
})
|
678 |
+
|
679 |
+
# do the concat, and add to output_buffer
|
680 |
+
|
681 |
+
# post-processing, including keeping the foreground and suppressing background.
|
682 |
+
support_images = []
|
683 |
+
support_mask = []
|
684 |
+
support_class = []
|
685 |
+
for itm in out_buffer:
|
686 |
+
support_images.append(itm["image"])
|
687 |
+
support_class.append(curr_class)
|
688 |
+
# support_mask.append( self.getMaskMedImg( itm["label"], curr_class, class_idx ))
|
689 |
+
support_mask.append(itm["label"])
|
690 |
+
|
691 |
+
return {'class_ids': [support_class],
|
692 |
+
'support_images': [support_images], #
|
693 |
+
'support_mask': [support_mask],
|
694 |
+
'scan_id': scan_id
|
695 |
+
}
|
696 |
+
|
697 |
+
def get_nii_dataset(config, image_size, **kwargs):
|
698 |
+
print(f"Check config: {config}")
|
699 |
+
organ_mapping = {
|
700 |
+
"sabs":{
|
701 |
+
"rk": 2,
|
702 |
+
"lk": 3,
|
703 |
+
"liver": 6,
|
704 |
+
"spleen": 1
|
705 |
+
},
|
706 |
+
"chaost2":{
|
707 |
+
"liver": 1,
|
708 |
+
"rk": 2,
|
709 |
+
"lk": 3,
|
710 |
+
"spleen": 4
|
711 |
+
}}
|
712 |
+
|
713 |
+
transforms = None
|
714 |
+
data_name = config['dataset']
|
715 |
+
if data_name == 'SABS_Superpix' or data_name == 'SABS_Superpix_448' or data_name == 'SABS_Superpix_672':
|
716 |
+
baseset_name = 'SABS'
|
717 |
+
max_label = 13
|
718 |
+
modality="CT"
|
719 |
+
elif data_name == 'C0_Superpix':
|
720 |
+
raise NotImplementedError
|
721 |
+
baseset_name = 'C0'
|
722 |
+
max_label = 3
|
723 |
+
elif data_name == 'CHAOST2_Superpix' or data_name == 'CHAOST2_Superpix_672':
|
724 |
+
baseset_name = 'CHAOST2'
|
725 |
+
max_label = 4
|
726 |
+
modality="MR"
|
727 |
+
elif 'lits' in data_name.lower():
|
728 |
+
baseset_name = 'LITS17'
|
729 |
+
max_label = 4
|
730 |
+
else:
|
731 |
+
raise ValueError(f'Dataset: {data_name} not found')
|
732 |
+
|
733 |
+
# norm_func = get_normalize_op(modality=modality, fids=None) # TODO add global statistics
|
734 |
+
# norm_func = None
|
735 |
+
|
736 |
+
test_label = organ_mapping[baseset_name.lower()][config["curr_cls"]]
|
737 |
+
base_dir = config['path'][data_name]['data_dir']
|
738 |
+
testdataset = ManualAnnoDataset(which_dataset=baseset_name,
|
739 |
+
base_dir=base_dir,
|
740 |
+
idx_split = config['eval_fold'],
|
741 |
+
mode = 'val',
|
742 |
+
scan_per_load = 1,
|
743 |
+
transforms=transforms,
|
744 |
+
min_fg=1,
|
745 |
+
nsup = config["task"]["n_shots"],
|
746 |
+
fix_length=None,
|
747 |
+
image_size=image_size,
|
748 |
+
# extern_normalize_func=norm_func
|
749 |
+
**kwargs)
|
750 |
+
|
751 |
+
testdataset = ValidationDataset(testdataset, test_classes = [test_label], npart = config["task"]["npart"])
|
752 |
+
testdataset.set_curr_cls(test_label)
|
753 |
+
|
754 |
+
traindataset = None # TODO make this the support set later
|
755 |
+
|
756 |
+
return traindataset, testdataset
|
dataloaders/PolypDataset.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from https://github.com/talshaharabany/AutoSAM
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
import torch.utils.data as data
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
import torch
|
12 |
+
from dataloaders.PolypTransforms import get_polyp_transform
|
13 |
+
import cv2
|
14 |
+
KVASIR = "Kvasir"
|
15 |
+
CLINIC_DB = "CVC-ClinicDB"
|
16 |
+
COLON_DB = "CVC-ColonDB"
|
17 |
+
ETIS_DB = "ETIS-LaribPolypDB"
|
18 |
+
CVC300 = "CVC-300"
|
19 |
+
|
20 |
+
DATASETS = (KVASIR, CLINIC_DB, COLON_DB, ETIS_DB)
|
21 |
+
EXCLUDE_DS = (CVC300, )
|
22 |
+
|
23 |
+
|
24 |
+
def create_suppport_set_for_polyps(n_support=10):
|
25 |
+
"""
|
26 |
+
create a text file contating n_support_images for each dataset
|
27 |
+
"""
|
28 |
+
root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/TrainDataset"
|
29 |
+
supp_images = []
|
30 |
+
supp_masks = []
|
31 |
+
|
32 |
+
image_dir = os.path.join(root_dir, "images")
|
33 |
+
mask_dir = os.path.join(root_dir, "masks")
|
34 |
+
# randonly sample n_support images and masks
|
35 |
+
image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(
|
36 |
+
image_dir) if f.endswith('.jpg') or f.endswith('.png')])
|
37 |
+
mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(
|
38 |
+
mask_dir) if f.endswith('.png')])
|
39 |
+
|
40 |
+
while len(supp_images) < n_support:
|
41 |
+
index = random.randint(0, len(image_paths) - 1)
|
42 |
+
# check that the index is not already in the support set
|
43 |
+
if image_paths[index] in supp_images:
|
44 |
+
continue
|
45 |
+
supp_images.append(image_paths[index])
|
46 |
+
supp_masks.append(mask_paths[index])
|
47 |
+
|
48 |
+
with open(os.path.join(root_dir, "support.txt"), 'w') as file:
|
49 |
+
for image_path, mask_path in zip(supp_images, supp_masks):
|
50 |
+
file.write(f"{image_path} {mask_path}\n")
|
51 |
+
|
52 |
+
def create_train_val_test_split_for_polyps():
|
53 |
+
root_dir = "/disk4/Lev/Projects/Self-supervised-Fewshot-Medical-Image-Segmentation/data/PolypDataset/"
|
54 |
+
# for each subdir in root_dir, create a split file
|
55 |
+
num_train_images_per_dataset = {
|
56 |
+
"CVC-ClinicDB": 548, "Kvasir": 900, "CVC-300": 0, "CVC-ColonDB": 0}
|
57 |
+
|
58 |
+
num_test_images_per_dataset = {
|
59 |
+
"CVC-ClinicDB": 64, "Kvasir": 100, "CVC-300": 60, "CVC-ColonDB": 380}
|
60 |
+
|
61 |
+
for subdir in os.listdir(root_dir):
|
62 |
+
subdir_path = os.path.join(root_dir, subdir)
|
63 |
+
if os.path.isdir(subdir_path):
|
64 |
+
split_file = os.path.join(subdir_path, "split.txt")
|
65 |
+
image_dir = os.path.join(subdir_path, "images")
|
66 |
+
create_train_val_test_split(
|
67 |
+
image_dir, split_file, train_number=num_train_images_per_dataset[subdir], test_number=num_test_images_per_dataset[subdir])
|
68 |
+
|
69 |
+
|
70 |
+
def create_train_val_test_split(root, split_file, train_number=100, test_number=20):
|
71 |
+
"""
|
72 |
+
Create a train, val, test split file for a dataset
|
73 |
+
root: root directory of dataset
|
74 |
+
split_file: name of split file to create
|
75 |
+
train_ratio: ratio of train set
|
76 |
+
val_ratio: ratio of val set
|
77 |
+
test_ratio: ratio of test set
|
78 |
+
"""
|
79 |
+
# Get all files in root directory
|
80 |
+
files = os.listdir(root)
|
81 |
+
# Filter out non-image files, remove suffix
|
82 |
+
files = [f.split('.')[0]
|
83 |
+
for f in files if f.endswith('.jpg') or f.endswith('.png')]
|
84 |
+
# Shuffle files
|
85 |
+
random.shuffle(files)
|
86 |
+
|
87 |
+
# Calculate number of files for each split
|
88 |
+
num_files = len(files)
|
89 |
+
num_train = train_number
|
90 |
+
num_test = test_number
|
91 |
+
num_val = num_files - num_train - num_test
|
92 |
+
print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")
|
93 |
+
# Create splits
|
94 |
+
train = files[:num_train]
|
95 |
+
val = files[num_train:num_train + num_val]
|
96 |
+
test = files[num_train + num_val:]
|
97 |
+
|
98 |
+
# Write splits to file
|
99 |
+
with open(split_file, 'w') as file:
|
100 |
+
file.write("train\n")
|
101 |
+
for f in train:
|
102 |
+
file.write(f + "\n")
|
103 |
+
file.write("val\n")
|
104 |
+
for f in val:
|
105 |
+
file.write(f + "\n")
|
106 |
+
file.write("test\n")
|
107 |
+
for f in test:
|
108 |
+
file.write(f + "\n")
|
109 |
+
|
110 |
+
|
111 |
+
class PolypDataset(data.Dataset):
|
112 |
+
"""
|
113 |
+
dataloader for polyp segmentation tasks
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
|
117 |
+
self.trainsize = trainsize
|
118 |
+
self.augmentations = augmentations
|
119 |
+
self.datasets = datasets
|
120 |
+
if isinstance(image_size, int):
|
121 |
+
image_size = (image_size, image_size)
|
122 |
+
self.image_size = image_size
|
123 |
+
if image_root is not None and gt_root is not None:
|
124 |
+
self.images = [
|
125 |
+
os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
|
126 |
+
self.gts = [
|
127 |
+
os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')]
|
128 |
+
# also look in subdirectories
|
129 |
+
for subdir in os.listdir(image_root):
|
130 |
+
# if not dir, continue
|
131 |
+
if not os.path.isdir(os.path.join(image_root, subdir)):
|
132 |
+
continue
|
133 |
+
subdir_image_root = os.path.join(image_root, subdir)
|
134 |
+
subdir_gt_root = os.path.join(gt_root, subdir)
|
135 |
+
self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
|
136 |
+
subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
|
137 |
+
self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
|
138 |
+
subdir_gt_root) if f.endswith('.png')])
|
139 |
+
|
140 |
+
else:
|
141 |
+
self.images, self.gts = self.get_image_gt_pairs(
|
142 |
+
root, split="train" if train else "test", datasets=self.datasets)
|
143 |
+
self.images = sorted(self.images)
|
144 |
+
self.gts = sorted(self.gts)
|
145 |
+
if not 'VPS' in root:
|
146 |
+
self.filter_files_and_get_ds_mean_and_std()
|
147 |
+
if ds_mean is not None and ds_std is not None:
|
148 |
+
self.mean, self.std = ds_mean, ds_std
|
149 |
+
self.size = len(self.images)
|
150 |
+
self.train = train
|
151 |
+
self.sam_trans = sam_trans
|
152 |
+
if self.sam_trans is not None:
|
153 |
+
# sam trans takes care of norm
|
154 |
+
self.mean, self.std = 0 , 1
|
155 |
+
|
156 |
+
def get_image_gt_pairs(self, dir_root: str, split="train", datasets: tuple = DATASETS):
|
157 |
+
"""
|
158 |
+
for each folder in dir root, get all image-gt pairs. Assumes each subdir has a split.txt file
|
159 |
+
dir_root: root directory of all subdirectories, each subdirectory contains images and masks folders
|
160 |
+
split: train, val, or test
|
161 |
+
"""
|
162 |
+
image_paths = []
|
163 |
+
gt_paths = []
|
164 |
+
for folder in os.listdir(dir_root):
|
165 |
+
if folder not in datasets:
|
166 |
+
continue
|
167 |
+
split_file = os.path.join(dir_root, folder, "split.txt")
|
168 |
+
if os.path.isfile(split_file):
|
169 |
+
image_root = os.path.join(dir_root, folder, "images")
|
170 |
+
gt_root = os.path.join(dir_root, folder, "masks")
|
171 |
+
image_paths_tmp, gt_paths_tmp = self.get_image_gt_pairs_from_text_file(
|
172 |
+
image_root, gt_root, split_file, split=split)
|
173 |
+
image_paths.extend(image_paths_tmp)
|
174 |
+
gt_paths.extend(gt_paths_tmp)
|
175 |
+
else:
|
176 |
+
print(
|
177 |
+
f"No split.txt file found in {os.path.join(dir_root, folder)}")
|
178 |
+
|
179 |
+
return image_paths, gt_paths
|
180 |
+
|
181 |
+
def get_image_gt_pairs_from_text_file(self, image_root: str, gt_root: str, text_file: str, split: str = "train"):
|
182 |
+
"""
|
183 |
+
image_root: root directory of images
|
184 |
+
gt_root: root directory of ground truth
|
185 |
+
text_file: text file containing train, val, test split with the following format:
|
186 |
+
train:
|
187 |
+
image1
|
188 |
+
image2
|
189 |
+
...
|
190 |
+
val:
|
191 |
+
image1
|
192 |
+
image2
|
193 |
+
...
|
194 |
+
test:
|
195 |
+
image1
|
196 |
+
image2
|
197 |
+
...
|
198 |
+
|
199 |
+
split: train, val, or test
|
200 |
+
"""
|
201 |
+
# Initialize a dictionary to hold file names for each split
|
202 |
+
splits = {"train": [], "val": [], "test": []}
|
203 |
+
current_split = None
|
204 |
+
|
205 |
+
# Read the text file and categorize file names under each split
|
206 |
+
with open(text_file, 'r') as file:
|
207 |
+
for line in file:
|
208 |
+
line = line.strip()
|
209 |
+
if line in splits:
|
210 |
+
current_split = line
|
211 |
+
elif line and current_split:
|
212 |
+
splits[current_split].append(line)
|
213 |
+
|
214 |
+
# Get the file names for the requested split
|
215 |
+
file_names = splits[split]
|
216 |
+
|
217 |
+
# Create image-ground truth pairs
|
218 |
+
image_paths = []
|
219 |
+
gt_paths = []
|
220 |
+
for name in file_names:
|
221 |
+
image_path = os.path.join(image_root, name + '.png')
|
222 |
+
gt_path = os.path.join(gt_root, name + '.png')
|
223 |
+
image_paths.append(image_path)
|
224 |
+
gt_paths.append(gt_path)
|
225 |
+
|
226 |
+
return image_paths, gt_paths
|
227 |
+
|
228 |
+
def get_support_from_dirs(self, support_image_dir, support_mask_dir, n_support=1):
|
229 |
+
support_images = []
|
230 |
+
support_labels = []
|
231 |
+
# get all images and masks
|
232 |
+
support_image_paths = sorted([os.path.join(support_image_dir, f) for f in os.listdir(
|
233 |
+
support_image_dir) if f.endswith('.jpg') or f.endswith('.png')])
|
234 |
+
support_mask_paths = sorted([os.path.join(support_mask_dir, f) for f in os.listdir(
|
235 |
+
support_mask_dir) if f.endswith('.png')])
|
236 |
+
# sample n_support images and masks
|
237 |
+
for i in range(n_support):
|
238 |
+
index = random.randint(0, len(support_image_paths) - 1)
|
239 |
+
support_img = self.cv2_loader(
|
240 |
+
support_image_paths[index], is_mask=False)
|
241 |
+
support_mask = self.cv2_loader(
|
242 |
+
support_mask_paths[index], is_mask=True)
|
243 |
+
support_images.append(support_img)
|
244 |
+
support_labels.append(support_mask)
|
245 |
+
|
246 |
+
if self.augmentations:
|
247 |
+
support_images = [self.augmentations(
|
248 |
+
img, mask)[0] for img, mask in zip(support_images, support_labels)]
|
249 |
+
support_labels = [self.augmentations(
|
250 |
+
img, mask)[1] for img, mask in zip(support_images, support_labels)]
|
251 |
+
|
252 |
+
support_images = [(support_image - self.mean) / self.std if support_image.max() == 255 and support_image.min() == 0 else support_image for support_image in support_images]
|
253 |
+
|
254 |
+
if self.sam_trans is not None:
|
255 |
+
support_images = [self.sam_trans.preprocess(
|
256 |
+
img).squeeze(0) for img in support_images]
|
257 |
+
support_labels = [self.sam_trans.preprocess(
|
258 |
+
mask) for mask in support_labels]
|
259 |
+
else:
|
260 |
+
image_size = self.image_size
|
261 |
+
support_images = [torch.nn.functional.interpolate(img.unsqueeze(
|
262 |
+
0), size=image_size, mode='bilinear', align_corners=False).squeeze(0) for img in support_images]
|
263 |
+
support_labels = [torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
|
264 |
+
0), size=image_size, mode='nearest').squeeze(0).squeeze(0) for mask in support_labels]
|
265 |
+
|
266 |
+
return torch.stack(support_images), torch.stack(support_labels)
|
267 |
+
|
268 |
+
def get_support_from_text_file(self, text_file, n_support=1):
|
269 |
+
"""
|
270 |
+
each row in the file has 2 paths divided by space, the first is the image path and the second is the mask path
|
271 |
+
"""
|
272 |
+
support_images = []
|
273 |
+
support_labels = []
|
274 |
+
with open(text_file, 'r') as file:
|
275 |
+
for line in file:
|
276 |
+
image_path, mask_path = line.strip().split()
|
277 |
+
support_images.append(image_path)
|
278 |
+
support_labels.append(mask_path)
|
279 |
+
|
280 |
+
# indices = random.choices(range(len(support_images)), k=n_support)
|
281 |
+
if n_support > len(support_images):
|
282 |
+
raise ValueError(f"n_support ({n_support}) is larger than the number of images in the text file ({len(support_images)})")
|
283 |
+
|
284 |
+
n_support_images = support_images[:n_support]
|
285 |
+
n_support_labels = support_labels[:n_support]
|
286 |
+
|
287 |
+
return n_support_images, n_support_labels
|
288 |
+
|
289 |
+
def get_support(self, n_support=1, support_image_dir=None, support_mask_dir=None, text_file=None):
|
290 |
+
"""
|
291 |
+
Get support set from specified directories, text file or from the dataset itself
|
292 |
+
"""
|
293 |
+
if support_image_dir is not None and support_mask_dir:
|
294 |
+
return self.get_support_from_dirs(support_image_dir, support_mask_dir, n_support=n_support)
|
295 |
+
elif text_file is not None:
|
296 |
+
support_image_paths, support_gt_paths = self.get_support_from_text_file(text_file, n_support=n_support)
|
297 |
+
else:
|
298 |
+
# randomly sample n_support images and masks from the dataset
|
299 |
+
indices = random.choices(range(self.size), k=n_support)
|
300 |
+
# indices = list(range(n_support))
|
301 |
+
print(f"support indices:{indices}")
|
302 |
+
support_image_paths = [self.images[index] for index in indices]
|
303 |
+
support_gt_paths = [self.gts[index] for index in indices]
|
304 |
+
|
305 |
+
support_images = []
|
306 |
+
support_gts = []
|
307 |
+
|
308 |
+
for image_path, gt_path in zip(support_image_paths, support_gt_paths):
|
309 |
+
support_img = self.cv2_loader(image_path, is_mask=False)
|
310 |
+
support_mask = self.cv2_loader(gt_path, is_mask=True)
|
311 |
+
out = self.process_image_gt(support_img, support_mask)
|
312 |
+
support_images.append(out['image'].unsqueeze(0))
|
313 |
+
support_gts.append(out['label'].unsqueeze(0))
|
314 |
+
if len(support_images) >= n_support:
|
315 |
+
break
|
316 |
+
return support_images, support_gts, out['case']
|
317 |
+
# return torch.stack(support_images), torch.stack(support_gts), out['case']
|
318 |
+
|
319 |
+
def process_image_gt(self, image, gt, dataset=""):
|
320 |
+
"""
|
321 |
+
image and gt are expected to be output from self.cv2_loader
|
322 |
+
"""
|
323 |
+
original_size = tuple(image.shape[-2:])
|
324 |
+
if self.augmentations:
|
325 |
+
image, mask = self.augmentations(image, gt)
|
326 |
+
|
327 |
+
if self.sam_trans:
|
328 |
+
image, mask = self.sam_trans.apply_image_torch(
|
329 |
+
image.unsqueeze(0)), self.sam_trans.apply_image_torch(mask)
|
330 |
+
elif image.max() <= 255 and image.min() >= 0:
|
331 |
+
image = (image - self.mean) / self.std
|
332 |
+
mask[mask > 0.5] = 1
|
333 |
+
mask[mask <= 0.5] = 0
|
334 |
+
# image_size = tuple(img.shape[-2:])
|
335 |
+
|
336 |
+
image_size = self.image_size
|
337 |
+
if self.sam_trans is None:
|
338 |
+
image = torch.nn.functional.interpolate(image.unsqueeze(
|
339 |
+
0), size=image_size, mode='bilinear', align_corners=False).squeeze(0)
|
340 |
+
mask = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(
|
341 |
+
0), size=image_size, mode='nearest').squeeze(0).squeeze(0)
|
342 |
+
# img = (img - img.min()) / (img.max() - img.min()) # TODO uncomment this if results get worse
|
343 |
+
|
344 |
+
return {'image': self.sam_trans.preprocess(image).squeeze(0) if self.sam_trans else image,
|
345 |
+
'label': self.sam_trans.preprocess(mask) if self.sam_trans else mask,
|
346 |
+
'original_size': torch.Tensor(original_size),
|
347 |
+
'image_size': torch.Tensor(image_size),
|
348 |
+
'case': dataset} # case to be compatible with polyp video dataset
|
349 |
+
|
350 |
+
def get_dataset_name_from_path(self, path):
|
351 |
+
for dataset in self.datasets:
|
352 |
+
if dataset in path:
|
353 |
+
return dataset
|
354 |
+
return ""
|
355 |
+
|
356 |
+
def __getitem__(self, index):
|
357 |
+
image = self.cv2_loader(self.images[index], is_mask=False)
|
358 |
+
gt = self.cv2_loader(self.gts[index], is_mask=True)
|
359 |
+
dataset = self.get_dataset_name_from_path(self.images[index])
|
360 |
+
return self.process_image_gt(image, gt, dataset)
|
361 |
+
|
362 |
+
def filter_files_and_get_ds_mean_and_std(self):
|
363 |
+
assert len(self.images) == len(self.gts)
|
364 |
+
images = []
|
365 |
+
gts = []
|
366 |
+
ds_mean = 0
|
367 |
+
ds_std = 0
|
368 |
+
for img_path, gt_path in zip(self.images, self.gts):
|
369 |
+
if any([ex_ds in img_path for ex_ds in EXCLUDE_DS]):
|
370 |
+
continue
|
371 |
+
img = Image.open(img_path)
|
372 |
+
gt = Image.open(gt_path)
|
373 |
+
if img.size == gt.size:
|
374 |
+
images.append(img_path)
|
375 |
+
gts.append(gt_path)
|
376 |
+
ds_mean += np.array(img).mean()
|
377 |
+
ds_std += np.array(img).std()
|
378 |
+
self.images = images
|
379 |
+
self.gts = gts
|
380 |
+
self.mean = ds_mean / len(self.images)
|
381 |
+
self.std = ds_std / len(self.images)
|
382 |
+
|
383 |
+
def rgb_loader(self, path):
|
384 |
+
with open(path, 'rb') as f:
|
385 |
+
img = Image.open(f)
|
386 |
+
return img.convert('RGB')
|
387 |
+
|
388 |
+
def binary_loader(self, path):
|
389 |
+
# with open(path, 'rb') as f:
|
390 |
+
# img = Image.open(f)
|
391 |
+
# return img.convert('1')
|
392 |
+
img = cv2.imread(path, 0)
|
393 |
+
return img
|
394 |
+
|
395 |
+
def cv2_loader(self, path, is_mask):
|
396 |
+
if is_mask:
|
397 |
+
img = cv2.imread(path, 0)
|
398 |
+
img[img > 0] = 1
|
399 |
+
else:
|
400 |
+
img = cv2.cvtColor(cv2.imread(
|
401 |
+
path, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
402 |
+
return img
|
403 |
+
|
404 |
+
def resize(self, img, gt):
|
405 |
+
assert img.size == gt.size
|
406 |
+
w, h = img.size
|
407 |
+
if h < self.trainsize or w < self.trainsize:
|
408 |
+
h = max(h, self.trainsize)
|
409 |
+
w = max(w, self.trainsize)
|
410 |
+
return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
|
411 |
+
else:
|
412 |
+
return img, gt
|
413 |
+
|
414 |
+
def __len__(self):
|
415 |
+
# return 32
|
416 |
+
return self.size
|
417 |
+
|
418 |
+
|
419 |
+
class SuperpixPolypDataset(PolypDataset):
|
420 |
+
def __init__(self, root, image_root=None, gt_root=None, trainsize=352, augmentations=None, train=True, sam_trans=None, datasets=DATASETS, image_size=(1024, 1024), ds_mean=None, ds_std=None):
|
421 |
+
self.trainsize = trainsize
|
422 |
+
self.augmentations = augmentations
|
423 |
+
self.datasets = datasets
|
424 |
+
self.image_size = image_size
|
425 |
+
# print(self.augmentations)
|
426 |
+
if image_root is not None and gt_root is not None:
|
427 |
+
self.images = [
|
428 |
+
os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
|
429 |
+
self.gts = [
|
430 |
+
os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png') and 'superpix' in f]
|
431 |
+
# also look in subdirectories
|
432 |
+
for subdir in os.listdir(image_root):
|
433 |
+
# if not dir, continue
|
434 |
+
if not os.path.isdir(os.path.join(image_root, subdir)):
|
435 |
+
continue
|
436 |
+
subdir_image_root = os.path.join(image_root, subdir)
|
437 |
+
subdir_gt_root = os.path.join(gt_root, subdir)
|
438 |
+
self.images.extend([os.path.join(subdir_image_root, f) for f in os.listdir(
|
439 |
+
subdir_image_root) if f.endswith('.jpg') or f.endswith('.png')])
|
440 |
+
self.gts.extend([os.path.join(subdir_gt_root, f) for f in os.listdir(
|
441 |
+
subdir_gt_root) if f.endswith('.png')])
|
442 |
+
|
443 |
+
else:
|
444 |
+
self.images, self.gts = self.get_image_gt_pairs(
|
445 |
+
root, split="train" if train else "test", datasets=self.datasets)
|
446 |
+
self.images = sorted(self.images)
|
447 |
+
self.gts = sorted(self.gts)
|
448 |
+
if not 'VPS' in root:
|
449 |
+
self.filter_files_and_get_ds_mean_and_std()
|
450 |
+
if ds_mean is not None and ds_std is not None:
|
451 |
+
self.mean, self.std = ds_mean, ds_std
|
452 |
+
self.size = len(self.images)
|
453 |
+
self.train = train
|
454 |
+
self.sam_trans = sam_trans
|
455 |
+
if self.sam_trans is not None:
|
456 |
+
# sam trans takes care of norm
|
457 |
+
self.mean, self.std = 0 , 1
|
458 |
+
|
459 |
+
|
460 |
+
def __getitem__(self, index):
|
461 |
+
image = self.cv2_loader(self.images[index], is_mask=False)
|
462 |
+
gt = self.cv2_loader(self.gts[index], is_mask=False)
|
463 |
+
gt = gt[:, :, 0]
|
464 |
+
fgpath = os.path.basename(self.gts[index]).split('.png')[0].split('superpix-MIDDLE_')
|
465 |
+
fgpath = os.path.join(os.path.dirname(self.gts[index]), 'fgmask_' + fgpath[1] + '.png')
|
466 |
+
fg = self.cv2_loader(fgpath, is_mask=True)
|
467 |
+
dataset = self.get_dataset_name_from_path(self.images[index])
|
468 |
+
|
469 |
+
# randomly choose a superpixels from the gt
|
470 |
+
gt[1-fg] = 0
|
471 |
+
sp_id = random.choice(np.unique(gt)[1:])
|
472 |
+
sp = (gt == sp_id).astype(np.uint8)
|
473 |
+
|
474 |
+
|
475 |
+
out = self.process_image_gt(image, gt, dataset)
|
476 |
+
support_image, support_sp, dataset = out["image"], out["label"], out["case"]
|
477 |
+
|
478 |
+
out = self.process_image_gt(image, sp, dataset)
|
479 |
+
query_image, query_sp, dataset = out["image"], out["label"], out["case"]
|
480 |
+
|
481 |
+
# TODO tile the masks to have 3 channels?
|
482 |
+
|
483 |
+
support_bg_mask = 1 - support_sp
|
484 |
+
support_masks = {"fg_mask": support_sp, "bg_mask": support_bg_mask}
|
485 |
+
|
486 |
+
batch = {"support_images" : [[support_image]],
|
487 |
+
"support_mask" : [[support_masks]],
|
488 |
+
"query_images" : [query_image],
|
489 |
+
"query_labels" : [query_sp],
|
490 |
+
"scan_id" : [dataset]
|
491 |
+
}
|
492 |
+
|
493 |
+
return batch
|
494 |
+
|
495 |
+
|
496 |
+
def get_superpix_polyp_dataset(image_size:tuple=(1024,1024), sam_trans=None):
|
497 |
+
transform_train, transform_test = get_polyp_transform()
|
498 |
+
image_root = './data/PolypDataset/TrainDataset/images/'
|
499 |
+
gt_root = './data/PolypDataset/TrainDataset/superpixels/'
|
500 |
+
ds_train = SuperpixPolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
|
501 |
+
augmentations=transform_train,
|
502 |
+
sam_trans=sam_trans,
|
503 |
+
image_size=image_size)
|
504 |
+
|
505 |
+
return ds_train
|
506 |
+
|
507 |
+
def get_polyp_dataset(image_size, sam_trans=None):
|
508 |
+
transform_train, transform_test = get_polyp_transform()
|
509 |
+
image_root = './data/PolypDataset/TrainDataset/images/'
|
510 |
+
gt_root = './data/PolypDataset/TrainDataset/masks/'
|
511 |
+
ds_train = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root,
|
512 |
+
augmentations=transform_test, sam_trans=sam_trans, train=True, image_size=image_size)
|
513 |
+
image_root = './data/PolypDataset/TestDataset/test/images/'
|
514 |
+
gt_root = './data/PolypDataset/TestDataset/test/masks/'
|
515 |
+
ds_test = PolypDataset(root=image_root, image_root=image_root, gt_root=gt_root, train=False,
|
516 |
+
augmentations=transform_test, sam_trans=sam_trans, image_size=image_size)
|
517 |
+
return ds_train, ds_test
|
518 |
+
|
519 |
+
|
520 |
+
def get_tests_polyp_dataset(sam_trans):
|
521 |
+
transform_train, transform_test = get_polyp_transform()
|
522 |
+
|
523 |
+
image_root = './data/polyp/TestDataset/Kvasir/images/'
|
524 |
+
gt_root = './data/polyp/TestDataset/Kvasir/masks/'
|
525 |
+
ds_Kvasir = PolypDataset(
|
526 |
+
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
|
527 |
+
|
528 |
+
image_root = './data/polyp/TestDataset/CVC-ClinicDB/images/'
|
529 |
+
gt_root = './data/polyp/TestDataset/CVC-ClinicDB/masks/'
|
530 |
+
ds_ClinicDB = PolypDataset(
|
531 |
+
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
|
532 |
+
|
533 |
+
image_root = './data/polyp/TestDataset/CVC-ColonDB/images/'
|
534 |
+
gt_root = './data/polyp/TestDataset/CVC-ColonDB/masks/'
|
535 |
+
ds_ColonDB = PolypDataset(
|
536 |
+
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
|
537 |
+
|
538 |
+
image_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/images/'
|
539 |
+
gt_root = './data/polyp/TestDataset/ETIS-LaribPolypDB/masks/'
|
540 |
+
ds_ETIS = PolypDataset(
|
541 |
+
image_root, gt_root, augmentations=transform_test, train=False, sam_trans=sam_trans)
|
542 |
+
|
543 |
+
return ds_Kvasir, ds_ClinicDB, ds_ColonDB, ds_ETIS
|
544 |
+
|
545 |
+
|
546 |
+
if __name__ == '__main__':
|
547 |
+
# create_train_val_test_split_for_polyps()
|
548 |
+
create_suppport_set_for_polyps()
|
dataloaders/PolypTransforms.py
ADDED
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
import sys
|
5 |
+
import random
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
try:
|
9 |
+
import accimage
|
10 |
+
except ImportError:
|
11 |
+
accimage = None
|
12 |
+
import numpy as np
|
13 |
+
import numbers
|
14 |
+
import types
|
15 |
+
import collections
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from torchvision.transforms import functional as F
|
19 |
+
|
20 |
+
if sys.version_info < (3, 3):
|
21 |
+
Sequence = collections.Sequence
|
22 |
+
Iterable = collections.Iterable
|
23 |
+
else:
|
24 |
+
Sequence = collections.abc.Sequence
|
25 |
+
Iterable = collections.abc.Iterable
|
26 |
+
|
27 |
+
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "CenterCrop", "Pad",
|
28 |
+
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
|
29 |
+
"RandomVerticalFlip", "RandomResizedCrop", "FiveCrop", "TenCrop",
|
30 |
+
"ColorJitter", "RandomRotation", "RandomAffine",
|
31 |
+
"RandomPerspective"]
|
32 |
+
|
33 |
+
_pil_interpolation_to_str = {
|
34 |
+
Image.NEAREST: 'PIL.Image.NEAREST',
|
35 |
+
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
36 |
+
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
37 |
+
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
38 |
+
Image.HAMMING: 'PIL.Image.HAMMING',
|
39 |
+
Image.BOX: 'PIL.Image.BOX',
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
class Compose(object):
|
44 |
+
def __init__(self, transforms):
|
45 |
+
self.transforms = transforms
|
46 |
+
|
47 |
+
def __call__(self, img, mask):
|
48 |
+
for t in self.transforms:
|
49 |
+
img, mask = t(img, mask)
|
50 |
+
return img, mask
|
51 |
+
|
52 |
+
|
53 |
+
class ToTensor(object):
|
54 |
+
def __call__(self, img, mask):
|
55 |
+
# return F.to_tensor(img), F.to_tensor(mask)
|
56 |
+
img = np.array(img)
|
57 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float() # TODO add division by 255 to match torch.ToTensor()?
|
58 |
+
mask = torch.from_numpy(np.array(mask)).float()
|
59 |
+
return img, mask
|
60 |
+
|
61 |
+
|
62 |
+
class ToPILImage(object):
|
63 |
+
def __init__(self, mode=None):
|
64 |
+
self.mode = mode
|
65 |
+
|
66 |
+
def __call__(self, img, mask):
|
67 |
+
return F.to_pil_image(img, self.mode), F.to_pil_image(mask, self.mode)
|
68 |
+
|
69 |
+
|
70 |
+
class Normalize(object):
|
71 |
+
def __init__(self, mean, std, inplace=False):
|
72 |
+
self.mean = mean
|
73 |
+
self.std = std
|
74 |
+
self.inplace = inplace
|
75 |
+
|
76 |
+
def __call__(self, img, mask):
|
77 |
+
return F.normalize(img, self.mean, self.std, self.inplace), mask
|
78 |
+
|
79 |
+
|
80 |
+
class Resize(object):
|
81 |
+
def __init__(self, size, interpolation=Image.BILINEAR, do_mask=True):
|
82 |
+
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
|
83 |
+
self.size = size
|
84 |
+
self.interpolation = interpolation
|
85 |
+
self.do_mask = do_mask
|
86 |
+
|
87 |
+
def __call__(self, img, mask):
|
88 |
+
if self.do_mask:
|
89 |
+
return F.resize(img, self.size, Image.BICUBIC), F.resize(mask, self.size, Image.BICUBIC)
|
90 |
+
else:
|
91 |
+
return F.resize(img, self.size, Image.BICUBIC), mask
|
92 |
+
|
93 |
+
|
94 |
+
class CenterCrop(object):
|
95 |
+
def __init__(self, size):
|
96 |
+
if isinstance(size, numbers.Number):
|
97 |
+
self.size = (int(size), int(size))
|
98 |
+
else:
|
99 |
+
self.size = size
|
100 |
+
|
101 |
+
def __call__(self, img, mask):
|
102 |
+
return F.center_crop(img, self.size), F.center_crop(mask, self.size)
|
103 |
+
|
104 |
+
|
105 |
+
class Pad(object):
|
106 |
+
def __init__(self, padding, fill=0, padding_mode='constant'):
|
107 |
+
assert isinstance(padding, (numbers.Number, tuple))
|
108 |
+
assert isinstance(fill, (numbers.Number, str, tuple))
|
109 |
+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
110 |
+
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
|
111 |
+
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
|
112 |
+
"{} element tuple".format(len(padding)))
|
113 |
+
|
114 |
+
self.padding = padding
|
115 |
+
self.fill = fill
|
116 |
+
self.padding_mode = padding_mode
|
117 |
+
|
118 |
+
def __call__(self, img, mask):
|
119 |
+
return F.pad(img, self.padding, self.fill, self.padding_mode), \
|
120 |
+
F.pad(mask, self.padding, self.fill, self.padding_mode)
|
121 |
+
|
122 |
+
|
123 |
+
class Lambda(object):
|
124 |
+
def __init__(self, lambd):
|
125 |
+
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
|
126 |
+
self.lambd = lambd
|
127 |
+
|
128 |
+
def __call__(self, img, mask):
|
129 |
+
return self.lambd(img), self.lambd(mask)
|
130 |
+
|
131 |
+
|
132 |
+
class Lambda_image(object):
|
133 |
+
def __init__(self, lambd):
|
134 |
+
assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
|
135 |
+
self.lambd = lambd
|
136 |
+
|
137 |
+
def __call__(self, img, mask):
|
138 |
+
return self.lambd(img), mask
|
139 |
+
|
140 |
+
|
141 |
+
class RandomTransforms(object):
|
142 |
+
def __init__(self, transforms):
|
143 |
+
assert isinstance(transforms, (list, tuple))
|
144 |
+
self.transforms = transforms
|
145 |
+
|
146 |
+
def __call__(self, *args, **kwargs):
|
147 |
+
raise NotImplementedError()
|
148 |
+
|
149 |
+
|
150 |
+
class RandomApply(RandomTransforms):
|
151 |
+
def __init__(self, transforms, p=0.5):
|
152 |
+
super(RandomApply, self).__init__(transforms)
|
153 |
+
self.p = p
|
154 |
+
|
155 |
+
def __call__(self, img, mask):
|
156 |
+
if self.p < random.random():
|
157 |
+
return img, mask
|
158 |
+
for t in self.transforms:
|
159 |
+
img, mask = t(img, mask)
|
160 |
+
return img, mask
|
161 |
+
|
162 |
+
|
163 |
+
class RandomOrder(RandomTransforms):
|
164 |
+
def __call__(self, img, mask):
|
165 |
+
order = list(range(len(self.transforms)))
|
166 |
+
random.shuffle(order)
|
167 |
+
for i in order:
|
168 |
+
img, mask = self.transforms[i](img, mask)
|
169 |
+
return img, mask
|
170 |
+
|
171 |
+
|
172 |
+
class RandomChoice(RandomTransforms):
|
173 |
+
def __call__(self, img, mask):
|
174 |
+
t = random.choice(self.transforms)
|
175 |
+
return t(img, mask)
|
176 |
+
|
177 |
+
|
178 |
+
class RandomCrop(object):
|
179 |
+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
|
180 |
+
if isinstance(size, numbers.Number):
|
181 |
+
self.size = (int(size), int(size))
|
182 |
+
else:
|
183 |
+
self.size = size
|
184 |
+
self.padding = padding
|
185 |
+
self.pad_if_needed = pad_if_needed
|
186 |
+
self.fill = fill
|
187 |
+
self.padding_mode = padding_mode
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def get_params(img, output_size):
|
191 |
+
w, h = img.size
|
192 |
+
th, tw = output_size
|
193 |
+
if w == tw and h == th:
|
194 |
+
return 0, 0, h, w
|
195 |
+
|
196 |
+
i = random.randint(0, h - th)
|
197 |
+
j = random.randint(0, w - tw)
|
198 |
+
return i, j, th, tw
|
199 |
+
|
200 |
+
def __call__(self, img, mask):
|
201 |
+
if self.padding is not None:
|
202 |
+
img = F.pad(img, self.padding, self.fill, self.padding_mode)
|
203 |
+
|
204 |
+
# pad the width if needed
|
205 |
+
if self.pad_if_needed and img.size[0] < self.size[1]:
|
206 |
+
img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
|
207 |
+
# pad the height if needed
|
208 |
+
if self.pad_if_needed and img.size[1] < self.size[0]:
|
209 |
+
img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
|
210 |
+
|
211 |
+
i, j, h, w = self.get_params(img, self.size)
|
212 |
+
|
213 |
+
return F.crop(img, i, j, h, w), F.crop(mask, i, j, h, w)
|
214 |
+
|
215 |
+
|
216 |
+
class RandomHorizontalFlip(object):
|
217 |
+
def __init__(self, p=0.5):
|
218 |
+
self.p = p
|
219 |
+
|
220 |
+
def __call__(self, img, mask):
|
221 |
+
if random.random() < self.p:
|
222 |
+
return F.hflip(img), F.hflip(mask)
|
223 |
+
return img, mask
|
224 |
+
|
225 |
+
|
226 |
+
class RandomVerticalFlip(object):
|
227 |
+
def __init__(self, p=0.5):
|
228 |
+
self.p = p
|
229 |
+
|
230 |
+
def __call__(self, img, mask):
|
231 |
+
if random.random() < self.p:
|
232 |
+
return F.vflip(img), F.vflip(mask)
|
233 |
+
return img, mask
|
234 |
+
|
235 |
+
|
236 |
+
class RandomPerspective(object):
|
237 |
+
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
|
238 |
+
self.p = p
|
239 |
+
self.interpolation = interpolation
|
240 |
+
self.distortion_scale = distortion_scale
|
241 |
+
|
242 |
+
def __call__(self, img, mask):
|
243 |
+
if not F._is_pil_image(img):
|
244 |
+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
245 |
+
|
246 |
+
if random.random() < self.p:
|
247 |
+
width, height = img.size
|
248 |
+
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
|
249 |
+
return F.perspective(img, startpoints, endpoints, self.interpolation), \
|
250 |
+
F.perspective(mask, startpoints, endpoints, Image.NEAREST)
|
251 |
+
return img, mask
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def get_params(width, height, distortion_scale):
|
255 |
+
half_height = int(height / 2)
|
256 |
+
half_width = int(width / 2)
|
257 |
+
topleft = (random.randint(0, int(distortion_scale * half_width)),
|
258 |
+
random.randint(0, int(distortion_scale * half_height)))
|
259 |
+
topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
|
260 |
+
random.randint(0, int(distortion_scale * half_height)))
|
261 |
+
botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1),
|
262 |
+
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
|
263 |
+
botleft = (random.randint(0, int(distortion_scale * half_width)),
|
264 |
+
random.randint(height - int(distortion_scale * half_height) - 1, height - 1))
|
265 |
+
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)]
|
266 |
+
endpoints = [topleft, topright, botright, botleft]
|
267 |
+
return startpoints, endpoints
|
268 |
+
|
269 |
+
|
270 |
+
class RandomResizedCrop(object):
|
271 |
+
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
|
272 |
+
if isinstance(size, tuple):
|
273 |
+
self.size = size
|
274 |
+
else:
|
275 |
+
self.size = (size, size)
|
276 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
277 |
+
warnings.warn("range should be of kind (min, max)")
|
278 |
+
|
279 |
+
self.interpolation = interpolation
|
280 |
+
self.scale = scale
|
281 |
+
self.ratio = ratio
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def get_params(img, scale, ratio):
|
285 |
+
area = img.size[0] * img.size[1]
|
286 |
+
|
287 |
+
for attempt in range(10):
|
288 |
+
target_area = random.uniform(*scale) * area
|
289 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
290 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
291 |
+
|
292 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
293 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
294 |
+
|
295 |
+
if w <= img.size[0] and h <= img.size[1]:
|
296 |
+
i = random.randint(0, img.size[1] - h)
|
297 |
+
j = random.randint(0, img.size[0] - w)
|
298 |
+
return i, j, h, w
|
299 |
+
|
300 |
+
# Fallback to central crop
|
301 |
+
in_ratio = img.size[0] / img.size[1]
|
302 |
+
if (in_ratio < min(ratio)):
|
303 |
+
w = img.size[0]
|
304 |
+
h = w / min(ratio)
|
305 |
+
elif (in_ratio > max(ratio)):
|
306 |
+
h = img.size[1]
|
307 |
+
w = h * max(ratio)
|
308 |
+
else: # whole image
|
309 |
+
w = img.size[0]
|
310 |
+
h = img.size[1]
|
311 |
+
i = (img.size[1] - h) // 2
|
312 |
+
j = (img.size[0] - w) // 2
|
313 |
+
return i, j, h, w
|
314 |
+
|
315 |
+
def __call__(self, img, mask):
|
316 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
317 |
+
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), \
|
318 |
+
F.resized_crop(mask, i, j, h, w, self.size, Image.NEAREST)
|
319 |
+
|
320 |
+
|
321 |
+
class FiveCrop(object):
|
322 |
+
def __init__(self, size):
|
323 |
+
self.size = size
|
324 |
+
if isinstance(size, numbers.Number):
|
325 |
+
self.size = (int(size), int(size))
|
326 |
+
else:
|
327 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
328 |
+
self.size = size
|
329 |
+
|
330 |
+
def __call__(self, img, mask):
|
331 |
+
return F.five_crop(img, self.size), F.five_crop(mask, self.size)
|
332 |
+
|
333 |
+
|
334 |
+
class TenCrop(object):
|
335 |
+
def __init__(self, size, vertical_flip=False):
|
336 |
+
self.size = size
|
337 |
+
if isinstance(size, numbers.Number):
|
338 |
+
self.size = (int(size), int(size))
|
339 |
+
else:
|
340 |
+
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
341 |
+
self.size = size
|
342 |
+
self.vertical_flip = vertical_flip
|
343 |
+
|
344 |
+
def __call__(self, img, mask):
|
345 |
+
return F.ten_crop(img, self.size, self.vertical_flip), F.ten_crop(mask, self.size, self.vertical_flip)
|
346 |
+
|
347 |
+
|
348 |
+
class ColorJitter(object):
|
349 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
350 |
+
self.brightness = self._check_input(brightness, 'brightness')
|
351 |
+
self.contrast = self._check_input(contrast, 'contrast')
|
352 |
+
self.saturation = self._check_input(saturation, 'saturation')
|
353 |
+
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
|
354 |
+
clip_first_on_zero=False)
|
355 |
+
|
356 |
+
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
|
357 |
+
if isinstance(value, numbers.Number):
|
358 |
+
if value < 0:
|
359 |
+
raise ValueError("If {} is a single number, it must be non negative.".format(name))
|
360 |
+
value = [center - value, center + value]
|
361 |
+
if clip_first_on_zero:
|
362 |
+
value[0] = max(value[0], 0)
|
363 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
364 |
+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
|
365 |
+
raise ValueError("{} values should be between {}".format(name, bound))
|
366 |
+
else:
|
367 |
+
raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
|
368 |
+
|
369 |
+
# if value is 0 or (1., 1.) for brightness/contrast/saturation
|
370 |
+
# or (0., 0.) for hue, do nothing
|
371 |
+
if value[0] == value[1] == center:
|
372 |
+
value = None
|
373 |
+
return value
|
374 |
+
|
375 |
+
@staticmethod
|
376 |
+
def get_params(brightness, contrast, saturation, hue):
|
377 |
+
transforms = []
|
378 |
+
|
379 |
+
if brightness is not None:
|
380 |
+
brightness_factor = random.uniform(brightness[0], brightness[1])
|
381 |
+
transforms.append(Lambda_image(lambda img: F.adjust_brightness(img, brightness_factor)))
|
382 |
+
|
383 |
+
if contrast is not None:
|
384 |
+
contrast_factor = random.uniform(contrast[0], contrast[1])
|
385 |
+
transforms.append(Lambda_image(lambda img: F.adjust_contrast(img, contrast_factor)))
|
386 |
+
|
387 |
+
if saturation is not None:
|
388 |
+
saturation_factor = random.uniform(saturation[0], saturation[1])
|
389 |
+
transforms.append(Lambda_image(lambda img: F.adjust_saturation(img, saturation_factor)))
|
390 |
+
|
391 |
+
if hue is not None:
|
392 |
+
hue_factor = random.uniform(hue[0], hue[1])
|
393 |
+
transforms.append(Lambda_image(lambda img: F.adjust_hue(img, hue_factor)))
|
394 |
+
|
395 |
+
random.shuffle(transforms)
|
396 |
+
transform = Compose(transforms)
|
397 |
+
|
398 |
+
return transform
|
399 |
+
|
400 |
+
def __call__(self, img, mask):
|
401 |
+
transform = self.get_params(self.brightness, self.contrast,
|
402 |
+
self.saturation, self.hue)
|
403 |
+
return transform(img, mask)
|
404 |
+
|
405 |
+
|
406 |
+
class RandomRotation(object):
|
407 |
+
def __init__(self, degrees, resample=False, expand=False, center=None):
|
408 |
+
if isinstance(degrees, numbers.Number):
|
409 |
+
if degrees < 0:
|
410 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
411 |
+
self.degrees = (-degrees, degrees)
|
412 |
+
else:
|
413 |
+
if len(degrees) != 2:
|
414 |
+
raise ValueError("If degrees is a sequence, it must be of len 2.")
|
415 |
+
self.degrees = degrees
|
416 |
+
|
417 |
+
self.resample = resample
|
418 |
+
self.expand = expand
|
419 |
+
self.center = center
|
420 |
+
|
421 |
+
@staticmethod
|
422 |
+
def get_params(degrees):
|
423 |
+
angle = random.uniform(degrees[0], degrees[1])
|
424 |
+
|
425 |
+
return angle
|
426 |
+
|
427 |
+
def __call__(self, img, mask):
|
428 |
+
angle = self.get_params(self.degrees)
|
429 |
+
|
430 |
+
return F.rotate(img, angle, Image.BILINEAR, self.expand, self.center), \
|
431 |
+
F.rotate(mask, angle, Image.NEAREST, self.expand, self.center)
|
432 |
+
|
433 |
+
|
434 |
+
class RandomAffine(object):
|
435 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0):
|
436 |
+
if isinstance(degrees, numbers.Number):
|
437 |
+
if degrees < 0:
|
438 |
+
raise ValueError("If degrees is a single number, it must be positive.")
|
439 |
+
self.degrees = (-degrees, degrees)
|
440 |
+
else:
|
441 |
+
assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
|
442 |
+
"degrees should be a list or tuple and it must be of length 2."
|
443 |
+
self.degrees = degrees
|
444 |
+
|
445 |
+
if translate is not None:
|
446 |
+
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
447 |
+
"translate should be a list or tuple and it must be of length 2."
|
448 |
+
for t in translate:
|
449 |
+
if not (0.0 <= t <= 1.0):
|
450 |
+
raise ValueError("translation values should be between 0 and 1")
|
451 |
+
self.translate = translate
|
452 |
+
|
453 |
+
if scale is not None:
|
454 |
+
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
455 |
+
"scale should be a list or tuple and it must be of length 2."
|
456 |
+
for s in scale:
|
457 |
+
if s <= 0:
|
458 |
+
raise ValueError("scale values should be positive")
|
459 |
+
self.scale = scale
|
460 |
+
|
461 |
+
if shear is not None:
|
462 |
+
if isinstance(shear, numbers.Number):
|
463 |
+
if shear < 0:
|
464 |
+
raise ValueError("If shear is a single number, it must be positive.")
|
465 |
+
self.shear = (-shear, shear)
|
466 |
+
else:
|
467 |
+
assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
|
468 |
+
"shear should be a list or tuple and it must be of length 2."
|
469 |
+
self.shear = shear
|
470 |
+
else:
|
471 |
+
self.shear = shear
|
472 |
+
|
473 |
+
self.resample = resample
|
474 |
+
self.fillcolor = fillcolor
|
475 |
+
|
476 |
+
@staticmethod
|
477 |
+
def get_params(degrees, translate, scale_ranges, shears, img_size):
|
478 |
+
angle = random.uniform(degrees[0], degrees[1])
|
479 |
+
if translate is not None:
|
480 |
+
max_dx = translate[0] * img_size[0]
|
481 |
+
max_dy = translate[1] * img_size[1]
|
482 |
+
translations = (np.round(random.uniform(-max_dx, max_dx)),
|
483 |
+
np.round(random.uniform(-max_dy, max_dy)))
|
484 |
+
else:
|
485 |
+
translations = (0, 0)
|
486 |
+
|
487 |
+
if scale_ranges is not None:
|
488 |
+
scale = random.uniform(scale_ranges[0], scale_ranges[1])
|
489 |
+
else:
|
490 |
+
scale = 1.0
|
491 |
+
|
492 |
+
if shears is not None:
|
493 |
+
shear = random.uniform(shears[0], shears[1])
|
494 |
+
else:
|
495 |
+
shear = 0.0
|
496 |
+
|
497 |
+
return angle, translations, scale, shear
|
498 |
+
|
499 |
+
def __call__(self, img, mask):
|
500 |
+
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
|
501 |
+
return F.affine(img, *ret, interpolation=Image.BILINEAR, fill=self.fillcolor), \
|
502 |
+
F.affine(mask, *ret, interpolation=Image.NEAREST, fill=self.fillcolor)
|
503 |
+
|
504 |
+
|
505 |
+
|
506 |
+
def get_cub_transform():
|
507 |
+
transform_train = Compose([
|
508 |
+
ToPILImage(),
|
509 |
+
Resize((256, 256)),
|
510 |
+
RandomHorizontalFlip(),
|
511 |
+
RandomAffine(22, scale=(0.75, 1.25)),
|
512 |
+
ToTensor(),
|
513 |
+
Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
|
514 |
+
])
|
515 |
+
transform_test = Compose([
|
516 |
+
ToPILImage(),
|
517 |
+
Resize((256, 256)),
|
518 |
+
ToTensor(),
|
519 |
+
Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
|
520 |
+
])
|
521 |
+
return transform_train, transform_test
|
522 |
+
|
523 |
+
|
524 |
+
def get_glas_transform():
|
525 |
+
transform_train = Compose([
|
526 |
+
ToPILImage(),
|
527 |
+
# Resize((256, 256)),
|
528 |
+
ColorJitter(brightness=0.2,
|
529 |
+
contrast=0.2,
|
530 |
+
saturation=0.2,
|
531 |
+
hue=0.1),
|
532 |
+
RandomHorizontalFlip(),
|
533 |
+
RandomAffine(5, scale=(0.75, 1.25)),
|
534 |
+
ToTensor(),
|
535 |
+
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
|
536 |
+
])
|
537 |
+
transform_test = Compose([
|
538 |
+
ToPILImage(),
|
539 |
+
# Resize((256, 256)),
|
540 |
+
ToTensor(),
|
541 |
+
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
|
542 |
+
])
|
543 |
+
return transform_train, transform_test
|
544 |
+
|
545 |
+
# def get_glas_transform():
|
546 |
+
# transform_train = Compose([
|
547 |
+
# ToPILImage(),
|
548 |
+
# Resize((256, 256)),
|
549 |
+
# ColorJitter(brightness=0.2,
|
550 |
+
# contrast=0.2,
|
551 |
+
# saturation=0.2,
|
552 |
+
# hue=0.1),
|
553 |
+
# RandomHorizontalFlip(),
|
554 |
+
# RandomAffine(5, scale=(0.75, 1.25)),
|
555 |
+
# ToTensor(),
|
556 |
+
# Normalize(mean=[255*0.485, 255*0.456, 255*0.406], std=[255*0.229, 255*0.224, 255*0.225])
|
557 |
+
# ])
|
558 |
+
# transform_test = Compose([
|
559 |
+
# ToPILImage(),
|
560 |
+
# Resize((256, 256)),
|
561 |
+
# ToTensor(),
|
562 |
+
# Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])
|
563 |
+
# ])
|
564 |
+
# return transform_train, transform_test
|
565 |
+
|
566 |
+
|
567 |
+
def get_monu_transform(args):
|
568 |
+
Idim = int(args['Idim'])
|
569 |
+
transform_train = Compose([
|
570 |
+
ToPILImage(),
|
571 |
+
# Resize((Idim, Idim)),
|
572 |
+
ColorJitter(brightness=0.4,
|
573 |
+
contrast=0.4,
|
574 |
+
saturation=0.4,
|
575 |
+
hue=0.1),
|
576 |
+
RandomHorizontalFlip(),
|
577 |
+
RandomAffine(int(args['rotate']), scale=(float(args['scale1']), float(args['scale2']))),
|
578 |
+
ToTensor(),
|
579 |
+
# Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
|
580 |
+
])
|
581 |
+
transform_test = Compose([
|
582 |
+
ToPILImage(),
|
583 |
+
# Resize((Idim, Idim)),
|
584 |
+
ToTensor(),
|
585 |
+
# Normalize(mean=[142.07, 98.48, 132.96], std=[65.78, 57.05, 57.78])
|
586 |
+
])
|
587 |
+
return transform_train, transform_test
|
588 |
+
|
589 |
+
|
590 |
+
def get_polyp_transform():
|
591 |
+
transform_train = Compose([
|
592 |
+
# Resize((352, 352)),
|
593 |
+
ToPILImage(),
|
594 |
+
ColorJitter(brightness=0.4,
|
595 |
+
contrast=0.4,
|
596 |
+
saturation=0.4,
|
597 |
+
hue=0.1),
|
598 |
+
RandomVerticalFlip(),
|
599 |
+
RandomHorizontalFlip(),
|
600 |
+
RandomAffine(90, scale=(0.75, 1.25)),
|
601 |
+
ToTensor(),
|
602 |
+
# Normalize([105.61, 63.69, 45.67],
|
603 |
+
# [83.08, 55.86, 42.59])
|
604 |
+
])
|
605 |
+
transform_test = Compose([
|
606 |
+
# Resize((352, 352)),
|
607 |
+
ToPILImage(),
|
608 |
+
ToTensor(),
|
609 |
+
# Normalize([105.61, 63.69, 45.67],
|
610 |
+
# [83.08, 55.86, 42.59])
|
611 |
+
])
|
612 |
+
return transform_train, transform_test
|
613 |
+
|
614 |
+
|
615 |
+
def get_polyp_support_train_transform():
|
616 |
+
transform_train = Compose([
|
617 |
+
ColorJitter(brightness=0.4,
|
618 |
+
contrast=0.4,
|
619 |
+
saturation=0.4,
|
620 |
+
hue=0.1),
|
621 |
+
RandomVerticalFlip(),
|
622 |
+
RandomHorizontalFlip(),
|
623 |
+
RandomAffine(90, scale=(0.75, 1.25)),
|
624 |
+
])
|
625 |
+
|
626 |
+
return transform_train
|
dataloaders/SimpleDataset.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
"""
|
6 |
+
simple dataset, gets the images and masks as list together with a transform function that
|
7 |
+
shoudl receive both the image and the mask.
|
8 |
+
loop means how many times to loop the dataset per epoch
|
9 |
+
"""
|
10 |
+
|
11 |
+
class SimpleDataset(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, image_list, mask_list, transform=None, norm_func=None, loops=10, modality="", debug=False, image_size=None):
|
13 |
+
self.image_list = image_list
|
14 |
+
if image_size is not None:
|
15 |
+
if len(image_size) == 1:
|
16 |
+
image_size = (image_size, image_size)
|
17 |
+
self.image_size = image_size
|
18 |
+
else:
|
19 |
+
self.image_size = image_list[0].shape[-2:]
|
20 |
+
self.mask_list = mask_list
|
21 |
+
self.transform = transform
|
22 |
+
self.norm_func = norm_func
|
23 |
+
self.loops = loops
|
24 |
+
self.modality = modality
|
25 |
+
self.debug = debug
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.image_list) * self.loops
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
idx = idx % (len(self.image_list))
|
32 |
+
image = self.image_list[idx].numpy()
|
33 |
+
mask = self.mask_list[idx].to(dtype=torch.uint8).numpy()
|
34 |
+
if self.modality == "CT":
|
35 |
+
image = image.astype(np.uint8)
|
36 |
+
if self.transform:
|
37 |
+
image, mask = self.transform(image, mask)
|
38 |
+
else:
|
39 |
+
# mask = np.repeat(mask[..., np.newaxis], 3, axis=-1)
|
40 |
+
if self.transform:
|
41 |
+
image, mask = self.transform(image, mask)
|
42 |
+
|
43 |
+
if self.norm_func:
|
44 |
+
image = self.norm_func(image)
|
45 |
+
|
46 |
+
mask[mask != 0] = 1
|
47 |
+
|
48 |
+
if self.image_size != image.shape[-2:]:
|
49 |
+
image = torch.nn.functional.interpolate(torch.tensor(image).unsqueeze(0), self.image_size, mode='bilinear').squeeze(0)
|
50 |
+
mask = torch.nn.functional.interpolate(torch.tensor(mask).unsqueeze(0).unsqueeze(0), self.image_size, mode='nearest').squeeze(0).squeeze(0)
|
51 |
+
|
52 |
+
# plot image and mask
|
53 |
+
if self.debug:
|
54 |
+
fig = plt.figure()
|
55 |
+
plt.imshow((image[0]- image.min()) / (image.max() - image.min()))
|
56 |
+
plt.imshow(mask, alpha=0.5)
|
57 |
+
plt.savefig("debug/support_image_mask.png")
|
58 |
+
plt.close(fig)
|
59 |
+
|
60 |
+
image_size = torch.tensor(tuple(image.shape[-2:]))
|
61 |
+
return image, mask
|
dataloaders/__init__.py
ADDED
File without changes
|
dataloaders/augutils.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Utilities for augmentation. Partly credit to Dr. Jo Schlemper
|
3 |
+
'''
|
4 |
+
from os.path import join
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as deftfx
|
9 |
+
import dataloaders.image_transforms as myit
|
10 |
+
import copy
|
11 |
+
from util.consts import IMG_SIZE
|
12 |
+
import time
|
13 |
+
import functools
|
14 |
+
|
15 |
+
|
16 |
+
def get_sabs_aug(input_size, use_3d=False):
|
17 |
+
sabs_aug = {
|
18 |
+
# turn flipping off as medical data has fixed orientations
|
19 |
+
'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
|
20 |
+
'affine': {
|
21 |
+
'rotate': 5,
|
22 |
+
'shift': (5, 5),
|
23 |
+
'shear': 5,
|
24 |
+
'scale': (0.9, 1.2),
|
25 |
+
},
|
26 |
+
'elastic': {'alpha': 10, 'sigma': 5},
|
27 |
+
'patch': input_size,
|
28 |
+
'reduce_2d': True,
|
29 |
+
'3d': use_3d,
|
30 |
+
'gamma_range': (0.5, 1.5)
|
31 |
+
}
|
32 |
+
return sabs_aug
|
33 |
+
|
34 |
+
|
35 |
+
def get_sabs_augv3(input_size):
|
36 |
+
sabs_augv3 = {
|
37 |
+
'flip': {'v': False, 'h': False, 't': False, 'p': 0.25},
|
38 |
+
'affine': {
|
39 |
+
'rotate': 30,
|
40 |
+
'shift': (30, 30),
|
41 |
+
'shear': 30,
|
42 |
+
'scale': (0.8, 1.3),
|
43 |
+
},
|
44 |
+
'elastic': {'alpha': 20, 'sigma': 5},
|
45 |
+
'patch': input_size,
|
46 |
+
'reduce_2d': True,
|
47 |
+
'gamma_range': (0.2, 1.8)
|
48 |
+
}
|
49 |
+
return sabs_augv3
|
50 |
+
|
51 |
+
|
52 |
+
def get_aug(which_aug, input_size):
|
53 |
+
if which_aug == 'sabs_aug':
|
54 |
+
return get_sabs_aug(input_size)
|
55 |
+
elif which_aug == 'aug_v3':
|
56 |
+
return get_sabs_augv3(input_size)
|
57 |
+
else:
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
# augs = {
|
61 |
+
# 'sabs_aug': get_sabs_aug,
|
62 |
+
# 'aug_v3': get_sabs_augv3, # more aggresive
|
63 |
+
# }
|
64 |
+
|
65 |
+
|
66 |
+
def get_geometric_transformer(aug, order=3):
|
67 |
+
"""order: interpolation degree. Select order=0 for augmenting segmentation """
|
68 |
+
affine = aug['aug'].get('affine', 0)
|
69 |
+
alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
|
70 |
+
sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
|
71 |
+
flip = aug['aug'].get(
|
72 |
+
'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
|
73 |
+
|
74 |
+
tfx = []
|
75 |
+
if 'flip' in aug['aug']:
|
76 |
+
tfx.append(myit.RandomFlip3D(**flip))
|
77 |
+
|
78 |
+
if 'affine' in aug['aug']:
|
79 |
+
tfx.append(myit.RandomAffine(affine.get('rotate'),
|
80 |
+
affine.get('shift'),
|
81 |
+
affine.get('shear'),
|
82 |
+
affine.get('scale'),
|
83 |
+
affine.get('scale_iso', True),
|
84 |
+
order=order))
|
85 |
+
|
86 |
+
if 'elastic' in aug['aug']:
|
87 |
+
tfx.append(myit.ElasticTransform(alpha, sigma))
|
88 |
+
input_transform = deftfx.Compose(tfx)
|
89 |
+
return input_transform
|
90 |
+
|
91 |
+
|
92 |
+
def get_geometric_transformer_3d(aug, order=3):
|
93 |
+
"""order: interpolation degree. Select order=0 for augmenting segmentation """
|
94 |
+
affine = aug['aug'].get('affine', 0)
|
95 |
+
alpha = aug['aug'].get('elastic', {'alpha': 0})['alpha']
|
96 |
+
sigma = aug['aug'].get('elastic', {'sigma': 0})['sigma']
|
97 |
+
flip = aug['aug'].get(
|
98 |
+
'flip', {'v': True, 'h': True, 't': True, 'p': 0.125})
|
99 |
+
|
100 |
+
tfx = []
|
101 |
+
if 'flip' in aug['aug']:
|
102 |
+
tfx.append(myit.RandomFlip3D(**flip))
|
103 |
+
|
104 |
+
if 'affine' in aug['aug']:
|
105 |
+
tfx.append(myit.RandomAffine(affine.get('rotate'),
|
106 |
+
affine.get('shift'),
|
107 |
+
affine.get('shear'),
|
108 |
+
affine.get('scale'),
|
109 |
+
affine.get('scale_iso', True),
|
110 |
+
order=order,
|
111 |
+
use_3d=True))
|
112 |
+
|
113 |
+
if 'elastic' in aug['aug']:
|
114 |
+
tfx.append(myit.ElasticTransform(alpha, sigma))
|
115 |
+
input_transform = deftfx.Compose(tfx)
|
116 |
+
return input_transform
|
117 |
+
|
118 |
+
|
119 |
+
def gamma_transform(img, aug):
|
120 |
+
gamma_range = aug['aug']['gamma_range']
|
121 |
+
if isinstance(gamma_range, tuple):
|
122 |
+
gamma = np.random.rand() * \
|
123 |
+
(gamma_range[1] - gamma_range[0]) + gamma_range[0]
|
124 |
+
cmin = img.min()
|
125 |
+
irange = (img.max() - cmin + 1e-5)
|
126 |
+
|
127 |
+
img = img - cmin + 1e-5
|
128 |
+
img = irange * np.power(img * 1.0 / irange, gamma)
|
129 |
+
img = img + cmin
|
130 |
+
|
131 |
+
elif gamma_range == False:
|
132 |
+
pass
|
133 |
+
else:
|
134 |
+
raise ValueError(
|
135 |
+
"Cannot identify gamma transform range {}".format(gamma_range))
|
136 |
+
return img
|
137 |
+
|
138 |
+
|
139 |
+
def get_intensity_transformer(aug):
|
140 |
+
"""some basic intensity transforms"""
|
141 |
+
return functools.partial(gamma_transform, aug=aug)
|
142 |
+
|
143 |
+
|
144 |
+
def transform_with_label(aug):
|
145 |
+
"""
|
146 |
+
Doing image geometric transform
|
147 |
+
Proposed image to have the following configurations
|
148 |
+
[H x W x C + CL]
|
149 |
+
Where CL is the number of channels for the label. It is NOT in one-hot form
|
150 |
+
"""
|
151 |
+
|
152 |
+
geometric_tfx = get_geometric_transformer(aug)
|
153 |
+
intensity_tfx = get_intensity_transformer(aug)
|
154 |
+
|
155 |
+
def transform(comp, c_label, c_img, use_onehot, nclass, **kwargs):
|
156 |
+
"""
|
157 |
+
Args
|
158 |
+
comp: a numpy array with shape [H x W x C + c_label]
|
159 |
+
c_label: number of channels for a compact label. Note that the current version only supports 1 slice (H x W x 1)
|
160 |
+
nc_onehot: -1 for not using one-hot representation of mask. otherwise, specify number of classes in the label
|
161 |
+
|
162 |
+
"""
|
163 |
+
comp = copy.deepcopy(comp)
|
164 |
+
if (use_onehot is True) and (c_label != 1):
|
165 |
+
raise NotImplementedError(
|
166 |
+
"Only allow compact label, also the label can only be 2d")
|
167 |
+
assert c_img + 1 == comp.shape[-1], "only allow single slice 2D label"
|
168 |
+
|
169 |
+
# geometric transform
|
170 |
+
_label = comp[..., c_img]
|
171 |
+
_h_label = np.float32(np.arange(nclass) == (_label[..., None]))
|
172 |
+
# _h_label = np.float32(_label[..., None])
|
173 |
+
comp = np.concatenate([comp[..., :c_img], _h_label], -1)
|
174 |
+
comp = geometric_tfx(comp)
|
175 |
+
# round one_hot labels to 0 or 1
|
176 |
+
t_label_h = comp[..., c_img:]
|
177 |
+
t_label_h = np.rint(t_label_h)
|
178 |
+
assert t_label_h.max() <= 1
|
179 |
+
t_img = comp[..., 0: c_img]
|
180 |
+
|
181 |
+
# intensity transform
|
182 |
+
t_img = intensity_tfx(t_img)
|
183 |
+
|
184 |
+
if use_onehot is True:
|
185 |
+
t_label = t_label_h
|
186 |
+
else:
|
187 |
+
t_label = np.expand_dims(np.argmax(t_label_h, axis=-1), -1)
|
188 |
+
return t_img, t_label
|
189 |
+
|
190 |
+
return transform
|
191 |
+
|
192 |
+
|
193 |
+
def transform(scan, label, nclass, geometric_tfx, intensity_tfx):
|
194 |
+
"""
|
195 |
+
Args
|
196 |
+
scan: a numpy array with shape [D x H x W x C]
|
197 |
+
label: a numpy array with shape [D x H x W x 1]
|
198 |
+
"""
|
199 |
+
assert len(scan.shape) == 4, "Input scan must be 4D"
|
200 |
+
if len(label.shape) == 3:
|
201 |
+
label = np.expand_dims(label, -1)
|
202 |
+
|
203 |
+
# geometric transform
|
204 |
+
comp = copy.deepcopy(np.concatenate(
|
205 |
+
[scan, label], -1)) # [D x H x W x C + 1]
|
206 |
+
_label = comp[..., -1]
|
207 |
+
_h_label = np.float32(np.arange(nclass) == (_label[..., None]))
|
208 |
+
comp = np.concatenate([comp[..., :-1], _h_label], -1)
|
209 |
+
# change comp to be H x W x D x C + 1
|
210 |
+
comp = np.transpose(comp, (1, 2, 0, 3))
|
211 |
+
comp = geometric_tfx(comp)
|
212 |
+
t_label_h = comp[..., 1:]
|
213 |
+
t_label_h = np.rint(t_label_h)
|
214 |
+
assert t_label_h.max() <= 1
|
215 |
+
t_img = comp[..., 0:1]
|
216 |
+
|
217 |
+
# intensity transform
|
218 |
+
t_img = intensity_tfx(t_img)
|
219 |
+
return t_img, t_label_h
|
220 |
+
|
221 |
+
|
222 |
+
def transform_wrapper(scan, label, nclass, geometric_tfx, intensity_tfx):
|
223 |
+
return transform(scan, label, nclass, geometric_tfx, intensity_tfx)
|
224 |
+
|
dataloaders/common.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset classes for common uses
|
3 |
+
Extended from vanilla PANet code by Wang et al.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
class BaseDataset(Dataset):
|
11 |
+
"""
|
12 |
+
Base Dataset
|
13 |
+
Args:
|
14 |
+
base_dir:
|
15 |
+
dataset directory
|
16 |
+
"""
|
17 |
+
def __init__(self, base_dir):
|
18 |
+
self._base_dir = base_dir
|
19 |
+
self.aux_attrib = {}
|
20 |
+
self.aux_attrib_args = {}
|
21 |
+
self.ids = [] # must be overloaded in subclass
|
22 |
+
|
23 |
+
def add_attrib(self, key, func, func_args):
|
24 |
+
"""
|
25 |
+
Add attribute to the data sample dict
|
26 |
+
|
27 |
+
Args:
|
28 |
+
key:
|
29 |
+
key in the data sample dict for the new attribute
|
30 |
+
e.g. sample['click_map'], sample['depth_map']
|
31 |
+
func:
|
32 |
+
function to process a data sample and create an attribute (e.g. user clicks)
|
33 |
+
func_args:
|
34 |
+
extra arguments to pass, expected a dict
|
35 |
+
"""
|
36 |
+
if key in self.aux_attrib:
|
37 |
+
raise KeyError("Attribute '{0}' already exists, please use 'set_attrib'.".format(key))
|
38 |
+
else:
|
39 |
+
self.set_attrib(key, func, func_args)
|
40 |
+
|
41 |
+
def set_attrib(self, key, func, func_args):
|
42 |
+
"""
|
43 |
+
Set attribute in the data sample dict
|
44 |
+
|
45 |
+
Args:
|
46 |
+
key:
|
47 |
+
key in the data sample dict for the new attribute
|
48 |
+
e.g. sample['click_map'], sample['depth_map']
|
49 |
+
func:
|
50 |
+
function to process a data sample and create an attribute (e.g. user clicks)
|
51 |
+
func_args:
|
52 |
+
extra arguments to pass, expected a dict
|
53 |
+
"""
|
54 |
+
self.aux_attrib[key] = func
|
55 |
+
self.aux_attrib_args[key] = func_args
|
56 |
+
|
57 |
+
def del_attrib(self, key):
|
58 |
+
"""
|
59 |
+
Remove attribute in the data sample dict
|
60 |
+
|
61 |
+
Args:
|
62 |
+
key:
|
63 |
+
key in the data sample dict
|
64 |
+
"""
|
65 |
+
self.aux_attrib.pop(key)
|
66 |
+
self.aux_attrib_args.pop(key)
|
67 |
+
|
68 |
+
def subsets(self, sub_ids, sub_args_lst=None):
|
69 |
+
"""
|
70 |
+
Create subsets by ids
|
71 |
+
|
72 |
+
Args:
|
73 |
+
sub_ids:
|
74 |
+
a sequence of sequences, each sequence contains data ids for one subset
|
75 |
+
sub_args_lst:
|
76 |
+
a list of args for some subset-specific auxiliary attribute function
|
77 |
+
"""
|
78 |
+
|
79 |
+
indices = [[self.ids.index(id_) for id_ in ids] for ids in sub_ids]
|
80 |
+
if sub_args_lst is not None:
|
81 |
+
subsets = [Subset(dataset=self, indices=index, sub_attrib_args=args)
|
82 |
+
for index, args in zip(indices, sub_args_lst)]
|
83 |
+
else:
|
84 |
+
subsets = [Subset(dataset=self, indices=index) for index in indices]
|
85 |
+
return subsets
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
pass
|
89 |
+
|
90 |
+
def __getitem__(self, idx):
|
91 |
+
pass
|
92 |
+
|
93 |
+
|
94 |
+
class ReloadPairedDataset(Dataset):
|
95 |
+
"""
|
96 |
+
Make pairs of data from dataset
|
97 |
+
Eable only loading part of the entire data in each epoach and then reload to the next part
|
98 |
+
Args:
|
99 |
+
datasets:
|
100 |
+
source datasets, expect a list of Dataset.
|
101 |
+
Each dataset indices a certain class. It contains a list of all z-indices of this class for each scan
|
102 |
+
n_elements:
|
103 |
+
number of elements in a pair
|
104 |
+
curr_max_iters:
|
105 |
+
number of pairs in an epoch
|
106 |
+
pair_based_transforms:
|
107 |
+
some transformation performed on a pair basis, expect a list of functions,
|
108 |
+
each function takes a pair sample and return a transformed one.
|
109 |
+
"""
|
110 |
+
def __init__(self, datasets, n_elements, curr_max_iters,
|
111 |
+
pair_based_transforms=None):
|
112 |
+
super().__init__()
|
113 |
+
self.datasets = datasets
|
114 |
+
self.n_datasets = len(self.datasets)
|
115 |
+
self.n_data = [len(dataset) for dataset in self.datasets]
|
116 |
+
self.n_elements = n_elements
|
117 |
+
self.curr_max_iters = curr_max_iters
|
118 |
+
self.pair_based_transforms = pair_based_transforms
|
119 |
+
self.update_index()
|
120 |
+
|
121 |
+
def update_index(self):
|
122 |
+
"""
|
123 |
+
update the order of batches for the next episode
|
124 |
+
"""
|
125 |
+
|
126 |
+
# update number of elements for each subset
|
127 |
+
if hasattr(self, 'indices'):
|
128 |
+
n_data_old = self.n_data # DEBUG
|
129 |
+
self.n_data = [len(dataset) for dataset in self.datasets]
|
130 |
+
|
131 |
+
if isinstance(self.n_elements, list):
|
132 |
+
self.indices = [[(dataset_idx, data_idx) for i, dataset_idx in enumerate(random.sample(range(self.n_datasets), k=len(self.n_elements))) # select which way(s) to use
|
133 |
+
for data_idx in random.sample(range(self.n_data[dataset_idx]), k=self.n_elements[i])] # for each way, which sample to use
|
134 |
+
for i_iter in range(self.curr_max_iters)] # sample <self.curr_max_iters> iterations
|
135 |
+
|
136 |
+
elif self.n_elements > self.n_datasets:
|
137 |
+
raise ValueError("When 'same=False', 'n_element' should be no more than n_datasets")
|
138 |
+
else:
|
139 |
+
self.indices = [[(dataset_idx, random.randrange(self.n_data[dataset_idx]))
|
140 |
+
for dataset_idx in random.sample(range(self.n_datasets),
|
141 |
+
k=n_elements)]
|
142 |
+
for i in range(curr_max_iters)]
|
143 |
+
|
144 |
+
def __len__(self):
|
145 |
+
return self.curr_max_iters
|
146 |
+
|
147 |
+
def __getitem__(self, idx):
|
148 |
+
sample = [self.datasets[dataset_idx][data_idx]
|
149 |
+
for dataset_idx, data_idx in self.indices[idx]]
|
150 |
+
if self.pair_based_transforms is not None:
|
151 |
+
for transform, args in self.pair_based_transforms:
|
152 |
+
sample = transform(sample, **args)
|
153 |
+
return sample
|
154 |
+
|
155 |
+
class Subset(Dataset):
|
156 |
+
"""
|
157 |
+
Subset of a dataset at specified indices. Used for seperating a dataset by class in our context
|
158 |
+
|
159 |
+
Args:
|
160 |
+
dataset:
|
161 |
+
The whole Dataset
|
162 |
+
indices:
|
163 |
+
Indices of samples of the current class in the entire dataset
|
164 |
+
sub_attrib_args:
|
165 |
+
Subset-specific arguments for attribute functions, expected a dict
|
166 |
+
"""
|
167 |
+
def __init__(self, dataset, indices, sub_attrib_args=None):
|
168 |
+
self.dataset = dataset
|
169 |
+
self.indices = indices
|
170 |
+
self.sub_attrib_args = sub_attrib_args
|
171 |
+
|
172 |
+
def __getitem__(self, idx):
|
173 |
+
if self.sub_attrib_args is not None:
|
174 |
+
for key in self.sub_attrib_args:
|
175 |
+
# Make sure the dataset already has the corresponding attributes
|
176 |
+
# Here we only make the arguments subset dependent
|
177 |
+
# (i.e. pass different arguments for each subset)
|
178 |
+
self.dataset.aux_attrib_args[key].update(self.sub_attrib_args[key])
|
179 |
+
return self.dataset[self.indices[idx]]
|
180 |
+
|
181 |
+
def __len__(self):
|
182 |
+
return len(self.indices)
|
183 |
+
|
184 |
+
class ValidationDataset(Dataset):
|
185 |
+
"""
|
186 |
+
Dataset for validation
|
187 |
+
|
188 |
+
Args:
|
189 |
+
dataset:
|
190 |
+
source dataset with a __getitem__ method
|
191 |
+
test_classes:
|
192 |
+
test classes
|
193 |
+
npart: int. number of parts, used for evaluation when assigning support images
|
194 |
+
|
195 |
+
"""
|
196 |
+
def __init__(self, dataset, test_classes: list, npart: int):
|
197 |
+
super().__init__()
|
198 |
+
self.dataset = dataset
|
199 |
+
self.__curr_cls = None
|
200 |
+
self.test_classes = test_classes
|
201 |
+
self.dataset.aux_attrib = None
|
202 |
+
self.npart = npart
|
203 |
+
|
204 |
+
def set_curr_cls(self, curr_cls):
|
205 |
+
assert curr_cls in self.test_classes
|
206 |
+
self.__curr_cls = curr_cls
|
207 |
+
|
208 |
+
def get_curr_cls(self):
|
209 |
+
return self.__curr_cls
|
210 |
+
|
211 |
+
def read_dataset(self):
|
212 |
+
"""
|
213 |
+
override original read_dataset to allow reading with z_margin
|
214 |
+
"""
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def __len__(self):
|
218 |
+
return len(self.dataset)
|
219 |
+
|
220 |
+
def label_strip(self, label):
|
221 |
+
"""
|
222 |
+
mask unrelated labels out
|
223 |
+
"""
|
224 |
+
out = torch.where(label == self.__curr_cls,
|
225 |
+
torch.ones_like(label), torch.zeros_like(label))
|
226 |
+
return out
|
227 |
+
|
228 |
+
def __getitem__(self, idx):
|
229 |
+
if self.__curr_cls is None:
|
230 |
+
raise Exception("Please initialize current class first")
|
231 |
+
|
232 |
+
sample = self.dataset[idx]
|
233 |
+
sample["label"] = self.label_strip( sample["label"] )
|
234 |
+
sample["label_t"] = sample["label"].unsqueeze(-1).data.numpy()
|
235 |
+
|
236 |
+
labelname = self.dataset.all_label_names[self.__curr_cls]
|
237 |
+
z_min = min(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
|
238 |
+
z_max = max(self.dataset.tp1_cls_map[labelname][sample['scan_id']])
|
239 |
+
sample["z_min"], sample["z_max"] = z_min, z_max
|
240 |
+
try:
|
241 |
+
part_assign = int((sample["z_id"] - z_min) // ((z_max - z_min) / self.npart))
|
242 |
+
except:
|
243 |
+
part_assign = 0
|
244 |
+
# print("###### DATASET: support only has one valid slice ######")
|
245 |
+
if part_assign < 0:
|
246 |
+
part_assign = 0
|
247 |
+
elif part_assign >= self.npart:
|
248 |
+
part_assign = self.npart - 1
|
249 |
+
sample["part_assign"] = part_assign
|
250 |
+
sample["case"] = sample["scan_id"]
|
251 |
+
|
252 |
+
return sample
|
253 |
+
|
254 |
+
def get_support_set(self, config, n_support=3):
|
255 |
+
support_batched = self.dataset.get_support(curr_class=self.__curr_cls, class_idx= [self.__curr_cls], scan_idx=config["support_idx"], npart=config["task"]["npart"])
|
256 |
+
|
257 |
+
support_images = [img for way in support_batched["support_images"] for img in way]
|
258 |
+
support_labels = [fgmask['fg_mask'] for way in support_batched["support_mask"] for fgmask in way]
|
259 |
+
support_scan_id = self.dataset.potential_support_sid
|
260 |
+
return {"support_images": support_images, "support_labels": support_labels, "support_scan_id": support_scan_id}
|
261 |
+
|
262 |
+
|
263 |
+
|
dataloaders/dataset_utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils for datasets
|
3 |
+
"""
|
4 |
+
import functools
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import nibabel as nib
|
10 |
+
import numpy as np
|
11 |
+
import pdb
|
12 |
+
import SimpleITK as sitk
|
13 |
+
|
14 |
+
DATASET_INFO = {
|
15 |
+
"CHAOST2": {
|
16 |
+
'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
|
17 |
+
'REAL_LABEL_NAME': ["BG", "LIVER", "RK", "LK", "SPLEEN"],
|
18 |
+
'_SEP': [0, 4, 8, 12, 16, 20],
|
19 |
+
'MODALITY': 'MR',
|
20 |
+
'LABEL_GROUP': {
|
21 |
+
'pa_all': set(range(1, 5)),
|
22 |
+
0: set([1, 4]), # upper_abdomen, leaving kidneies as testing classes
|
23 |
+
1: set([2, 3]), # lower_abdomen
|
24 |
+
},
|
25 |
+
},
|
26 |
+
|
27 |
+
"SABS": {
|
28 |
+
'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
|
29 |
+
|
30 |
+
'REAL_LABEL_NAME': ["BGD", "SPLEEN", "KID_R", "KID_l", "GALLBLADDER", "ESOPHAGUS", "LIVER", "STOMACH", "AORTA", "IVC",\
|
31 |
+
"PS_VEIN", "PANCREAS", "AG_R", "AG_L"],
|
32 |
+
'_SEP': [0, 6, 12, 18, 24, 30],
|
33 |
+
'MODALITY': 'CT',
|
34 |
+
'LABEL_GROUP':{
|
35 |
+
'pa_all': set( [1,2,3,6] ),
|
36 |
+
0: set([1,6] ), # upper_abdomen: spleen + liver as training, kidneis are testing
|
37 |
+
1: set( [2,3] ), # lower_abdomen
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"LITS17": {
|
41 |
+
'PSEU_LABEL_NAME': ["BGD", "SUPFG"],
|
42 |
+
|
43 |
+
'REAL_LABEL_NAME': ["BGD", "LIVER", "TUMOR"],
|
44 |
+
'_SEP': [0, 26, 52, 78, 104],
|
45 |
+
'MODALITY': 'CT',
|
46 |
+
'LABEL_GROUP':{
|
47 |
+
'pa_all': set( [1 , 2] ),
|
48 |
+
0: set([1 ] ), # liver
|
49 |
+
1: set( [ 2] ), # tumor
|
50 |
+
2: set([1,2]) # liver + tumor
|
51 |
+
}
|
52 |
+
|
53 |
+
}
|
54 |
+
|
55 |
+
}
|
56 |
+
|
57 |
+
def read_nii_bysitk(input_fid, peel_info = False):
|
58 |
+
""" read nii to numpy through simpleitk
|
59 |
+
|
60 |
+
peelinfo: taking direction, origin, spacing and metadata out
|
61 |
+
"""
|
62 |
+
img_obj = sitk.ReadImage(input_fid)
|
63 |
+
img_np = sitk.GetArrayFromImage(img_obj)
|
64 |
+
if peel_info:
|
65 |
+
info_obj = {
|
66 |
+
"spacing": img_obj.GetSpacing(),
|
67 |
+
"origin": img_obj.GetOrigin(),
|
68 |
+
"direction": img_obj.GetDirection(),
|
69 |
+
"array_size": img_np.shape
|
70 |
+
}
|
71 |
+
return img_np, info_obj
|
72 |
+
else:
|
73 |
+
return img_np
|
74 |
+
|
75 |
+
|
76 |
+
def get_CT_statistics(scan_fids):
|
77 |
+
"""
|
78 |
+
As CT are quantitative, get mean and std for CT images for image normalizing
|
79 |
+
As in reality we might not be able to load all images at a time, we would better detach statistics calculation with actual data loading
|
80 |
+
"""
|
81 |
+
total_val = 0
|
82 |
+
n_pix = 0
|
83 |
+
for fid in scan_fids:
|
84 |
+
in_img = read_nii_bysitk(fid)
|
85 |
+
total_val += in_img.sum()
|
86 |
+
n_pix += np.prod(in_img.shape)
|
87 |
+
del in_img
|
88 |
+
meanval = total_val / n_pix
|
89 |
+
|
90 |
+
total_var = 0
|
91 |
+
for fid in scan_fids:
|
92 |
+
in_img = read_nii_bysitk(fid)
|
93 |
+
total_var += np.sum((in_img - meanval) ** 2 )
|
94 |
+
del in_img
|
95 |
+
var_all = total_var / n_pix
|
96 |
+
|
97 |
+
global_std = var_all ** 0.5
|
98 |
+
|
99 |
+
return meanval, global_std
|
100 |
+
|
101 |
+
def MR_normalize(x_in):
|
102 |
+
return (x_in - x_in.mean()) / x_in.std()
|
103 |
+
|
104 |
+
def CT_normalize(x_in, ct_mean, ct_std):
|
105 |
+
"""
|
106 |
+
Normalizing CT images, based on global statistics
|
107 |
+
"""
|
108 |
+
return (x_in - ct_mean) / ct_std
|
109 |
+
|
110 |
+
def get_normalize_op(modality, fids, ct_mean=None, ct_std=None):
|
111 |
+
"""
|
112 |
+
As title
|
113 |
+
Args:
|
114 |
+
modality: CT or MR
|
115 |
+
fids: fids for the fold
|
116 |
+
"""
|
117 |
+
if modality == 'MR':
|
118 |
+
return MR_normalize
|
119 |
+
|
120 |
+
elif modality == 'CT':
|
121 |
+
if ct_mean is None or ct_std is None:
|
122 |
+
ct_mean, ct_std = get_CT_statistics(fids)
|
123 |
+
# debug
|
124 |
+
print(f'###### DEBUG_DATASET CT_STATS NORMALIZED MEAN {ct_mean} STD {ct_std} ######')
|
125 |
+
|
126 |
+
return functools.partial(CT_normalize, ct_mean=ct_mean, ct_std=ct_std)
|
127 |
+
|
128 |
+
|
dataloaders/dev_customized_med.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Customized dataset. Extended from vanilla PANet script by Wang et al.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from dataloaders.common import ReloadPairedDataset, ValidationDataset
|
11 |
+
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset
|
12 |
+
|
13 |
+
def attrib_basic(_sample, class_id):
|
14 |
+
"""
|
15 |
+
Add basic attribute
|
16 |
+
Args:
|
17 |
+
_sample: data sample
|
18 |
+
class_id: class label asscociated with the data
|
19 |
+
(sometimes indicting from which subset the data are drawn)
|
20 |
+
"""
|
21 |
+
return {'class_id': class_id}
|
22 |
+
|
23 |
+
def getMaskOnly(label, class_id, class_ids):
|
24 |
+
"""
|
25 |
+
Generate FG/BG mask from the segmentation mask
|
26 |
+
|
27 |
+
Args:
|
28 |
+
label:
|
29 |
+
semantic mask
|
30 |
+
scribble:
|
31 |
+
scribble mask
|
32 |
+
class_id:
|
33 |
+
semantic class of interest
|
34 |
+
class_ids:
|
35 |
+
all class id in this episode
|
36 |
+
"""
|
37 |
+
# Dense Mask
|
38 |
+
fg_mask = torch.where(label == class_id,
|
39 |
+
torch.ones_like(label), torch.zeros_like(label))
|
40 |
+
bg_mask = torch.where(label != class_id,
|
41 |
+
torch.ones_like(label), torch.zeros_like(label))
|
42 |
+
for class_id in class_ids:
|
43 |
+
bg_mask[label == class_id] = 0
|
44 |
+
|
45 |
+
return {'fg_mask': fg_mask,
|
46 |
+
'bg_mask': bg_mask}
|
47 |
+
|
48 |
+
def getMasks(*args, **kwargs):
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True):
|
52 |
+
"""
|
53 |
+
Postprocess paired sample for fewshot settings
|
54 |
+
For now only 1-way is tested but we leave multi-way possible (inherited from original PANet)
|
55 |
+
|
56 |
+
Args:
|
57 |
+
paired_sample:
|
58 |
+
data sample from a PairedDataset
|
59 |
+
n_ways:
|
60 |
+
n-way few-shot learning
|
61 |
+
n_shots:
|
62 |
+
n-shot few-shot learning
|
63 |
+
cnt_query:
|
64 |
+
number of query images for each class in the support set
|
65 |
+
coco:
|
66 |
+
MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension
|
67 |
+
mask_only:
|
68 |
+
only give masks and no scribbles/ instances. Suitable for medical images (for now)
|
69 |
+
"""
|
70 |
+
if not mask_only:
|
71 |
+
raise NotImplementedError
|
72 |
+
###### Compose the support and query image list ######
|
73 |
+
cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries
|
74 |
+
|
75 |
+
# support class ids
|
76 |
+
class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query)
|
77 |
+
|
78 |
+
# support images
|
79 |
+
support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)]
|
80 |
+
for i in range(n_ways)] # fetch support images for each class
|
81 |
+
|
82 |
+
# support image labels
|
83 |
+
if coco:
|
84 |
+
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]]
|
85 |
+
for j in range(n_shots)] for i in range(n_ways)]
|
86 |
+
else:
|
87 |
+
support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)]
|
88 |
+
for i in range(n_ways)]
|
89 |
+
|
90 |
+
if not mask_only:
|
91 |
+
support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)]
|
92 |
+
for i in range(n_ways)]
|
93 |
+
support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)]
|
94 |
+
for i in range(n_ways)]
|
95 |
+
else:
|
96 |
+
support_insts = []
|
97 |
+
|
98 |
+
# query images, masks and class indices
|
99 |
+
query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways)
|
100 |
+
for j in range(cnt_query[i])]
|
101 |
+
if coco:
|
102 |
+
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]]
|
103 |
+
for i in range(n_ways) for j in range(cnt_query[i])]
|
104 |
+
else:
|
105 |
+
query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways)
|
106 |
+
for j in range(cnt_query[i])]
|
107 |
+
query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1
|
108 |
+
for x in set(np.unique(query_label)) & set(class_ids)])
|
109 |
+
for query_label in query_labels]
|
110 |
+
|
111 |
+
###### Generate support image masks ######
|
112 |
+
if not mask_only:
|
113 |
+
support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot],
|
114 |
+
class_ids[way], class_ids)
|
115 |
+
for shot in range(n_shots)] for way in range(n_ways)]
|
116 |
+
else:
|
117 |
+
support_mask = [[getMaskOnly(support_labels[way][shot],
|
118 |
+
class_ids[way], class_ids)
|
119 |
+
for shot in range(n_shots)] for way in range(n_ways)]
|
120 |
+
|
121 |
+
###### Generate query label (class indices in one episode, i.e. the ground truth)######
|
122 |
+
query_labels_tmp = [torch.zeros_like(x) for x in query_labels]
|
123 |
+
for i, query_label_tmp in enumerate(query_labels_tmp):
|
124 |
+
query_label_tmp[query_labels[i] == 255] = 255
|
125 |
+
for j in range(n_ways):
|
126 |
+
query_label_tmp[query_labels[i] == class_ids[j]] = j + 1
|
127 |
+
|
128 |
+
###### Generate query mask for each semantic class (including BG) ######
|
129 |
+
# BG class
|
130 |
+
query_masks = [[torch.where(query_label == 0,
|
131 |
+
torch.ones_like(query_label),
|
132 |
+
torch.zeros_like(query_label))[None, ...],]
|
133 |
+
for query_label in query_labels]
|
134 |
+
# Other classes in query image
|
135 |
+
for i, query_label in enumerate(query_labels):
|
136 |
+
for idx in query_cls_idx[i][1:]:
|
137 |
+
mask = torch.where(query_label == class_ids[idx - 1],
|
138 |
+
torch.ones_like(query_label),
|
139 |
+
torch.zeros_like(query_label))[None, ...]
|
140 |
+
query_masks[i].append(mask)
|
141 |
+
|
142 |
+
|
143 |
+
return {'class_ids': class_ids,
|
144 |
+
'support_images': support_images,
|
145 |
+
'support_mask': support_mask,
|
146 |
+
'support_inst': support_insts, # leave these interfaces
|
147 |
+
'support_scribbles': support_scribbles,
|
148 |
+
|
149 |
+
'query_images': query_images,
|
150 |
+
'query_labels': query_labels_tmp,
|
151 |
+
'query_masks': query_masks,
|
152 |
+
'query_cls_idx': query_cls_idx,
|
153 |
+
}
|
154 |
+
|
155 |
+
|
156 |
+
def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load,
|
157 |
+
transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs):
|
158 |
+
"""
|
159 |
+
Dataset wrapper
|
160 |
+
Args:
|
161 |
+
dataset_name:
|
162 |
+
indicates what dataset to use
|
163 |
+
base_dir:
|
164 |
+
dataset directory
|
165 |
+
mode:
|
166 |
+
which mode to use
|
167 |
+
choose from ('train', 'val', 'trainval', 'trainaug')
|
168 |
+
idx_split:
|
169 |
+
index of split
|
170 |
+
scan_per_load:
|
171 |
+
number of scans to load into memory as the dataset is large
|
172 |
+
use that together with reload_buffer
|
173 |
+
transforms:
|
174 |
+
transformations to be performed on images/masks
|
175 |
+
act_labels:
|
176 |
+
active labels involved in training process. Should be a subset of all labels
|
177 |
+
n_ways:
|
178 |
+
n-way few-shot learning, should be no more than # of object class labels
|
179 |
+
n_shots:
|
180 |
+
n-shot few-shot learning
|
181 |
+
max_iters_per_load:
|
182 |
+
number of pairs per load (epoch size)
|
183 |
+
n_queries:
|
184 |
+
number of query images
|
185 |
+
fix_parent_len:
|
186 |
+
fixed length of the parent dataset
|
187 |
+
"""
|
188 |
+
med_set = ManualAnnoDataset
|
189 |
+
|
190 |
+
|
191 |
+
mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\
|
192 |
+
scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\
|
193 |
+
exclude_list = exclude_list, **kwargs)
|
194 |
+
|
195 |
+
mydataset.add_attrib('basic', attrib_basic, {})
|
196 |
+
|
197 |
+
# Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside
|
198 |
+
subsets = mydataset.subsets([{'basic': {'class_id': ii}}
|
199 |
+
for ii, _ in enumerate(mydataset.label_name)])
|
200 |
+
|
201 |
+
# Choose the classes of queries
|
202 |
+
cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways)
|
203 |
+
# Number of queries for each way
|
204 |
+
# Set the number of images for each class
|
205 |
+
n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries
|
206 |
+
# Create paired dataset. We do not include background.
|
207 |
+
paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load,
|
208 |
+
pair_based_transforms=[
|
209 |
+
(fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots,
|
210 |
+
'cnt_query': cnt_query, 'mask_only': True})])
|
211 |
+
return paired_data, mydataset
|
212 |
+
|
213 |
+
def update_loader_dset(loader, parent_set):
|
214 |
+
"""
|
215 |
+
Update data loader and the parent dataset behind
|
216 |
+
Args:
|
217 |
+
loader: actual dataloader
|
218 |
+
parent_set: parent dataset which actually stores the data
|
219 |
+
"""
|
220 |
+
parent_set.reload_buffer()
|
221 |
+
loader.dataset.update_index()
|
222 |
+
print(f'###### Loader and dataset have been updated ######' )
|
223 |
+
|
224 |
+
def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs):
|
225 |
+
"""
|
226 |
+
validation set for med images
|
227 |
+
Args:
|
228 |
+
dataset_name:
|
229 |
+
indicates what dataset to use
|
230 |
+
base_dir:
|
231 |
+
SABS dataset directory
|
232 |
+
mode: (original split)
|
233 |
+
which split to use
|
234 |
+
choose from ('train', 'val', 'trainval', 'trainaug')
|
235 |
+
idx_split:
|
236 |
+
index of split
|
237 |
+
scan_per_batch:
|
238 |
+
number of scans to load into memory as the dataset is large
|
239 |
+
use that together with reload_buffer
|
240 |
+
act_labels:
|
241 |
+
actual labels involved in training process. Should be a subset of all labels
|
242 |
+
npart: number of chunks for splitting a 3d volume
|
243 |
+
nsup: number of support scans, equivalent to nshot
|
244 |
+
"""
|
245 |
+
mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs)
|
246 |
+
mydataset.add_attrib('basic', attrib_basic, {})
|
247 |
+
|
248 |
+
valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart)
|
249 |
+
|
250 |
+
return valset, mydataset
|
dataloaders/image_transforms.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Image transforms functions for data augmentation
|
3 |
+
Credit to Dr. Jo Schlemper
|
4 |
+
"""
|
5 |
+
|
6 |
+
from collections.abc import Sequence
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import scipy
|
10 |
+
from scipy.ndimage.filters import gaussian_filter
|
11 |
+
from scipy.ndimage.interpolation import map_coordinates
|
12 |
+
from numpy.lib.stride_tricks import as_strided
|
13 |
+
import numpy as np
|
14 |
+
import cv2
|
15 |
+
from scipy.ndimage import map_coordinates
|
16 |
+
from numpy.lib.stride_tricks import as_strided
|
17 |
+
from multiprocessing import Pool
|
18 |
+
import albumentations as A
|
19 |
+
import time
|
20 |
+
|
21 |
+
###### UTILITIES ######
|
22 |
+
def random_num_generator(config, random_state=np.random):
|
23 |
+
if config[0] == 'uniform':
|
24 |
+
ret = random_state.uniform(config[1], config[2], 1)[0]
|
25 |
+
elif config[0] == 'lognormal':
|
26 |
+
ret = random_state.lognormal(config[1], config[2], 1)[0]
|
27 |
+
else:
|
28 |
+
#print(config)
|
29 |
+
raise Exception('unsupported format')
|
30 |
+
return ret
|
31 |
+
|
32 |
+
def get_translation_matrix(translation):
|
33 |
+
""" translation: [tx, ty] """
|
34 |
+
tx, ty = translation
|
35 |
+
translation_matrix = np.array([[1, 0, tx],
|
36 |
+
[0, 1, ty],
|
37 |
+
[0, 0, 1]])
|
38 |
+
return translation_matrix
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def get_rotation_matrix(rotation, input_shape, centred=True):
|
43 |
+
theta = np.pi / 180 * np.array(rotation)
|
44 |
+
if centred:
|
45 |
+
rotation_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), rotation, 1)
|
46 |
+
rotation_matrix = np.vstack([rotation_matrix, [0, 0, 1]])
|
47 |
+
else:
|
48 |
+
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
|
49 |
+
[np.sin(theta), np.cos(theta), 0],
|
50 |
+
[0, 0, 1]])
|
51 |
+
return rotation_matrix
|
52 |
+
|
53 |
+
def get_zoom_matrix(zoom, input_shape, centred=True):
|
54 |
+
zx, zy = zoom
|
55 |
+
if centred:
|
56 |
+
zoom_matrix = cv2.getRotationMatrix2D((input_shape[0]/2, input_shape[1]//2), 0, zoom[0])
|
57 |
+
zoom_matrix = np.vstack([zoom_matrix, [0, 0, 1]])
|
58 |
+
else:
|
59 |
+
zoom_matrix = np.array([[zx, 0, 0],
|
60 |
+
[0, zy, 0],
|
61 |
+
[0, 0, 1]])
|
62 |
+
return zoom_matrix
|
63 |
+
|
64 |
+
def get_shear_matrix(shear_angle):
|
65 |
+
theta = (np.pi * shear_angle) / 180
|
66 |
+
shear_matrix = np.array([[1, -np.sin(theta), 0],
|
67 |
+
[0, np.cos(theta), 0],
|
68 |
+
[0, 0, 1]])
|
69 |
+
return shear_matrix
|
70 |
+
|
71 |
+
###### AFFINE TRANSFORM ######
|
72 |
+
class RandomAffine(object):
|
73 |
+
"""Apply random affine transformation on a numpy.ndarray (H x W x C)
|
74 |
+
Comment by co1818: this is still doing affine on 2d (H x W plane).
|
75 |
+
A same transform is applied to all C channels
|
76 |
+
|
77 |
+
Parameter:
|
78 |
+
----------
|
79 |
+
|
80 |
+
alpha: Range [0, 4] seems good for small images
|
81 |
+
|
82 |
+
order: interpolation method (c.f. opencv)
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
rotation_range=None,
|
87 |
+
translation_range=None,
|
88 |
+
shear_range=None,
|
89 |
+
zoom_range=None,
|
90 |
+
zoom_keep_aspect=False,
|
91 |
+
interp='bilinear',
|
92 |
+
use_3d=False,
|
93 |
+
order=3):
|
94 |
+
"""
|
95 |
+
Perform an affine transforms.
|
96 |
+
|
97 |
+
Arguments
|
98 |
+
---------
|
99 |
+
rotation_range : one integer or float
|
100 |
+
image will be rotated randomly between (-degrees, degrees)
|
101 |
+
|
102 |
+
translation_range : (x_shift, y_shift)
|
103 |
+
shifts in pixels
|
104 |
+
|
105 |
+
*NOT TESTED* shear_range : float
|
106 |
+
image will be sheared randomly between (-degrees, degrees)
|
107 |
+
|
108 |
+
zoom_range : (zoom_min, zoom_max)
|
109 |
+
list/tuple with two floats between [0, infinity).
|
110 |
+
first float should be less than the second
|
111 |
+
lower and upper bounds on percent zoom.
|
112 |
+
Anything less than 1.0 will zoom in on the image,
|
113 |
+
anything greater than 1.0 will zoom out on the image.
|
114 |
+
e.g. (0.7, 1.0) will only zoom in,
|
115 |
+
(1.0, 1.4) will only zoom out,
|
116 |
+
(0.7, 1.4) will randomly zoom in or out
|
117 |
+
"""
|
118 |
+
|
119 |
+
self.rotation_range = rotation_range
|
120 |
+
self.translation_range = translation_range
|
121 |
+
self.shear_range = shear_range
|
122 |
+
self.zoom_range = zoom_range
|
123 |
+
self.zoom_keep_aspect = zoom_keep_aspect
|
124 |
+
self.interp = interp
|
125 |
+
self.order = order
|
126 |
+
self.use_3d = use_3d
|
127 |
+
|
128 |
+
def build_M(self, input_shape):
|
129 |
+
tfx = []
|
130 |
+
final_tfx = np.eye(3)
|
131 |
+
if self.rotation_range:
|
132 |
+
rot = np.random.uniform(-self.rotation_range, self.rotation_range)
|
133 |
+
tfx.append(get_rotation_matrix(rot, input_shape))
|
134 |
+
if self.translation_range:
|
135 |
+
tx = np.random.uniform(-self.translation_range[0], self.translation_range[0])
|
136 |
+
ty = np.random.uniform(-self.translation_range[1], self.translation_range[1])
|
137 |
+
tfx.append(get_translation_matrix((tx,ty)))
|
138 |
+
if self.shear_range:
|
139 |
+
rot = np.random.uniform(-self.shear_range, self.shear_range)
|
140 |
+
tfx.append(get_shear_matrix(rot))
|
141 |
+
if self.zoom_range:
|
142 |
+
sx = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
|
143 |
+
if self.zoom_keep_aspect:
|
144 |
+
sy = sx
|
145 |
+
else:
|
146 |
+
sy = np.random.uniform(self.zoom_range[0], self.zoom_range[1])
|
147 |
+
|
148 |
+
tfx.append(get_zoom_matrix((sx, sy), input_shape))
|
149 |
+
|
150 |
+
for tfx_mat in tfx:
|
151 |
+
final_tfx = np.dot(tfx_mat, final_tfx)
|
152 |
+
|
153 |
+
return final_tfx.astype(np.float32)
|
154 |
+
|
155 |
+
def __call__(self, image):
|
156 |
+
# build matrix
|
157 |
+
input_shape = image.shape[:2]
|
158 |
+
M = self.build_M(input_shape)
|
159 |
+
|
160 |
+
res = np.zeros_like(image)
|
161 |
+
#if isinstance(self.interp, Sequence):
|
162 |
+
if type(self.order) is list or type(self.order) is tuple:
|
163 |
+
for i, intp in enumerate(self.order):
|
164 |
+
if self.use_3d:
|
165 |
+
res[..., i] = affine_transform_3d_via_M(image[..., i], M[:2], interp=intp)
|
166 |
+
else:
|
167 |
+
res[..., i] = affine_transform_via_M(image[..., i], M[:2], interp=intp)
|
168 |
+
else:
|
169 |
+
# squeeze if needed
|
170 |
+
orig_shape = image.shape
|
171 |
+
image_s = np.squeeze(image)
|
172 |
+
if self.use_3d:
|
173 |
+
res = affine_transform_3d_via_M(image_s, M[:2], interp=self.order)
|
174 |
+
else:
|
175 |
+
res = affine_transform_via_M(image_s, M[:2], interp=self.order)
|
176 |
+
res = res.reshape(orig_shape)
|
177 |
+
|
178 |
+
#res = affine_transform_via_M(image, M[:2], interp=self.order)
|
179 |
+
|
180 |
+
return res
|
181 |
+
|
182 |
+
def affine_transform_via_M(image, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
|
183 |
+
imshape = image.shape
|
184 |
+
shape_size = imshape[:2]
|
185 |
+
|
186 |
+
# Random affine
|
187 |
+
warped = cv2.warpAffine(image.reshape(shape_size + (-1,)), M, shape_size[::-1],
|
188 |
+
flags=interp, borderMode=borderMode)
|
189 |
+
|
190 |
+
#print(imshape, warped.shape)
|
191 |
+
|
192 |
+
warped = warped[..., np.newaxis].reshape(imshape)
|
193 |
+
|
194 |
+
return warped
|
195 |
+
|
196 |
+
def affine_transform_3d_via_M(vol, M, borderMode=cv2.BORDER_CONSTANT, interp=cv2.INTER_NEAREST):
|
197 |
+
"""
|
198 |
+
vol should be of shape (nx, ny, n1, ..., nm)
|
199 |
+
"""
|
200 |
+
# go over slice slice
|
201 |
+
res = np.zeros_like(vol)
|
202 |
+
for i in range(vol.shape[2]):
|
203 |
+
res[:, :, i] = affine_transform_via_M(vol[:,:,i], M, borderMode=borderMode, interp=interp)
|
204 |
+
|
205 |
+
return res
|
206 |
+
|
207 |
+
|
208 |
+
###### ELASTIC TRANSFORM ######
|
209 |
+
def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random):
|
210 |
+
"""Elastic deformation of image as described in [Simard2003]_.
|
211 |
+
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
212 |
+
Convolutional Neural Networks applied to Visual Document Analysis", in
|
213 |
+
Proc. of the International Conference on Document Analysis and
|
214 |
+
Recognition, 2003.
|
215 |
+
"""
|
216 |
+
assert image.ndim == 3
|
217 |
+
shape = image.shape[:2]
|
218 |
+
|
219 |
+
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1),
|
220 |
+
sigma, mode="constant", cval=0) * alpha
|
221 |
+
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1),
|
222 |
+
sigma, mode="constant", cval=0) * alpha
|
223 |
+
|
224 |
+
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
|
225 |
+
indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))]
|
226 |
+
result = np.empty_like(image)
|
227 |
+
for i in range(image.shape[2]):
|
228 |
+
result[:, :, i] = map_coordinates(
|
229 |
+
image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape)
|
230 |
+
return result
|
231 |
+
|
232 |
+
def elastic_transform_nd_3d(image, **kwargs):
|
233 |
+
"""
|
234 |
+
image_w_mask should be of shape (nx, ny, nz, 3)
|
235 |
+
"""
|
236 |
+
image_w_mask = image
|
237 |
+
start_time = time.time()
|
238 |
+
elastic_transform = A.ElasticTransform(alpha=10, sigma=20, alpha_affine=15, interpolation=1, border_mode=4, always_apply=True, p=0.5)
|
239 |
+
# print(f"elastic transform initilization took {time.time() - start_time} seconds")
|
240 |
+
img = image_w_mask[..., 0]
|
241 |
+
label = image_w_mask[..., -1]
|
242 |
+
transformed = elastic_transform(image=img, mask=label)
|
243 |
+
t_img = transformed['image'][..., np.newaxis]
|
244 |
+
t_mask = transformed['mask'][..., np.newaxis]
|
245 |
+
t_mask_bg = 1 - t_mask
|
246 |
+
t_mask = np.concatenate([t_mask_bg, t_mask], axis=-1)
|
247 |
+
|
248 |
+
comp = np.concatenate([t_img, t_mask], axis=-1)
|
249 |
+
return comp
|
250 |
+
|
251 |
+
def elastic_transform_nd(image, alpha, sigma, random_state=None, order=1, lazy=False):
|
252 |
+
"""Expects data to be (nx, ny, n1 ,..., nm)
|
253 |
+
params:
|
254 |
+
------
|
255 |
+
|
256 |
+
alpha:
|
257 |
+
the scaling parameter.
|
258 |
+
E.g.: alpha=2 => distorts images up to 2x scaling
|
259 |
+
|
260 |
+
sigma:
|
261 |
+
standard deviation of gaussian filter.
|
262 |
+
E.g.
|
263 |
+
low (sig~=1e-3) => no smoothing, pixelated.
|
264 |
+
high (1/5 * imsize) => smooth, more like affine.
|
265 |
+
very high (1/2*im_size) => translation
|
266 |
+
"""
|
267 |
+
|
268 |
+
if random_state is None:
|
269 |
+
random_state = np.random.RandomState(None)
|
270 |
+
|
271 |
+
shape = image.shape
|
272 |
+
imsize = shape[:2]
|
273 |
+
dim = shape[2:]
|
274 |
+
|
275 |
+
# Random affine
|
276 |
+
blur_size = int(4*sigma) | 1
|
277 |
+
dx = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
|
278 |
+
ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
|
279 |
+
dy = cv2.GaussianBlur(random_state.rand(*imsize)*2-1,
|
280 |
+
ksize=(blur_size, blur_size), sigmaX=sigma) * alpha
|
281 |
+
|
282 |
+
# use as_strided to copy things over across n1...nn channels
|
283 |
+
dx = as_strided(dx.astype(np.float32),
|
284 |
+
strides=(0,) * len(dim) + (4*shape[1], 4),
|
285 |
+
shape=dim+(shape[0], shape[1]))
|
286 |
+
dx = np.transpose(dx, axes=(-2, -1) + tuple(range(len(dim))))
|
287 |
+
|
288 |
+
dy = as_strided(dy.astype(np.float32),
|
289 |
+
strides=(0,) * len(dim) + (4*shape[1], 4),
|
290 |
+
shape=dim+(shape[0], shape[1]))
|
291 |
+
dy = np.transpose(dy, axes=(-2, -1) + tuple(range(len(dim))))
|
292 |
+
|
293 |
+
coord = np.meshgrid(*[np.arange(shape_i) for shape_i in (shape[1], shape[0]) + dim])
|
294 |
+
indices = [np.reshape(e+de, (-1, 1)) for e, de in zip([coord[1], coord[0]] + coord[2:],
|
295 |
+
[dy, dx] + [0] * len(dim))]
|
296 |
+
|
297 |
+
if lazy:
|
298 |
+
return indices
|
299 |
+
res = map_coordinates(image, indices, order=order, mode='reflect').reshape(shape)
|
300 |
+
return res
|
301 |
+
|
302 |
+
class ElasticTransform(object):
|
303 |
+
"""Apply elastic transformation on a numpy.ndarray (H x W x C)
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, alpha, sigma, order=1):
|
307 |
+
self.alpha = alpha
|
308 |
+
self.sigma = sigma
|
309 |
+
self.order = order
|
310 |
+
|
311 |
+
def __call__(self, image):
|
312 |
+
if isinstance(self.alpha, Sequence):
|
313 |
+
alpha = random_num_generator(self.alpha)
|
314 |
+
else:
|
315 |
+
alpha = self.alpha
|
316 |
+
if isinstance(self.sigma, Sequence):
|
317 |
+
sigma = random_num_generator(self.sigma)
|
318 |
+
else:
|
319 |
+
sigma = self.sigma
|
320 |
+
return elastic_transform_nd(image, alpha=alpha, sigma=sigma, order=self.order)
|
321 |
+
|
322 |
+
class RandomFlip3D(object):
|
323 |
+
|
324 |
+
def __init__(self, h=True, v=True, t=True, p=0.5):
|
325 |
+
"""
|
326 |
+
Randomly flip an image horizontally and/or vertically with
|
327 |
+
some probability.
|
328 |
+
|
329 |
+
Arguments
|
330 |
+
---------
|
331 |
+
h : boolean
|
332 |
+
whether to horizontally flip w/ probability p
|
333 |
+
|
334 |
+
v : boolean
|
335 |
+
whether to vertically flip w/ probability p
|
336 |
+
|
337 |
+
p : float between [0,1]
|
338 |
+
probability with which to apply allowed flipping operations
|
339 |
+
"""
|
340 |
+
self.horizontal = h
|
341 |
+
self.vertical = v
|
342 |
+
self.depth = t
|
343 |
+
self.p = p
|
344 |
+
|
345 |
+
def __call__(self, x, y=None):
|
346 |
+
# horizontal flip with p = self.p
|
347 |
+
if self.horizontal:
|
348 |
+
if np.random.random() < self.p:
|
349 |
+
x = x[::-1, ...]
|
350 |
+
|
351 |
+
# vertical flip with p = self.p
|
352 |
+
if self.vertical:
|
353 |
+
if np.random.random() < self.p:
|
354 |
+
x = x[:, ::-1, ...]
|
355 |
+
|
356 |
+
if self.depth:
|
357 |
+
if np.random.random() < self.p:
|
358 |
+
x = x[..., ::-1]
|
359 |
+
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
dataloaders/niftiio.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils for datasets
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import SimpleITK as sitk
|
8 |
+
|
9 |
+
|
10 |
+
def read_nii_bysitk(input_fid, peel_info = False):
|
11 |
+
""" read nii to numpy through simpleitk
|
12 |
+
peelinfo: taking direction, origin, spacing and metadata out
|
13 |
+
"""
|
14 |
+
img_obj = sitk.ReadImage(input_fid)
|
15 |
+
img_np = sitk.GetArrayFromImage(img_obj)
|
16 |
+
if peel_info:
|
17 |
+
info_obj = {
|
18 |
+
"spacing": img_obj.GetSpacing(),
|
19 |
+
"origin": img_obj.GetOrigin(),
|
20 |
+
"direction": img_obj.GetDirection(),
|
21 |
+
"array_size": img_np.shape
|
22 |
+
}
|
23 |
+
return img_np, info_obj
|
24 |
+
else:
|
25 |
+
return img_np
|
26 |
+
|
27 |
+
def convert_to_sitk(input_mat, peeled_info):
|
28 |
+
"""
|
29 |
+
write a numpy array to sitk image object with essential meta-data
|
30 |
+
"""
|
31 |
+
nii_obj = sitk.GetImageFromArray(input_mat)
|
32 |
+
if peeled_info:
|
33 |
+
nii_obj.SetSpacing( peeled_info["spacing"] )
|
34 |
+
nii_obj.SetOrigin( peeled_info["origin"] )
|
35 |
+
nii_obj.SetDirection(peeled_info["direction"] )
|
36 |
+
return nii_obj
|
37 |
+
|
38 |
+
def np2itk(img, ref_obj):
|
39 |
+
"""
|
40 |
+
img: numpy array
|
41 |
+
ref_obj: reference sitk object for copying information from
|
42 |
+
"""
|
43 |
+
itk_obj = sitk.GetImageFromArray(img)
|
44 |
+
itk_obj.SetSpacing( ref_obj.GetSpacing() )
|
45 |
+
itk_obj.SetOrigin( ref_obj.GetOrigin() )
|
46 |
+
itk_obj.SetDirection( ref_obj.GetDirection() )
|
47 |
+
return itk_obj
|
48 |
+
|
models/ProtoMedSAM.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from models.ProtoSAM import ModelWrapper
|
7 |
+
from segment_anything import sam_model_registry
|
8 |
+
from util.utils import rotate_tensor_no_crop, reverse_tensor, need_softmax, get_confidence_from_logits, get_connected_components, cca, plot_connected_components
|
9 |
+
|
10 |
+
class ProtoMedSAM(nn.Module):
|
11 |
+
def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/medsam_vit_b.pth", debug=False, use_cca=False, coarse_pred_only=False):
|
12 |
+
super().__init__()
|
13 |
+
if isinstance(image_size, int):
|
14 |
+
image_size = (image_size, image_size)
|
15 |
+
self.image_size = image_size
|
16 |
+
self.coarse_segmentation_model = coarse_segmentation_model
|
17 |
+
self.get_sam(sam_pretrained_path)
|
18 |
+
self.coarse_pred_only = coarse_pred_only
|
19 |
+
self.debug = debug
|
20 |
+
self.use_cca = use_cca
|
21 |
+
|
22 |
+
|
23 |
+
def get_sam(self, checkpoint_path):
|
24 |
+
model_type="vit_b" # TODO make generic?
|
25 |
+
if 'vit_h' in checkpoint_path:
|
26 |
+
model_type = "vit_h"
|
27 |
+
self.medsam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
|
28 |
+
|
29 |
+
|
30 |
+
torch.no_grad()
|
31 |
+
def medsam_inference(self, img_embed, box_1024, H, W, query_label=None):
|
32 |
+
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
|
33 |
+
if len(box_torch.shape) == 2:
|
34 |
+
box_torch = box_torch[:, None, :] # (B, 1, 4)
|
35 |
+
|
36 |
+
sparse_embeddings, dense_embeddings = self.medsam.prompt_encoder(
|
37 |
+
points=None,
|
38 |
+
boxes=box_torch,
|
39 |
+
masks=None,
|
40 |
+
)
|
41 |
+
low_res_logits, conf = self.medsam.mask_decoder(
|
42 |
+
image_embeddings=img_embed, # (B, 256, 64, 64)
|
43 |
+
image_pe=self.medsam.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
44 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
45 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
46 |
+
multimask_output=True if query_label is not None else False,
|
47 |
+
)
|
48 |
+
|
49 |
+
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
50 |
+
|
51 |
+
low_res_pred = F.interpolate(
|
52 |
+
low_res_pred,
|
53 |
+
size=(H, W),
|
54 |
+
mode="bilinear",
|
55 |
+
align_corners=False,
|
56 |
+
) # (1, 1, gt.shape)
|
57 |
+
low_res_pred = low_res_pred.squeeze().cpu() # (256, 256)
|
58 |
+
|
59 |
+
low_res_pred = low_res_pred.numpy()
|
60 |
+
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
61 |
+
|
62 |
+
if query_label is not None:
|
63 |
+
medsam_seg = self.get_best_mask(medsam_seg, query_label)[None, :]
|
64 |
+
|
65 |
+
return medsam_seg, conf.cpu().detach().numpy()
|
66 |
+
|
67 |
+
def get_iou(self, pred, label):
|
68 |
+
"""
|
69 |
+
pred np array shape h,w type uint8
|
70 |
+
label np array shpae h,w type uiint8
|
71 |
+
"""
|
72 |
+
tp = np.logical_and(pred, label).sum()
|
73 |
+
fp = np.logical_and(pred, 1-label).sum()
|
74 |
+
fn = np.logical_and(1-pred, label).sum()
|
75 |
+
iou = tp / (tp + fp + fn)
|
76 |
+
return iou
|
77 |
+
|
78 |
+
def get_best_mask(self, masks, labels):
|
79 |
+
"""
|
80 |
+
masks np shape ( B, h, w)
|
81 |
+
labels torch shape (1, H, W)
|
82 |
+
"""
|
83 |
+
np_labels = labels[0].clone().detach().cpu().numpy()
|
84 |
+
best_iou, best_mask = 0, None
|
85 |
+
for mask in masks:
|
86 |
+
iou = self.get_iou(mask, np_labels)
|
87 |
+
if iou > best_iou:
|
88 |
+
best_iou = iou
|
89 |
+
best_mask = mask
|
90 |
+
|
91 |
+
return best_mask
|
92 |
+
|
93 |
+
def get_bbox(self, pred):
|
94 |
+
"""
|
95 |
+
pred is tensor of shape (H,W) - 1 is fg, 0 is bg.
|
96 |
+
return bbox of pred s.t np.array([xmin, y_min, xmax, ymax])
|
97 |
+
"""
|
98 |
+
if isinstance(pred, np.ndarray):
|
99 |
+
pred = torch.from_numpy(pred)
|
100 |
+
if pred.max() == 0:
|
101 |
+
return None
|
102 |
+
indices = torch.nonzero(pred)
|
103 |
+
ymin, xmin = indices.min(dim=0)[0]
|
104 |
+
ymax, xmax = indices.max(dim=0)[0]
|
105 |
+
return np.array([xmin, ymin, xmax, ymax])
|
106 |
+
|
107 |
+
|
108 |
+
def get_bbox_per_cc(self, conn_components):
|
109 |
+
"""
|
110 |
+
conn_components: output of cca function
|
111 |
+
return list of bboxes per connected component, each bbox is a list of 2d points
|
112 |
+
"""
|
113 |
+
bboxes = []
|
114 |
+
for i in range(1, conn_components[0]):
|
115 |
+
# get the indices of the foreground points
|
116 |
+
pred = torch.tensor(conn_components[1] == i, dtype=torch.uint8)
|
117 |
+
bboxes.append(self.get_bbox(pred))
|
118 |
+
|
119 |
+
bboxes = np.array(bboxes)
|
120 |
+
return bboxes
|
121 |
+
|
122 |
+
def forward(self, query_image, coarse_model_input, degrees_rotate=0):
|
123 |
+
"""
|
124 |
+
query_image: 3d tensor of shape (1, 3, H, W)
|
125 |
+
images should be normalized with mean and std but not to [0, 1]?
|
126 |
+
"""
|
127 |
+
original_size = query_image.shape[-2]
|
128 |
+
# rotate query_image by degrees_rotate
|
129 |
+
rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
|
130 |
+
# print(f"rotating query image took {time.time() - start_time} seconds")
|
131 |
+
coarse_model_input.set_query_images(rotated_img)
|
132 |
+
output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
|
133 |
+
# print(f"ALPNet took {time.time() - start_time} seconds")
|
134 |
+
|
135 |
+
if degrees_rotate != 0:
|
136 |
+
output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
|
137 |
+
# print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
|
138 |
+
else:
|
139 |
+
output_logits = output_logits_rot
|
140 |
+
|
141 |
+
# check if softmax is needed
|
142 |
+
# output_p = output_logits.softmax(dim=1)
|
143 |
+
output_p = output_logits
|
144 |
+
pred = output_logits.argmax(dim=1)[0]
|
145 |
+
if self.debug:
|
146 |
+
_pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
|
147 |
+
plt.subplot(132)
|
148 |
+
plt.imshow(query_image[0,0].detach().cpu())
|
149 |
+
plt.imshow(_pred, alpha=0.5)
|
150 |
+
plt.subplot(131)
|
151 |
+
# plot heatmap of prob of being fg
|
152 |
+
plt.imshow(output_p[0, 1].detach().cpu())
|
153 |
+
# plot rotated query image and rotated pred
|
154 |
+
output_p_rot = output_logits_rot.softmax(dim=1)
|
155 |
+
_pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
|
156 |
+
_pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
|
157 |
+
plt.subplot(133)
|
158 |
+
plt.imshow(rotated_img[0, 0].detach().cpu())
|
159 |
+
plt.imshow(_pred_rot, alpha=0.5)
|
160 |
+
plt.savefig('debug/coarse_pred.png')
|
161 |
+
plt.close()
|
162 |
+
|
163 |
+
if self.coarse_pred_only:
|
164 |
+
output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
|
165 |
+
pred = output_logits.argmax(dim=1)[0]
|
166 |
+
conf = get_confidence_from_logits(output_logits)
|
167 |
+
if self.use_cca:
|
168 |
+
_pred = np.array(pred.detach().cpu())
|
169 |
+
_pred, conf = cca(_pred, output_logits, return_conf=True)
|
170 |
+
pred = torch.from_numpy(_pred)
|
171 |
+
if self.training:
|
172 |
+
return output_logits, [conf]
|
173 |
+
return pred, [conf]
|
174 |
+
|
175 |
+
if query_image.shape[-2:] != self.image_size:
|
176 |
+
query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
|
177 |
+
output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
|
178 |
+
if need_softmax(output_logits):
|
179 |
+
output_logits = output_logits.softmax(dim=1)
|
180 |
+
|
181 |
+
output_p = output_logits
|
182 |
+
pred = output_p.argmax(dim=1)[0]
|
183 |
+
|
184 |
+
_pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
|
185 |
+
if self.use_cca:
|
186 |
+
conn_components = cca(_pred, output_logits, return_cc=True)
|
187 |
+
conf=None
|
188 |
+
else:
|
189 |
+
conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
|
190 |
+
if self.debug:
|
191 |
+
plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
|
192 |
+
# print(f"connected components took {time.time() - start_time} seconds")
|
193 |
+
|
194 |
+
if _pred.max() == 0:
|
195 |
+
if output_p.shape[-2:] != original_size:
|
196 |
+
output_p = F.interpolate(output_p, size=original_size, mode='bilinear')
|
197 |
+
return output_p.argmax(dim=1)[0], [0]
|
198 |
+
|
199 |
+
H, W = query_image.shape[-2:]
|
200 |
+
# bbox = self.get_bbox(_pred)
|
201 |
+
bbox = self.get_bbox_per_cc(conn_components)
|
202 |
+
bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
|
203 |
+
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
|
204 |
+
with torch.no_grad():
|
205 |
+
image_embedding = self.medsam.image_encoder(query_image)
|
206 |
+
|
207 |
+
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W)
|
208 |
+
|
209 |
+
if self.debug:
|
210 |
+
fig, ax = plt.subplots(1, 2)
|
211 |
+
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
|
212 |
+
show_mask(medsam_seg, ax[0])
|
213 |
+
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
|
214 |
+
show_box(bbox[0], ax[1])
|
215 |
+
plt.savefig('debug/medsam_pred.png')
|
216 |
+
plt.close()
|
217 |
+
|
218 |
+
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
|
219 |
+
if medsam_seg.shape[-2:] != original_size:
|
220 |
+
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
|
221 |
+
|
222 |
+
return medsam_seg, [conf]
|
223 |
+
|
224 |
+
def segment_all(self, query_image, query_label):
|
225 |
+
H, W = query_image.shape[-2:]
|
226 |
+
# bbox = self.get_bbox(_pred)
|
227 |
+
# bbox = self.get_bbox_per_cc(conn_components)
|
228 |
+
# bbox = bbox / np.array([W, H, W, H]) * max(self.image_size)
|
229 |
+
bbox = np.array([[0, 0, W, H]])
|
230 |
+
query_image = (query_image - query_image.min()) / (query_image.max() - query_image.min())
|
231 |
+
with torch.no_grad():
|
232 |
+
image_embedding = self.medsam.image_encoder(query_image)
|
233 |
+
|
234 |
+
medsam_seg, conf= self.medsam_inference(image_embedding, bbox, H, W, query_label)
|
235 |
+
|
236 |
+
if self.debug:
|
237 |
+
fig, ax = plt.subplots(1, 2)
|
238 |
+
ax[0].imshow(query_image[0].permute(1,2,0).detach().cpu())
|
239 |
+
show_mask(medsam_seg, ax[0])
|
240 |
+
ax[1].imshow(query_image[0].permute(1,2,0).detach().cpu())
|
241 |
+
show_box(bbox[0], ax[1])
|
242 |
+
plt.savefig('debug/medsam_pred.png')
|
243 |
+
plt.close()
|
244 |
+
|
245 |
+
medsam_seg = torch.tensor(medsam_seg, device=image_embedding.device)
|
246 |
+
if medsam_seg.shape[-2:] != (H, W):
|
247 |
+
medsam_seg = F.interpolate(medsam_seg.unsqueeze(0).unsqueeze(0), size=(H, W), mode='nearest')[0][0]
|
248 |
+
|
249 |
+
return medsam_seg.view(H,W), [conf]
|
250 |
+
|
251 |
+
|
252 |
+
def show_mask(mask, ax, random_color=False):
|
253 |
+
if random_color:
|
254 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
255 |
+
else:
|
256 |
+
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
|
257 |
+
h, w = mask.shape[-2:]
|
258 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
259 |
+
ax.imshow(mask_image)
|
260 |
+
|
261 |
+
|
262 |
+
def show_box(box, ax):
|
263 |
+
x0, y0 = box[0], box[1]
|
264 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
265 |
+
ax.add_patch(
|
266 |
+
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
|
267 |
+
)
|
models/ProtoSAM.py
ADDED
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
from models.grid_proto_fewshot import FewShotSeg
|
8 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
9 |
+
from models.SamWrapper import SamWrapper
|
10 |
+
from util.utils import cca, get_connected_components, rotate_tensor_no_crop, reverse_tensor, get_confidence_from_logits
|
11 |
+
from util.lora import inject_trainable_lora
|
12 |
+
from models.segment_anything.utils.transforms import ResizeLongestSide
|
13 |
+
import cv2
|
14 |
+
import time
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
|
17 |
+
CONF_MODE="conf"
|
18 |
+
CENTROID_MODE="centroid"
|
19 |
+
BOTH_MODE="both"
|
20 |
+
POINT_MODES=(CONF_MODE, CENTROID_MODE, BOTH_MODE)
|
21 |
+
|
22 |
+
TYPE_ALPNET="alpnet"
|
23 |
+
TYPE_SAM="sam"
|
24 |
+
|
25 |
+
def plot_connected_components(cca_output, original_image, confidences:dict=None, title="debug/connected_components.png"):
|
26 |
+
num_labels, labels, stats, centroids = cca_output
|
27 |
+
# Create an output image with random colors for each component
|
28 |
+
output_image = np.zeros((labels.shape[0], labels.shape[1], 3), np.uint8)
|
29 |
+
for label in range(1, num_labels): # Start from 1 to skip the background
|
30 |
+
mask = labels == label
|
31 |
+
output_image[mask] = np.random.randint(0, 255, size=3)
|
32 |
+
|
33 |
+
# Plotting the original and the colored components image
|
34 |
+
plt.figure(figsize=(10, 5))
|
35 |
+
plt.subplot(121), plt.imshow(original_image), plt.title('Original Image')
|
36 |
+
plt.subplot(122), plt.imshow(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)), plt.title('Connected Components')
|
37 |
+
if confidences is not None:
|
38 |
+
# Plot the axes color chart with the confidences, use the same colors as the connected components
|
39 |
+
plt.subplot(122)
|
40 |
+
scatter = plt.scatter(centroids[:, 0], centroids[:, 1], c=list(confidences.values()), cmap='jet')
|
41 |
+
plt.colorbar(scatter)
|
42 |
+
|
43 |
+
plt.savefig(title)
|
44 |
+
plt.close()
|
45 |
+
|
46 |
+
class SegmentationInput(ABC):
|
47 |
+
@abstractmethod
|
48 |
+
def set_query_images(self, query_images):
|
49 |
+
pass
|
50 |
+
|
51 |
+
def to(self, device):
|
52 |
+
pass
|
53 |
+
|
54 |
+
class SegmentationOutput(ABC):
|
55 |
+
@abstractmethod
|
56 |
+
def get_prediction(self):
|
57 |
+
pass
|
58 |
+
|
59 |
+
class ALPNetInput(SegmentationInput): # for alpnet
|
60 |
+
def __init__(self, support_images:list, support_labels:list, query_images:torch.Tensor, isval, val_wsize, show_viz=False, supp_fts=None):
|
61 |
+
self.supp_imgs = [support_images]
|
62 |
+
self.fore_mask = [support_labels]
|
63 |
+
self.back_mask = [[1 - sup_labels for sup_labels in support_labels]]
|
64 |
+
self.qry_imgs = [query_images]
|
65 |
+
self.isval = isval
|
66 |
+
self.val_wsize = val_wsize
|
67 |
+
self.show_viz = show_viz
|
68 |
+
self.supp_fts = supp_fts
|
69 |
+
|
70 |
+
def set_query_images(self, query_images):
|
71 |
+
self.qry_imgs = [query_images]
|
72 |
+
|
73 |
+
def to(self, device):
|
74 |
+
self.supp_imgs = [[supp_img.to(device) for way in self.supp_imgs for supp_img in way]]
|
75 |
+
self.fore_mask = [[fore_mask.to(device) for way in self.fore_mask for fore_mask in way]]
|
76 |
+
self.back_mask = [[back_mask.to(device) for way in self.back_mask for back_mask in way]]
|
77 |
+
self.qry_imgs = [qry_img.to(device) for qry_img in self.qry_imgs]
|
78 |
+
if self.supp_fts is not None:
|
79 |
+
self.supp_fts = self.supp_fts.to(device)
|
80 |
+
|
81 |
+
class ALPNetOutput(SegmentationOutput):
|
82 |
+
def __init__(self, pred, align_loss, sim_maps, assign_maps, proto_grid, supp_fts, qry_fts):
|
83 |
+
self.pred = pred
|
84 |
+
self.align_loss = align_loss
|
85 |
+
self.sim_maps = sim_maps
|
86 |
+
self.assign_maps = assign_maps
|
87 |
+
self.proto_grid = proto_grid
|
88 |
+
self.supp_fts = supp_fts
|
89 |
+
self.qry_fts = qry_fts
|
90 |
+
|
91 |
+
def get_prediction(self):
|
92 |
+
return self.pred
|
93 |
+
|
94 |
+
class SAMWrapperInput(SegmentationInput):
|
95 |
+
def __init__(self, image, image_labels):
|
96 |
+
self.image = image
|
97 |
+
self.image_labels = image_labels
|
98 |
+
|
99 |
+
def set_query_images(self, query_images):
|
100 |
+
B, C, H, W = query_images.shape
|
101 |
+
if isinstance(query_images, torch.Tensor):
|
102 |
+
query_images = query_images.cpu().detach().numpy()
|
103 |
+
assert B == 1, "batch size must be 1"
|
104 |
+
query_images = (query_images - query_images.min()) / (query_images.max() - query_images.min()) * 255
|
105 |
+
query_images = query_images.astype(np.uint8)
|
106 |
+
self.image = np.transpose(query_images[0], (1, 2, 0))
|
107 |
+
|
108 |
+
def to(self, device):
|
109 |
+
pass
|
110 |
+
|
111 |
+
|
112 |
+
class InputFactory(ABC):
|
113 |
+
@staticmethod
|
114 |
+
def create_input(input_type, query_image, support_images=None, support_labels=None, isval=False, val_wsize=None, show_viz=False, supp_fts=None, original_sz=None, img_sz=None, gts=None):
|
115 |
+
|
116 |
+
if input_type == TYPE_ALPNET:
|
117 |
+
return ALPNetInput(support_images, support_labels, query_image, isval, val_wsize, show_viz, supp_fts)
|
118 |
+
elif input_type == TYPE_SAM:
|
119 |
+
qimg = np.array(query_image.detach().cpu())
|
120 |
+
B,C,H,W = qimg.shape
|
121 |
+
assert B == 1, "batch size must be 1"
|
122 |
+
gts = np.array(gts.detach().cpu()).astype(np.uint8).reshape(H,W)
|
123 |
+
assert np.unique(gts).shape[0] <= 2, "support labels must be binary"
|
124 |
+
gts[gts > 0] = 1
|
125 |
+
qimg = qimg.reshape(H,W,C)
|
126 |
+
qimg = (qimg - qimg.min()) / (qimg.max() - qimg.min()) * 255
|
127 |
+
qimg = qimg.astype(np.uint8)
|
128 |
+
return SAMWrapperInput(qimg, gts)
|
129 |
+
else:
|
130 |
+
raise ValueError(f"input_type not supported")
|
131 |
+
|
132 |
+
|
133 |
+
class ModelWrapper(ABC):
|
134 |
+
def __init__(self, model):
|
135 |
+
self.model = model
|
136 |
+
|
137 |
+
def __call__(self, input_data: SegmentationInput)->SegmentationOutput:
|
138 |
+
pass
|
139 |
+
|
140 |
+
def state_dict(self):
|
141 |
+
return self.model.state_dict()
|
142 |
+
|
143 |
+
def load_state_dict(self, state_dict):
|
144 |
+
self.model.load_state_dict(state_dict)
|
145 |
+
|
146 |
+
def eval(self):
|
147 |
+
self.model.eval()
|
148 |
+
|
149 |
+
def train(self):
|
150 |
+
self.model.train()
|
151 |
+
|
152 |
+
def parameters(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
class ALPNetWrapper(ModelWrapper):
|
156 |
+
def __init__(self, model: FewShotSeg):
|
157 |
+
super().__init__(model)
|
158 |
+
|
159 |
+
def __call__(self, input_data: ALPNetInput):
|
160 |
+
output = self.model(**input_data.__dict__)
|
161 |
+
output = ALPNetOutput(*output)
|
162 |
+
return output.pred
|
163 |
+
|
164 |
+
def parameters(self):
|
165 |
+
return self.model.encoder.parameters()
|
166 |
+
|
167 |
+
def train(self):
|
168 |
+
self.model.encoder.train()
|
169 |
+
|
170 |
+
class SamWrapperWrapper(ModelWrapper):
|
171 |
+
def __init__(self, model:SamWrapper):
|
172 |
+
super().__init__(model)
|
173 |
+
|
174 |
+
def __call__(self, input_data: SAMWrapperInput):
|
175 |
+
pred = self.model(**input_data.__dict__)
|
176 |
+
# make pred look like logits
|
177 |
+
pred = torch.tensor(pred).float()[None, None, ...]
|
178 |
+
pred = torch.cat([1-pred, pred], dim=1)
|
179 |
+
return pred
|
180 |
+
|
181 |
+
def to(self, device):
|
182 |
+
self.model.sam.to(device)
|
183 |
+
|
184 |
+
class ProtoSAM(nn.Module):
|
185 |
+
def __init__(self, image_size, coarse_segmentation_model:ModelWrapper, sam_pretrained_path="pretrained_model/sam_default.pth", num_points_for_sam=1, use_points=True, use_bbox=False, use_mask=False, debug=False, use_cca=False, point_mode=CONF_MODE, use_sam_trans=True, coarse_pred_only=False, alpnet_image_size=None, use_neg_points=False, ):
|
186 |
+
super().__init__()
|
187 |
+
if isinstance(image_size, int):
|
188 |
+
image_size = (image_size, image_size)
|
189 |
+
self.image_size = image_size
|
190 |
+
self.coarse_segmentation_model = coarse_segmentation_model
|
191 |
+
self.get_sam(sam_pretrained_path, use_sam_trans)
|
192 |
+
self.num_points_for_sam = num_points_for_sam
|
193 |
+
self.use_points = use_points
|
194 |
+
self.use_bbox = use_bbox # if False then uses points
|
195 |
+
self.use_mask = use_mask
|
196 |
+
self.use_neg_points = use_neg_points
|
197 |
+
assert self.use_bbox or self.use_points or self.use_mask, "must use at least one of bbox, points, or mask"
|
198 |
+
self.use_cca = use_cca
|
199 |
+
self.point_mode = point_mode
|
200 |
+
if self.point_mode not in POINT_MODES:
|
201 |
+
raise ValueError(f"point mode must be one of {POINT_MODES}")
|
202 |
+
self.debug=debug
|
203 |
+
self.coarse_pred_only = coarse_pred_only
|
204 |
+
|
205 |
+
def get_sam(self, checkpoint_path, use_sam_trans):
|
206 |
+
model_type="vit_b" # TODO make generic?
|
207 |
+
if 'vit_h' in checkpoint_path:
|
208 |
+
model_type = "vit_h"
|
209 |
+
self.sam = sam_model_registry[model_type](checkpoint=checkpoint_path).eval()
|
210 |
+
self.predictor = SamPredictor(self.sam)
|
211 |
+
self.sam.requires_grad_(False)
|
212 |
+
if use_sam_trans:
|
213 |
+
# sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size, pixel_mean=[0], pixel_std=[1])
|
214 |
+
sam_trans = ResizeLongestSide(self.sam.image_encoder.img_size)
|
215 |
+
sam_trans.pixel_mean = torch.tensor([0, 0, 0]).view(3, 1, 1)
|
216 |
+
sam_trans.pixel_std = torch.tensor([1, 1, 1]).view(3, 1, 1)
|
217 |
+
else:
|
218 |
+
sam_trans = None
|
219 |
+
|
220 |
+
self.sam_trans = sam_trans
|
221 |
+
|
222 |
+
def get_bbox(self, pred):
|
223 |
+
'''
|
224 |
+
pred tensor of shape (H, W) where 1 represents foreground and 0 represents background
|
225 |
+
returns a list of 2d points representing the bbox
|
226 |
+
'''
|
227 |
+
if isinstance(pred, np.ndarray):
|
228 |
+
pred = torch.tensor(pred)
|
229 |
+
# get the indices of the foreground points
|
230 |
+
indices = torch.nonzero(pred)
|
231 |
+
# get the min and max of the indices
|
232 |
+
min_x = indices[:, 1].min()
|
233 |
+
max_x = indices[:, 1].max()
|
234 |
+
min_y = indices[:, 0].min()
|
235 |
+
max_y = indices[:, 0].max()
|
236 |
+
# get the bbox
|
237 |
+
bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
|
238 |
+
|
239 |
+
|
240 |
+
return bbox
|
241 |
+
|
242 |
+
def get_bbox_per_cc(self, conn_components):
|
243 |
+
"""
|
244 |
+
conn_components: output of cca function
|
245 |
+
return list of bboxes per connected component, each bbox is a list of 2d points
|
246 |
+
"""
|
247 |
+
bboxes = []
|
248 |
+
for i in range(1, conn_components[0]):
|
249 |
+
# get the indices of the foreground points
|
250 |
+
indices = torch.nonzero(torch.tensor(conn_components[1] == i))
|
251 |
+
# get the min and max of the indices
|
252 |
+
min_x = indices[:, 1].min()
|
253 |
+
max_x = indices[:, 1].max()
|
254 |
+
min_y = indices[:, 0].min()
|
255 |
+
max_y = indices[:, 0].max()
|
256 |
+
# get the bbox
|
257 |
+
# bbox = [[min_y, min_x], [min_y, max_x], [max_y, max_x], [max_y, min_x]]
|
258 |
+
# bbox = [[min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y]]
|
259 |
+
# bbox should be in a XYXY format
|
260 |
+
bbox = [min_x, min_y, max_x, max_y]
|
261 |
+
bboxes.append(bbox)
|
262 |
+
|
263 |
+
bboxes = np.array(bboxes)
|
264 |
+
return bboxes
|
265 |
+
|
266 |
+
def get_most_conf_points(self, output_p_fg, pred, k):
|
267 |
+
'''
|
268 |
+
get the k most confident points from pred
|
269 |
+
output_p: 3d tensor of shape (H, W)
|
270 |
+
pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
|
271 |
+
'''
|
272 |
+
# Create a mask where pred is 1
|
273 |
+
mask = pred.bool()
|
274 |
+
|
275 |
+
# Apply the mask to output_p_fg
|
276 |
+
masked_output_p_fg = output_p_fg[mask]
|
277 |
+
if masked_output_p_fg.numel() == 0:
|
278 |
+
return None, None
|
279 |
+
# Get the top k probabilities and their indices
|
280 |
+
confidences, indices = torch.topk(masked_output_p_fg, k)
|
281 |
+
|
282 |
+
# Get the locations of the top k points in xy format
|
283 |
+
locations = torch.nonzero(mask)[indices]
|
284 |
+
# convert locations to xy format
|
285 |
+
locations = locations[:, [1, 0]]
|
286 |
+
# convert locations to list of lists
|
287 |
+
# points = [loc.tolist() for loc in locations]
|
288 |
+
|
289 |
+
return locations.numpy(), [float(conf.item()) for conf in confidences]
|
290 |
+
|
291 |
+
|
292 |
+
def plot_most_conf_points(self, points, confidences, pred, image, bboxes=None, title=None):
|
293 |
+
'''
|
294 |
+
points: np array of shape (N, 2) where each row is a point in xy format
|
295 |
+
pred: 2d tensor of shape (H, W) where 1 represents foreground and 0 represents background
|
296 |
+
image: 2d tensor of shape (H,W) representing the image
|
297 |
+
bbox: list or np array of shape (N, 4) where each row is a bbox in xyxy format
|
298 |
+
'''
|
299 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
300 |
+
if isinstance(pred, torch.Tensor):
|
301 |
+
pred = pred.cpu().detach().numpy()
|
302 |
+
if len(image.shape) == 3 and image.shape[0] == 3:
|
303 |
+
image = image.permute(1, 2, 0)
|
304 |
+
if title is None:
|
305 |
+
title="debug/most_conf_points.png"
|
306 |
+
|
307 |
+
fig = plt.figure()
|
308 |
+
image = (image - image.min()) / (image.max() - image.min())
|
309 |
+
plt.imshow(image)
|
310 |
+
plt.imshow(pred, alpha=0.5)
|
311 |
+
for i, point in enumerate(points):
|
312 |
+
plt.scatter(point[0][0], point[0][1], cmap='viridis', marker='*', c='red')
|
313 |
+
if confidences is not None:
|
314 |
+
plt.text(point[0], point[1], f"{confidences[i]:.3f}", fontsize=12, color='red')
|
315 |
+
# assume points is a list of lists
|
316 |
+
if bboxes is not None:
|
317 |
+
for bbox in bboxes:
|
318 |
+
if bbox is None:
|
319 |
+
continue
|
320 |
+
bbox = np.array(bbox)
|
321 |
+
# plt.scatter(bbox[:, 1], bbox[:, 0], c='red')
|
322 |
+
# plot a line connecting the points
|
323 |
+
box = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
|
324 |
+
box = np.vstack([box, box[0]])
|
325 |
+
plt.plot(box[:, 0], box[:, 1], c='green')
|
326 |
+
plt.colorbar()
|
327 |
+
fig.savefig(title)
|
328 |
+
plt.close(fig)
|
329 |
+
|
330 |
+
def plot_sam_preds(self, masks, scores, image, input_point, input_label, input_box=None):
|
331 |
+
if len(image.shape) == 3:
|
332 |
+
image = image.permute(1, 2, 0)
|
333 |
+
image = (image - image.min()) / (image.max() - image.min())
|
334 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
335 |
+
plt.figure(figsize=(10,10))
|
336 |
+
plt.imshow(image)
|
337 |
+
show_mask(mask, plt.gca())
|
338 |
+
if input_point is not None:
|
339 |
+
show_points(input_point, input_label, plt.gca())
|
340 |
+
if input_box is not None:
|
341 |
+
show_box(input_box, plt.gca())
|
342 |
+
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
343 |
+
# plt.axis('off')
|
344 |
+
plt.savefig(f'debug/sam_mask_{i+1}.png')
|
345 |
+
plt.close()
|
346 |
+
if i > 5:
|
347 |
+
break
|
348 |
+
|
349 |
+
def get_sam_input_points(self, conn_components, output_p, get_neg_points=False, l=1):
|
350 |
+
"""
|
351 |
+
args:
|
352 |
+
conn_components: output of cca function
|
353 |
+
output_p: 3d tensor of shape (1, 2, H, W)
|
354 |
+
get_neg_points: bool, if True then return the negative points
|
355 |
+
l: int, number of negative points to get
|
356 |
+
"""
|
357 |
+
sam_input_points = []
|
358 |
+
sam_neg_points = []
|
359 |
+
fg_p = output_p[0, 1].detach().cpu()
|
360 |
+
|
361 |
+
if get_neg_points:
|
362 |
+
# get global negative points
|
363 |
+
bg_p = output_p[0, 0].detach().cpu()
|
364 |
+
bg_p[bg_p < 0.95] = 0
|
365 |
+
bg_pred = torch.where(bg_p > 0, 1, 0)
|
366 |
+
glob_neg_points, _ = self.get_most_conf_points(bg_p, bg_pred, 1)
|
367 |
+
if self.debug:
|
368 |
+
# plot the bg_p as a heatmap
|
369 |
+
plt.figure()
|
370 |
+
plt.imshow(bg_p)
|
371 |
+
plt.colorbar()
|
372 |
+
plt.savefig('debug/bg_p_heatmap.png')
|
373 |
+
plt.close()
|
374 |
+
|
375 |
+
for i, cc_id in enumerate(np.unique(conn_components[1])):
|
376 |
+
# get self.num_points_for_sam most confident points from pred
|
377 |
+
if cc_id == 0:
|
378 |
+
continue # skip background
|
379 |
+
pred = torch.tensor(conn_components[1] == cc_id).float()
|
380 |
+
|
381 |
+
if self.point_mode == CONF_MODE:
|
382 |
+
points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam) # (N, 2)
|
383 |
+
elif self.point_mode == CENTROID_MODE:
|
384 |
+
points = conn_components[3][cc_id][None, :] # (1, 2)
|
385 |
+
confidences = [1 for _ in range(len(points))]
|
386 |
+
elif self.point_mode == BOTH_MODE:
|
387 |
+
points, confidences = self.get_most_conf_points(fg_p, pred, self.num_points_for_sam)
|
388 |
+
point = conn_components[3][cc_id][None, :]
|
389 |
+
points = np.vstack([points, point]) # (N+1, 2)
|
390 |
+
confidences.append(1)
|
391 |
+
else:
|
392 |
+
raise NotImplementedError(f"point mode {self.point_mode} not implemented")
|
393 |
+
sam_input_points.append(np.array(points))
|
394 |
+
|
395 |
+
if get_neg_points:
|
396 |
+
pred_uint8 = (pred.numpy() * 255).astype(np.uint8)
|
397 |
+
|
398 |
+
# Dilate the mask to expand it
|
399 |
+
kernel_size = 3 # Size of the dilation kernel, adjust accordingly
|
400 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
401 |
+
dilation_iterations = 10 # Number of times dilation is applied, adjust as needed
|
402 |
+
dilated_mask = cv2.dilate(pred_uint8, kernel, iterations=dilation_iterations)
|
403 |
+
|
404 |
+
# Subtract the original mask from the dilated mask
|
405 |
+
# This will give a boundary that is only outside the original mask
|
406 |
+
outside_boundary = dilated_mask - pred_uint8
|
407 |
+
|
408 |
+
# Convert back to torch tensor and normalize
|
409 |
+
boundary = torch.tensor(outside_boundary).float() / 255
|
410 |
+
try:
|
411 |
+
bg_p = output_p[0, 0].detach().cpu()
|
412 |
+
neg_points, neg_confidences = self.get_most_conf_points(bg_p, boundary, l)
|
413 |
+
except RuntimeError as e:
|
414 |
+
# make each point (None, None)
|
415 |
+
neg_points = None
|
416 |
+
# append global negative points to the negative points
|
417 |
+
if neg_points is not None and glob_neg_points is not None:
|
418 |
+
neg_points = np.vstack([neg_points, glob_neg_points])
|
419 |
+
else:
|
420 |
+
neg_points = glob_neg_points if neg_points is None else neg_points
|
421 |
+
if self.debug and neg_points is not None:
|
422 |
+
# draw an image with 2 subplots, one is the pred and the other is the boundary
|
423 |
+
plt.figure()
|
424 |
+
plt.subplot(121)
|
425 |
+
plt.imshow(pred)
|
426 |
+
plt.imshow(boundary, alpha=0.5)
|
427 |
+
# plot the neg points
|
428 |
+
plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
|
429 |
+
plt.subplot(122)
|
430 |
+
plt.imshow(pred)
|
431 |
+
plt.scatter(neg_points[:, 0], neg_points[:, 1], cmap='viridis', marker='*', c='red')
|
432 |
+
plt.savefig('debug/pred_and_boundary.png')
|
433 |
+
plt.close()
|
434 |
+
sam_neg_points.append(neg_points)
|
435 |
+
else:
|
436 |
+
# create a list of None same shape as points
|
437 |
+
sam_neg_points = [None for _ in range(len(sam_input_points))]
|
438 |
+
|
439 |
+
sam_input_labels = np.array([l+1 for l, cc_points in enumerate(sam_input_points) for _ in range(len(cc_points))])
|
440 |
+
sam_input_points = np.stack(sam_input_points) # should be of shape (num_connected_components, num_points_for_sam, 2)
|
441 |
+
# if get_neg_points:
|
442 |
+
sam_neg_input_points = np.stack(sam_neg_points) if sam_neg_points is not None else None
|
443 |
+
if sam_neg_input_points is not None:
|
444 |
+
sam_neg_input_points = sam_neg_points
|
445 |
+
sam_neg_input_labels = np.array([0] * len(sam_neg_input_points) )
|
446 |
+
else:
|
447 |
+
sam_neg_input_points = None
|
448 |
+
sam_neg_input_labels = None
|
449 |
+
|
450 |
+
return sam_input_points, sam_input_labels, sam_neg_input_points, sam_neg_input_labels
|
451 |
+
|
452 |
+
def get_sam_input_mask(self, conn_components):
|
453 |
+
sam_input_masks = []
|
454 |
+
sam_input_mask_lables = []
|
455 |
+
for i, cc_id in enumerate(np.unique(conn_components[1])):
|
456 |
+
# get self.num_points_for_sam most confident points from pred
|
457 |
+
if cc_id == 0:
|
458 |
+
continue
|
459 |
+
pred = torch.tensor(conn_components[1] == cc_id).float()
|
460 |
+
sam_input_masks.append(pred)
|
461 |
+
sam_input_mask_lables.append(cc_id)
|
462 |
+
|
463 |
+
sam_input_masks = np.stack(sam_input_masks)
|
464 |
+
sam_input_mask_lables = np.array(sam_input_mask_lables)
|
465 |
+
|
466 |
+
return sam_input_masks, sam_input_mask_lables
|
467 |
+
|
468 |
+
def predict_w_masks(self, sam_input_masks, qry_img, original_size):
|
469 |
+
masks = []
|
470 |
+
scores = []
|
471 |
+
for in_mask in sam_input_masks:
|
472 |
+
in_mask = cv2.resize(in_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
|
473 |
+
in_mask[in_mask == 1] = 10
|
474 |
+
in_mask[in_mask == 0] = -8
|
475 |
+
assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
|
476 |
+
self.predictor.set_image(qry_img)
|
477 |
+
mask, score, _ = self.predictor.predict(
|
478 |
+
mask_input=in_mask[None, ...].astype(np.uint8),
|
479 |
+
multimask_output=True)
|
480 |
+
# get max index from score
|
481 |
+
if self.debug:
|
482 |
+
# plot each channel of mask
|
483 |
+
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
|
484 |
+
for i in range(mask.shape[0]):
|
485 |
+
ax[i].imshow(qry_img)
|
486 |
+
ax[i].imshow(mask[i], alpha=0.5)
|
487 |
+
ax[i].set_title(f"Mask {i+1}, Score: {score[i]:.3f}", fontsize=18)
|
488 |
+
# ax[i].axis('off')
|
489 |
+
ax[-1].imshow(cv2.resize(in_mask, original_size, interpolation=cv2.INTER_NEAREST))
|
490 |
+
fig.savefig(f'debug/sam_mask_from_mask_prompts.png')
|
491 |
+
plt.close(fig)
|
492 |
+
|
493 |
+
|
494 |
+
max_index = score.argmax()
|
495 |
+
masks.append(mask[max_index])
|
496 |
+
scores.append(score[max_index])
|
497 |
+
|
498 |
+
return masks, scores
|
499 |
+
|
500 |
+
def predict_w_points_bbox(self, sam_input_points, bboxes, sam_neg_input_points, qry_img, pred, return_logits=False):
|
501 |
+
masks, scores = [], []
|
502 |
+
self.predictor.set_image(qry_img)
|
503 |
+
# if sam_input_points is None:
|
504 |
+
# sam_input_points = [None for _ in range(len(bboxes))]
|
505 |
+
for point, bbox_xyxy, neg_point in zip(sam_input_points, bboxes, sam_neg_input_points):
|
506 |
+
assert qry_img.max() <= 255 and qry_img.min() >= 0 and qry_img.dtype == np.uint8
|
507 |
+
points = point
|
508 |
+
point_labels = np.array([1] * len(point)) if point is not None else None
|
509 |
+
if self.use_neg_points:
|
510 |
+
neg_points = [npoint for npoint in neg_point if None not in npoint]
|
511 |
+
points = np.vstack([point, *neg_points])
|
512 |
+
point_labels = np.array([1] * len(point) + [0] * len(neg_points))
|
513 |
+
if self.debug:
|
514 |
+
self.plot_most_conf_points(points[:, None, ...], None, pred, qry_img, bboxes=bbox_xyxy[None,...] if bbox_xyxy is not None else None, title="debug/pos_neg_points.png") # TODO add plots for all points not just the first set of points
|
515 |
+
mask, score, _ = self.predictor.predict(
|
516 |
+
point_coords=points,
|
517 |
+
point_labels=point_labels,
|
518 |
+
# box=bbox_xyxy[None, :] if bbox_xyxy is not None else None,
|
519 |
+
box = bbox_xyxy if bbox_xyxy is not None else None,
|
520 |
+
# mask_input=sam_mask_input,
|
521 |
+
return_logits=return_logits,
|
522 |
+
multimask_output=False if self.use_cca else True
|
523 |
+
)
|
524 |
+
# best_pred_idx = np.argmax(score)
|
525 |
+
best_pred_idx = 0
|
526 |
+
masks.append(mask[best_pred_idx])
|
527 |
+
scores.append(score[best_pred_idx])
|
528 |
+
|
529 |
+
if self.debug:
|
530 |
+
# pass
|
531 |
+
self.plot_sam_preds(mask, score, qry_img[...,0], points.reshape(-1,2) if sam_input_points is not None else None, point_labels, input_box=bbox_xyxy if bbox_xyxy is not None else None)
|
532 |
+
|
533 |
+
return masks, scores
|
534 |
+
|
535 |
+
|
536 |
+
def forward(self, query_image, coarse_model_input, degrees_rotate=0):
|
537 |
+
"""
|
538 |
+
query_image: 3d tensor of shape (1, 3, H, W)
|
539 |
+
images should be normalized with mean and std but not to [0, 1]?
|
540 |
+
"""
|
541 |
+
original_size = query_image.shape[-2]
|
542 |
+
# rotate query_image by degrees_rotate
|
543 |
+
start_time = time.time()
|
544 |
+
rotated_img, (rot_h, rot_w) = rotate_tensor_no_crop(query_image, degrees_rotate)
|
545 |
+
# print(f"rotating query image took {time.time() - start_time} seconds")
|
546 |
+
start_time = time.time()
|
547 |
+
coarse_model_input.set_query_images(rotated_img)
|
548 |
+
output_logits_rot = self.coarse_segmentation_model(coarse_model_input)
|
549 |
+
# print(f"ALPNet took {time.time() - start_time} seconds")
|
550 |
+
|
551 |
+
if degrees_rotate != 0:
|
552 |
+
start_time = time.time()
|
553 |
+
output_logits = reverse_tensor(output_logits_rot, rot_h, rot_w, -degrees_rotate)
|
554 |
+
# print(f"reversing rotated output_logits took {time.time() - start_time} seconds")
|
555 |
+
else:
|
556 |
+
output_logits = output_logits_rot
|
557 |
+
|
558 |
+
# check if softmax is needed
|
559 |
+
output_p = output_logits.softmax(dim=1)
|
560 |
+
# output_p = output_logits
|
561 |
+
pred = output_logits.argmax(dim=1)[0]
|
562 |
+
if self.debug:
|
563 |
+
_pred = np.array(output_logits.argmax(dim=1)[0].detach().cpu())
|
564 |
+
plt.subplot(132)
|
565 |
+
plt.imshow(query_image[0,0].detach().cpu())
|
566 |
+
plt.imshow(_pred, alpha=0.5)
|
567 |
+
plt.subplot(131)
|
568 |
+
# plot heatmap of prob of being fg
|
569 |
+
plt.imshow(output_p[0, 1].detach().cpu())
|
570 |
+
# plot rotated query image and rotated pred
|
571 |
+
output_p_rot = output_logits_rot.softmax(dim=1)
|
572 |
+
_pred_rot = np.array(output_p_rot.argmax(dim=1)[0].detach().cpu())
|
573 |
+
_pred_rot = F.interpolate(torch.tensor(_pred_rot).unsqueeze(0).unsqueeze(0).float(), size=original_size, mode='nearest')[0][0]
|
574 |
+
plt.subplot(133)
|
575 |
+
plt.imshow(rotated_img[0, 0].detach().cpu())
|
576 |
+
plt.imshow(_pred_rot, alpha=0.5)
|
577 |
+
plt.savefig('debug/coarse_pred.png')
|
578 |
+
plt.close()
|
579 |
+
|
580 |
+
if self.coarse_pred_only:
|
581 |
+
output_logits = F.interpolate(output_logits, size=original_size, mode='bilinear') if output_logits.shape[-2:] != original_size else output_logits
|
582 |
+
pred = output_logits.argmax(dim=1)[0]
|
583 |
+
conf = get_confidence_from_logits(output_logits)
|
584 |
+
if self.use_cca:
|
585 |
+
_pred = np.array(pred.detach().cpu())
|
586 |
+
_pred, conf = cca(_pred, output_logits, return_conf=True)
|
587 |
+
pred = torch.from_numpy(_pred)
|
588 |
+
if self.training:
|
589 |
+
return output_logits, [conf]
|
590 |
+
# Ensure pred is a float tensor for consistent visualization
|
591 |
+
return pred.float(), [conf]
|
592 |
+
|
593 |
+
if query_image.shape[-2:] != self.image_size:
|
594 |
+
query_image = F.interpolate(query_image, size=self.image_size, mode='bilinear')
|
595 |
+
output_logits = F.interpolate(output_logits, size=self.image_size, mode='bilinear')
|
596 |
+
# if need_softmax(output_logits):
|
597 |
+
# output_logits = output_logits.softmax(dim=1)
|
598 |
+
|
599 |
+
# output_p = output_logits
|
600 |
+
output_p = output_logits.softmax(dim=1)
|
601 |
+
pred = output_p.argmax(dim=1)[0]
|
602 |
+
|
603 |
+
_pred = np.array(output_p.argmax(dim=1)[0].detach().cpu())
|
604 |
+
start_time = time.time()
|
605 |
+
if self.use_cca:
|
606 |
+
conn_components = cca(_pred, output_logits, return_cc=True)
|
607 |
+
conf=None
|
608 |
+
else:
|
609 |
+
conn_components, conf = get_connected_components(_pred, output_logits, return_conf=True)
|
610 |
+
if self.debug:
|
611 |
+
plot_connected_components(conn_components, query_image[0,0].detach().cpu(), conf)
|
612 |
+
# print(f"connected components took {time.time() - start_time} seconds")
|
613 |
+
if _pred.max() == 0:
|
614 |
+
return output_p.argmax(dim=1)[0], [0]
|
615 |
+
|
616 |
+
# get bbox from pred
|
617 |
+
if self.use_bbox:
|
618 |
+
start_time = time.time()
|
619 |
+
try:
|
620 |
+
bboxes = self.get_bbox_per_cc(conn_components)
|
621 |
+
except:
|
622 |
+
bboxes = [None] * conn_components[0]
|
623 |
+
else:
|
624 |
+
bboxes = [None] * conn_components[0]
|
625 |
+
# print(f"getting bboxes took {time.time() - start_time} seconds")
|
626 |
+
|
627 |
+
|
628 |
+
start_time = time.time()
|
629 |
+
if self.use_points:
|
630 |
+
sam_input_points, sam_input_point_labels, sam_neg_input_points, sam_neg_input_labels = self.get_sam_input_points(conn_components, output_p, get_neg_points=self.use_neg_points, l=1)
|
631 |
+
else:
|
632 |
+
sam_input_points = [None] * conn_components[0]
|
633 |
+
sam_input_point_labels = [None] * conn_components[0]
|
634 |
+
sam_neg_input_points = [None] * conn_components[0]
|
635 |
+
sam_neg_input_labels = [None] * conn_components[0]
|
636 |
+
# print(f"getting sam input points took {time.time() - start_time} seconds")
|
637 |
+
|
638 |
+
if self.use_mask:
|
639 |
+
sam_input_masks, sam_input_mask_labels = self.get_sam_input_mask(conn_components)
|
640 |
+
else:
|
641 |
+
sam_input_masks = None
|
642 |
+
sam_input_mask_labels = None
|
643 |
+
|
644 |
+
if self.debug and sam_input_points is not None:
|
645 |
+
title = f'debug/most_conf_points.png'
|
646 |
+
if self.use_cca:
|
647 |
+
title = f'debug/most_conf_points_cca.png'
|
648 |
+
# convert points to a list where each item is a list of 2 elements in xy format
|
649 |
+
self.plot_most_conf_points(sam_input_points, None, _pred, query_image[0, 0].detach().cpu(), bboxes=bboxes, title=title) # TODO add plots for all points not just the first set of points
|
650 |
+
|
651 |
+
# self.sam_trans = None
|
652 |
+
if self.sam_trans is None:
|
653 |
+
query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
|
654 |
+
else:
|
655 |
+
query_image = self.sam_trans.apply_image_torch(query_image[0])
|
656 |
+
query_image = self.sam_trans.preprocess(query_image)
|
657 |
+
query_image = query_image.permute(1, 2, 0).detach().cpu().numpy()
|
658 |
+
# mask = self.sam_trans.preprocess(mask)
|
659 |
+
|
660 |
+
|
661 |
+
query_image = ((query_image - query_image.min()) / (query_image.max() - query_image.min()) * 255).astype(np.uint8)
|
662 |
+
if self.use_mask:
|
663 |
+
masks, scores = self.predict_w_masks(sam_input_masks, query_image, original_size)
|
664 |
+
|
665 |
+
start_time = time.time()
|
666 |
+
if self.use_points or self.use_bbox:
|
667 |
+
masks, scores = self.predict_w_points_bbox(sam_input_points, bboxes, sam_neg_input_points, query_image, pred, return_logits=True if self.training else False)
|
668 |
+
# print(f"predicting w points/bbox took {time.time() - start_time} seconds")
|
669 |
+
|
670 |
+
pred = sum(masks)
|
671 |
+
if not self.training:
|
672 |
+
pred = pred > 0
|
673 |
+
pred = torch.tensor(pred).float().to(output_p.device)
|
674 |
+
|
675 |
+
# pred = torch.tensor(masks[0]).float().cuda()
|
676 |
+
# resize pred to the size of the input
|
677 |
+
pred = F.interpolate(pred.unsqueeze(0).unsqueeze(0), size=original_size, mode='nearest')[0][0]
|
678 |
+
|
679 |
+
return pred, scores
|
680 |
+
|
681 |
+
|
682 |
+
def show_mask(mask, ax, random_color=False):
|
683 |
+
if random_color:
|
684 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
685 |
+
else:
|
686 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
687 |
+
h, w = mask.shape[-2:]
|
688 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
689 |
+
ax.imshow(mask_image)
|
690 |
+
|
691 |
+
def show_points(coords, labels, ax, marker_size=375):
|
692 |
+
pos_points = coords[labels==1]
|
693 |
+
neg_points = coords[labels==0]
|
694 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
695 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
696 |
+
|
697 |
+
def show_box(box, ax):
|
698 |
+
x0, y0 = box[0], box[1]
|
699 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
700 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
701 |
+
|
702 |
+
def need_softmax(tensor, dim=1):
|
703 |
+
return not torch.all(torch.isclose(tensor.sum(dim=dim), torch.ones_like(tensor.sum(dim=dim))) & (tensor >= 0))
|
704 |
+
|
705 |
+
|
706 |
+
|
707 |
+
|
708 |
+
|
models/SamWrapper.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from models.segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
|
5 |
+
from models.segment_anything.utils.transforms import ResizeLongestSide
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
def get_iou(mask, label):
|
9 |
+
tp = (mask * label).sum()
|
10 |
+
fp = (mask * (1-label)).sum()
|
11 |
+
fn = ((1-mask) * label).sum()
|
12 |
+
iou = tp / (tp + fp + fn)
|
13 |
+
return iou
|
14 |
+
|
15 |
+
class SamWrapper(nn.Module):
|
16 |
+
def __init__(self,sam_args):
|
17 |
+
"""
|
18 |
+
sam_args: dict should include the following
|
19 |
+
{
|
20 |
+
"model_type": "vit_h",
|
21 |
+
"sam_checkpoint": "path to checkpoint" pretrained_model/sam_vit_h.pth
|
22 |
+
}
|
23 |
+
"""
|
24 |
+
super().__init__()
|
25 |
+
self.sam = sam_model_registry[sam_args['model_type']](checkpoint=sam_args['sam_checkpoint'])
|
26 |
+
self.mask_generator = SamAutomaticMaskGenerator(self.sam)
|
27 |
+
self.transform = ResizeLongestSide(self.sam.image_encoder.img_size)
|
28 |
+
|
29 |
+
def forward(self, image, image_labels):
|
30 |
+
"""
|
31 |
+
generate masks for a batch of images
|
32 |
+
return mask that has the largest iou with the image label
|
33 |
+
Args:
|
34 |
+
images (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
35 |
+
image_labels (np.ndarray): The image labels to generate masks for, in HWC uint8 format. assuming binary labels
|
36 |
+
"""
|
37 |
+
image = self.transform.apply_image(image)
|
38 |
+
masks = self.mask_generator.generate(image)
|
39 |
+
|
40 |
+
best_index, best_iou = None, 0
|
41 |
+
for i, mask in enumerate(masks):
|
42 |
+
segmentation = mask['segmentation']
|
43 |
+
iou = get_iou(segmentation.astype(np.uint8), image_labels)
|
44 |
+
if best_index is None or iou > best_iou:
|
45 |
+
best_index = i
|
46 |
+
best_iou = iou
|
47 |
+
|
48 |
+
return masks[best_index]['segmentation']
|
49 |
+
|
50 |
+
def to(self, device):
|
51 |
+
self.sam.to(device)
|
52 |
+
self.mask_generator.to(device)
|
53 |
+
self.mask_generator.predictor.to(device)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
sam_args = {
|
59 |
+
"model_type": "vit_h",
|
60 |
+
"sam_checkpoint": "pretrained_model/sam_vit_h.pth"
|
61 |
+
}
|
62 |
+
sam_wrapper = SamWrapper(sam_args).cuda()
|
63 |
+
image = cv2.imread("./Kheops-Pyramid.jpg")
|
64 |
+
image = np.array(image).astype('uint8')
|
65 |
+
image_labels = torch.rand(1,3,224,224)
|
66 |
+
sam_wrapper(image, image_labels)
|
67 |
+
|
68 |
+
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/ProtoSAM.cpython-312.pyc
ADDED
Binary file (42.4 kB). View file
|
|
models/__pycache__/SamWrapper.cpython-312.pyc
ADDED
Binary file (3.84 kB). View file
|
|
models/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (161 Bytes). View file
|
|
models/__pycache__/alpmodule.cpython-312.pyc
ADDED
Binary file (12 kB). View file
|
|
models/__pycache__/grid_proto_fewshot.cpython-312.pyc
ADDED
Binary file (23.4 kB). View file
|
|
models/alpmodule.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ALPModule
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import time
|
6 |
+
import math
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
# for unit test from spatial_similarity_module import NONLocalBlock2D, LayerNorm
|
13 |
+
|
14 |
+
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
|
15 |
+
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
|
16 |
+
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
|
17 |
+
x = x.div(x_norm.unsqueeze(1).expand_as(x))
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class MultiProtoAsConv(nn.Module):
|
22 |
+
def __init__(self, proto_grid, feature_hw, embed_dim=768, use_attention=False, upsample_mode = 'bilinear'):
|
23 |
+
"""
|
24 |
+
ALPModule
|
25 |
+
Args:
|
26 |
+
proto_grid: Grid size when doing multi-prototyping. For a 32-by-32 feature map, a size of 16-by-16 leads to a pooling window of 2-by-2
|
27 |
+
feature_hw: Spatial size of input feature map
|
28 |
+
|
29 |
+
"""
|
30 |
+
super(MultiProtoAsConv, self).__init__()
|
31 |
+
self.feature_hw = feature_hw
|
32 |
+
self.proto_grid = proto_grid
|
33 |
+
self.upsample_mode = upsample_mode
|
34 |
+
kernel_size = [ ft_l // grid_l for ft_l, grid_l in zip(feature_hw, proto_grid) ]
|
35 |
+
self.kernel_size = kernel_size
|
36 |
+
print(f"MultiProtoAsConv: kernel_size: {kernel_size}")
|
37 |
+
self.avg_pool_op = nn.AvgPool2d( kernel_size )
|
38 |
+
|
39 |
+
if use_attention:
|
40 |
+
self.proto_fg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
|
41 |
+
self.proto_bg_attnetion = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12 if embed_dim == 768 else 8, batch_first=True)
|
42 |
+
self.fg_mask_projection = nn.Sequential(
|
43 |
+
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
48 |
+
)
|
49 |
+
self.bg_mask_projection = nn.Sequential(
|
50 |
+
nn.Conv2d(embed_dim, 256, kernel_size=1, stride=1, padding=0, bias=True),
|
51 |
+
nn.ReLU(inplace=True),
|
52 |
+
nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True),
|
53 |
+
nn.ReLU(inplace=True),
|
54 |
+
nn.Conv2d(128, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
55 |
+
)
|
56 |
+
|
57 |
+
def get_prediction_from_prototypes(self, prototypes, query, mode, vis_sim=False ):
|
58 |
+
if mode == 'mask':
|
59 |
+
pred_mask = F.cosine_similarity(query, prototypes[..., None, None], dim=1, eps = 1e-4) * 20.0 # [1, h, w]
|
60 |
+
# incase there are more than one prototypes in the same location, take the max
|
61 |
+
pred_mask = pred_mask.max(dim = 0)[0].unsqueeze(0)
|
62 |
+
vis_dict = {'proto_assign': pred_mask} # things to visualize
|
63 |
+
if vis_sim:
|
64 |
+
vis_dict['raw_local_sims'] = pred_mask
|
65 |
+
return pred_mask.unsqueeze(1), [pred_mask], vis_dict # just a placeholder. pred_mask returned as [1, way(1), h, w]
|
66 |
+
|
67 |
+
elif mode == 'gridconv':
|
68 |
+
dists = F.conv2d(query, prototypes[..., None, None]) * 20
|
69 |
+
|
70 |
+
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
|
71 |
+
debug_assign = dists.argmax(dim = 1).float().detach()
|
72 |
+
|
73 |
+
vis_dict = {'proto_assign': debug_assign} # things to visualize
|
74 |
+
|
75 |
+
if vis_sim: # return the similarity for visualization
|
76 |
+
vis_dict['raw_local_sims'] = dists.clone().detach()
|
77 |
+
return pred_grid, [debug_assign], vis_dict
|
78 |
+
|
79 |
+
elif mode == 'gridconv+':
|
80 |
+
dists = F.conv2d(query, prototypes[..., None, None]) * 20
|
81 |
+
|
82 |
+
pred_grid = torch.sum(F.softmax(dists, dim = 1) * dists, dim = 1, keepdim = True)
|
83 |
+
# raw_local_sims = dists.det ach()
|
84 |
+
|
85 |
+
debug_assign = dists.argmax(dim = 1).float()
|
86 |
+
|
87 |
+
vis_dict = {'proto_assign': debug_assign}
|
88 |
+
if vis_sim:
|
89 |
+
vis_dict['raw_local_sims'] = dists.clone().detach()
|
90 |
+
|
91 |
+
return pred_grid, [debug_assign], vis_dict
|
92 |
+
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Invalid mode: {mode}. Expected 'mask', 'gridconv', or 'gridconv+'.")
|
95 |
+
|
96 |
+
|
97 |
+
def get_prototypes(self, sup_x, sup_y, mode, val_wsize, thresh, isval = False):
|
98 |
+
if mode == 'mask':
|
99 |
+
proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
|
100 |
+
/ (sup_y.sum(dim=(-1, -2)) + 1e-5) # nb x C
|
101 |
+
|
102 |
+
pro_n = proto.mean(dim = 0, keepdim = True) # 1 X C, take the mean of everything
|
103 |
+
pro_n = proto
|
104 |
+
proto_grid = sup_y.clone().detach() # a single prototype for the whole image
|
105 |
+
resized_proto_grid = proto_grid
|
106 |
+
non_zero = torch.nonzero(proto_grid)
|
107 |
+
|
108 |
+
elif mode == 'gridconv':
|
109 |
+
nch = sup_x.shape[1]
|
110 |
+
|
111 |
+
sup_nshot = sup_x.shape[0]
|
112 |
+
# if len(sup_x.shape) > 4:
|
113 |
+
# sup_x = sup_x.squeeze()
|
114 |
+
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
|
115 |
+
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0) # way(1),nb, hw, nc
|
116 |
+
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
|
117 |
+
|
118 |
+
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
|
119 |
+
|
120 |
+
# get a grid of prototypes
|
121 |
+
proto_grid = sup_y_g.clone().detach()
|
122 |
+
proto_grid[proto_grid < thresh] = 0
|
123 |
+
# interpolate the grid to the original size
|
124 |
+
non_zero = torch.nonzero(proto_grid)
|
125 |
+
|
126 |
+
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
|
127 |
+
for index in non_zero:
|
128 |
+
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
|
129 |
+
|
130 |
+
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
|
131 |
+
protos = n_sup_x[sup_y_g > thresh, :] # npro, nc
|
132 |
+
pro_n = safe_norm(protos)
|
133 |
+
|
134 |
+
elif mode == 'gridconv+':
|
135 |
+
nch = sup_x.shape[1]
|
136 |
+
n_sup_x = F.avg_pool2d(sup_x, val_wsize) if isval else self.avg_pool_op( sup_x )
|
137 |
+
sup_nshot = sup_x.shape[0]
|
138 |
+
n_sup_x = n_sup_x.view(sup_nshot, nch, -1).permute(0,2,1).unsqueeze(0)
|
139 |
+
n_sup_x = n_sup_x.reshape(1, -1, nch).unsqueeze(0)
|
140 |
+
sup_y_g = F.avg_pool2d(sup_y, val_wsize) if isval else self.avg_pool_op(sup_y)
|
141 |
+
|
142 |
+
# get a grid of prototypes
|
143 |
+
proto_grid = sup_y_g.clone().detach()
|
144 |
+
proto_grid[proto_grid < thresh] = 0
|
145 |
+
non_zero = torch.nonzero(proto_grid)
|
146 |
+
for i, idx in enumerate(non_zero):
|
147 |
+
proto_grid[0, idx[1], idx[2], idx[3]] = i + 1
|
148 |
+
resized_proto_grid = torch.zeros([1, 1, proto_grid.shape[2]*val_wsize, proto_grid.shape[3]*val_wsize])
|
149 |
+
for index in non_zero:
|
150 |
+
resized_proto_grid[0, 0, index[2]*val_wsize:index[2]*val_wsize + val_wsize, index[3]*val_wsize:index[3]*val_wsize + 2] = proto_grid[0, 0, index[2], index[3]]
|
151 |
+
|
152 |
+
sup_y_g = sup_y_g.view( sup_nshot, 1, -1 ).permute(1, 0, 2).view(1, -1).unsqueeze(0)
|
153 |
+
protos = n_sup_x[sup_y_g > thresh, :]
|
154 |
+
|
155 |
+
glb_proto = torch.sum(sup_x * sup_y, dim=(-1, -2)) \
|
156 |
+
/ (sup_y.sum(dim=(-1, -2)) + 1e-5)
|
157 |
+
|
158 |
+
pro_n = safe_norm(torch.cat( [protos, glb_proto], dim = 0 ))
|
159 |
+
return pro_n, resized_proto_grid, non_zero
|
160 |
+
|
161 |
+
def forward(self, qry, sup_x, sup_y, mode, thresh, isval = False, val_wsize = None, vis_sim = False, get_prototypes=False, **kwargs):
|
162 |
+
"""
|
163 |
+
Now supports
|
164 |
+
Args:
|
165 |
+
mode: 'mask'/ 'grid'. if mask, works as original prototyping
|
166 |
+
qry: [way(1), nc, h, w]
|
167 |
+
sup_x: [nb, nc, h, w]
|
168 |
+
sup_y: [nb, 1, h, w]
|
169 |
+
vis_sim: visualize raw similarities or not
|
170 |
+
New
|
171 |
+
mode: 'mask'/ 'grid'. if mask, works as original prototyping
|
172 |
+
qry: [way(1), nb(1), nc, h, w]
|
173 |
+
sup_x: [way(1), shot, nb(1), nc, h, w]
|
174 |
+
sup_y: [way(1), shot, nb(1), h, w]
|
175 |
+
vis_sim: visualize raw similarities or not
|
176 |
+
"""
|
177 |
+
|
178 |
+
qry = qry.squeeze(1) # [way(1), nb(1), nc, hw] -> [way(1), nc, h, w]
|
179 |
+
sup_x = sup_x.squeeze(0).squeeze(1) # [nshot, nc, h, w]
|
180 |
+
sup_y = sup_y.squeeze(0) # [nshot, 1, h, w]
|
181 |
+
|
182 |
+
def safe_norm(x, p = 2, dim = 1, eps = 1e-4):
|
183 |
+
x_norm = torch.norm(x, p = p, dim = dim) # .detach()
|
184 |
+
x_norm = torch.max(x_norm, torch.ones_like(x_norm).cuda() * eps)
|
185 |
+
x = x.div(x_norm.unsqueeze(1).expand_as(x))
|
186 |
+
return x
|
187 |
+
if val_wsize is None:
|
188 |
+
val_wsize = self.avg_pool_op.kernel_size
|
189 |
+
if isinstance(val_wsize, (tuple, list)):
|
190 |
+
val_wsize = val_wsize[0]
|
191 |
+
sup_y = sup_y.reshape(sup_x.shape[0], 1, sup_x.shape[-2], sup_x.shape[-1])
|
192 |
+
pro_n, proto_grid, proto_indices = self.get_prototypes(sup_x, sup_y, mode, val_wsize, thresh, isval)
|
193 |
+
if 0 in pro_n.shape:
|
194 |
+
print("failed to find prototypes")
|
195 |
+
qry_n = qry if mode == 'mask' else safe_norm(qry)
|
196 |
+
pred_grid, debug_assign, vis_dict = self.get_prediction_from_prototypes(pro_n, qry_n, mode, vis_sim=vis_sim)
|
197 |
+
|
198 |
+
return pred_grid, debug_assign, vis_dict, proto_grid
|
199 |
+
|
models/backbone/__init__.py
ADDED
File without changes
|
models/backbone/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (170 Bytes). View file
|
|
models/backbone/__pycache__/torchvision_backbones.cpython-312.pyc
ADDED
Binary file (2.78 kB). View file
|
|
models/backbone/torchvision_backbones.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Backbones supported by torchvison.
|
3 |
+
"""
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
class TVDeeplabRes101Encoder(nn.Module):
|
13 |
+
"""
|
14 |
+
FCN-Resnet101 backbone from torchvision deeplabv3
|
15 |
+
No ASPP is used as we found emperically it hurts performance
|
16 |
+
"""
|
17 |
+
def __init__(self, use_coco_init, aux_dim_keep = 64, use_aspp = False):
|
18 |
+
super().__init__()
|
19 |
+
_model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=use_coco_init, progress=True, num_classes=21, aux_loss=None)
|
20 |
+
if use_coco_init:
|
21 |
+
print("###### NETWORK: Using ms-coco initialization ######")
|
22 |
+
else:
|
23 |
+
print("###### NETWORK: Training from scratch ######")
|
24 |
+
|
25 |
+
_model_list = list(_model.children())
|
26 |
+
self.aux_dim_keep = aux_dim_keep
|
27 |
+
self.backbone = _model_list[0]
|
28 |
+
self.localconv = nn.Conv2d(2048, 256,kernel_size = 1, stride = 1, bias = False) # reduce feature map dimension
|
29 |
+
self.asppconv = nn.Conv2d(256, 256,kernel_size = 1, bias = False)
|
30 |
+
|
31 |
+
_aspp = _model_list[1][0]
|
32 |
+
_conv256 = _model_list[1][1]
|
33 |
+
self.aspp_out = nn.Sequential(*[_aspp, _conv256] )
|
34 |
+
self.use_aspp = use_aspp
|
35 |
+
|
36 |
+
def forward(self, x_in, low_level):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
low_level: whether returning aggregated low-level features in FCN
|
40 |
+
"""
|
41 |
+
fts = self.backbone(x_in)
|
42 |
+
if self.use_aspp:
|
43 |
+
fts256 = self.aspp_out(fts['out'])
|
44 |
+
high_level_fts = fts256
|
45 |
+
else:
|
46 |
+
fts2048 = fts['out']
|
47 |
+
high_level_fts = self.localconv(fts2048)
|
48 |
+
|
49 |
+
if low_level:
|
50 |
+
low_level_fts = fts['aux'][:, : self.aux_dim_keep]
|
51 |
+
return high_level_fts, low_level_fts
|
52 |
+
else:
|
53 |
+
return high_level_fts
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
models/grid_proto_fewshot.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ALPNet
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .alpmodule import MultiProtoAsConv
|
8 |
+
from .backbone.torchvision_backbones import TVDeeplabRes101Encoder
|
9 |
+
from util.consts import DEFAULT_FEATURE_SIZE
|
10 |
+
from util.lora import inject_trainable_lora
|
11 |
+
# from util.utils import load_config_from_url, plot_dinov2_fts
|
12 |
+
import math
|
13 |
+
|
14 |
+
# Specify a local path to the repository (or use installed package instead)
|
15 |
+
FG_PROT_MODE = 'gridconv+' # using both local and global prototype
|
16 |
+
# FG_PROT_MODE = 'mask'
|
17 |
+
# using local prototype only. Also 'mask' refers to using global prototype only (as done in vanilla PANet)
|
18 |
+
BG_PROT_MODE = 'gridconv'
|
19 |
+
|
20 |
+
# thresholds for deciding class of prototypes
|
21 |
+
FG_THRESH = 0.95
|
22 |
+
BG_THRESH = 0.95
|
23 |
+
|
24 |
+
|
25 |
+
class FewShotSeg(nn.Module):
|
26 |
+
"""
|
27 |
+
ALPNet
|
28 |
+
Args:
|
29 |
+
in_channels: Number of input channels
|
30 |
+
cfg: Model configurations
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, image_size, pretrained_path=None, cfg=None):
|
34 |
+
super(FewShotSeg, self).__init__()
|
35 |
+
self.image_size = image_size
|
36 |
+
self.pretrained_path = pretrained_path
|
37 |
+
print(f'###### Pre-trained path: {self.pretrained_path} ######')
|
38 |
+
self.config = cfg or {
|
39 |
+
'align': False, 'debug': False}
|
40 |
+
self.get_encoder()
|
41 |
+
self.get_cls()
|
42 |
+
if self.pretrained_path:
|
43 |
+
self.load_state_dict(torch.load(self.pretrained_path), strict=True)
|
44 |
+
print(
|
45 |
+
f'###### Pre-trained model f{self.pretrained_path} has been loaded ######')
|
46 |
+
|
47 |
+
def get_encoder(self):
|
48 |
+
self.config['feature_hw'] = [DEFAULT_FEATURE_SIZE,
|
49 |
+
DEFAULT_FEATURE_SIZE] # default feature map size
|
50 |
+
if self.config['which_model'] == 'dlfcn_res101' or self.config['which_model'] == 'default':
|
51 |
+
use_coco_init = self.config['use_coco_init']
|
52 |
+
self.encoder = TVDeeplabRes101Encoder(use_coco_init)
|
53 |
+
self.config['feature_hw'] = [
|
54 |
+
math.ceil(self.image_size/8), math.ceil(self.image_size/8)]
|
55 |
+
elif self.config['which_model'] == 'dinov2_l14':
|
56 |
+
self.encoder = torch.hub.load(
|
57 |
+
'facebookresearch/dinov2', 'dinov2_vitl14')
|
58 |
+
self.config['feature_hw'] = [max(
|
59 |
+
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
|
60 |
+
elif self.config['which_model'] == 'dinov2_l14_reg':
|
61 |
+
try:
|
62 |
+
self.encoder = torch.hub.load(
|
63 |
+
'facebookresearch/dinov2', 'dinov2_vitl14_reg')
|
64 |
+
except RuntimeError as e:
|
65 |
+
self.encoder = torch.hub.load(
|
66 |
+
'facebookresearch/dino', 'dinov2_vitl14_reg', force_reload=True)
|
67 |
+
self.config['feature_hw'] = [max(
|
68 |
+
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
|
69 |
+
elif self.config['which_model'] == 'dinov2_b14':
|
70 |
+
self.encoder = torch.hub.load(
|
71 |
+
'facebookresearch/dinov2', 'dinov2_vitb14')
|
72 |
+
self.config['feature_hw'] = [max(
|
73 |
+
self.image_size//14, DEFAULT_FEATURE_SIZE), max(self.image_size//14, DEFAULT_FEATURE_SIZE)]
|
74 |
+
else:
|
75 |
+
raise NotImplementedError(
|
76 |
+
f'Backbone network {self.config["which_model"]} not implemented')
|
77 |
+
|
78 |
+
if self.config['lora'] > 0:
|
79 |
+
self.encoder.requires_grad_(False)
|
80 |
+
print(f'Injecting LoRA with rank:{self.config["lora"]}')
|
81 |
+
encoder_lora_params = inject_trainable_lora(
|
82 |
+
self.encoder, r=self.config['lora'])
|
83 |
+
|
84 |
+
def get_features(self, imgs_concat):
|
85 |
+
if self.config['which_model'] == 'dlfcn_res101':
|
86 |
+
img_fts = self.encoder(imgs_concat, low_level=False)
|
87 |
+
elif 'dino' in self.config['which_model']:
|
88 |
+
# resize imgs_concat to the closest size that is divisble by 14
|
89 |
+
imgs_concat = F.interpolate(imgs_concat, size=(
|
90 |
+
self.image_size // 14 * 14, self.image_size // 14 * 14), mode='bilinear')
|
91 |
+
dino_fts = self.encoder.forward_features(imgs_concat)
|
92 |
+
img_fts = dino_fts["x_norm_patchtokens"] # B, HW, C
|
93 |
+
img_fts = img_fts.permute(0, 2, 1) # B, C, HW
|
94 |
+
C, HW = img_fts.shape[-2:]
|
95 |
+
img_fts = img_fts.view(-1, C, int(HW**0.5),
|
96 |
+
int(HW**0.5)) # B, C, H, W
|
97 |
+
if HW < DEFAULT_FEATURE_SIZE ** 2:
|
98 |
+
img_fts = F.interpolate(img_fts, size=(
|
99 |
+
DEFAULT_FEATURE_SIZE, DEFAULT_FEATURE_SIZE), mode='bilinear') # this is if h,w < (32,32)
|
100 |
+
else:
|
101 |
+
raise NotImplementedError(
|
102 |
+
f'Backbone network {self.config["which_model"]} not implemented')
|
103 |
+
|
104 |
+
return img_fts
|
105 |
+
|
106 |
+
def get_cls(self):
|
107 |
+
"""
|
108 |
+
Obtain the similarity-based classifier
|
109 |
+
"""
|
110 |
+
proto_hw = self.config["proto_grid_size"]
|
111 |
+
|
112 |
+
if self.config['cls_name'] == 'grid_proto':
|
113 |
+
embed_dim = 256
|
114 |
+
if 'dinov2_b14' in self.config['which_model']:
|
115 |
+
embed_dim = 768
|
116 |
+
elif 'dinov2_l14' in self.config['which_model']:
|
117 |
+
embed_dim = 1024
|
118 |
+
self.cls_unit = MultiProtoAsConv(proto_grid=[proto_hw, proto_hw], feature_hw=self.config["feature_hw"], embed_dim=embed_dim) # when treating it as ordinary prototype
|
119 |
+
print(f"cls unit feature hw: {self.cls_unit.feature_hw}")
|
120 |
+
else:
|
121 |
+
raise NotImplementedError(
|
122 |
+
f'Classifier {self.config["cls_name"]} not implemented')
|
123 |
+
|
124 |
+
def forward_resolutions(self, resolutions, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
|
125 |
+
predictions = []
|
126 |
+
for res in resolutions:
|
127 |
+
supp_imgs_resized = [[F.interpolate(supp_img[0], size=(
|
128 |
+
res, res), mode='bilinear') for supp_img in supp_imgs]] if supp_imgs[0][0].shape[-1] != res else supp_imgs
|
129 |
+
fore_mask_resized = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
|
130 |
+
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != res else fore_mask
|
131 |
+
back_mask_resized = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(res, res), mode='bilinear')[
|
132 |
+
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != res else back_mask
|
133 |
+
qry_imgs_resized = [F.interpolate(qry_img, size=(res, res), mode='bilinear')
|
134 |
+
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != res else qry_imgs
|
135 |
+
|
136 |
+
pred = self.forward(supp_imgs_resized, fore_mask_resized, back_mask_resized,
|
137 |
+
qry_imgs_resized, isval, val_wsize, show_viz, supp_fts)[0]
|
138 |
+
predictions.append(pred)
|
139 |
+
|
140 |
+
def resize_inputs_to_image_size(self, supp_imgs, fore_mask, back_mask, qry_imgs):
|
141 |
+
supp_imgs = [[F.interpolate(supp_img, size=(
|
142 |
+
self.image_size, self.image_size), mode='bilinear') for supp_img in supp_imgs_way] for supp_imgs_way in supp_imgs]
|
143 |
+
fore_mask = [[F.interpolate(fore_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
|
144 |
+
0] for fore_mask_way in fore_mask]] if fore_mask[0][0].shape[-1] != self.image_size else fore_mask
|
145 |
+
back_mask = [[F.interpolate(back_mask_way[0].unsqueeze(0), size=(self.image_size, self.image_size), mode='bilinear')[
|
146 |
+
0] for back_mask_way in back_mask]] if back_mask[0][0].shape[-1] != self.image_size else back_mask
|
147 |
+
qry_imgs = [F.interpolate(qry_img, size=(self.image_size, self.image_size), mode='bilinear')
|
148 |
+
for qry_img in qry_imgs] if qry_imgs[0][0].shape[-1] != self.image_size else qry_imgs
|
149 |
+
return supp_imgs, fore_mask, back_mask, qry_imgs
|
150 |
+
|
151 |
+
def forward(self, supp_imgs, fore_mask, back_mask, qry_imgs, isval, val_wsize, show_viz=False, supp_fts=None):
|
152 |
+
"""
|
153 |
+
Args:
|
154 |
+
supp_imgs: support images
|
155 |
+
way x shot x [B x 3 x H x W], list of lists of tensors
|
156 |
+
fore_mask: foreground masks for support images
|
157 |
+
way x shot x [B x H x W], list of lists of tensors
|
158 |
+
back_mask: background masks for support images
|
159 |
+
way x shot x [B x H x W], list of lists of tensors
|
160 |
+
qry_imgs: query images
|
161 |
+
N x [B x 3 x H x W], list of tensors
|
162 |
+
show_viz: return the visualization dictionary
|
163 |
+
"""
|
164 |
+
# ('Please go through this piece of code carefully')
|
165 |
+
# supp_imgs, fore_mask, back_mask, qry_imgs = self.resize_inputs_to_image_size(
|
166 |
+
# supp_imgs, fore_mask, back_mask, qry_imgs)
|
167 |
+
|
168 |
+
n_ways = len(supp_imgs)
|
169 |
+
n_shots = len(supp_imgs[0])
|
170 |
+
n_queries = len(qry_imgs)
|
171 |
+
|
172 |
+
# NOTE: actual shot in support goes in batch dimension
|
173 |
+
assert n_ways == 1, "Multi-shot has not been implemented yet"
|
174 |
+
assert n_queries == 1
|
175 |
+
|
176 |
+
sup_bsize = supp_imgs[0][0].shape[0]
|
177 |
+
img_size = supp_imgs[0][0].shape[-2:]
|
178 |
+
if self.config["cls_name"] == 'grid_proto_3d':
|
179 |
+
img_size = supp_imgs[0][0].shape[-3:]
|
180 |
+
qry_bsize = qry_imgs[0].shape[0]
|
181 |
+
|
182 |
+
imgs_concat = torch.cat([torch.cat(way, dim=0) for way in supp_imgs]
|
183 |
+
+ [torch.cat(qry_imgs, dim=0),], dim=0)
|
184 |
+
|
185 |
+
img_fts = self.get_features(imgs_concat)
|
186 |
+
if len(img_fts.shape) == 5: # for 3D
|
187 |
+
fts_size = img_fts.shape[-3:]
|
188 |
+
else:
|
189 |
+
fts_size = img_fts.shape[-2:]
|
190 |
+
if supp_fts is None:
|
191 |
+
supp_fts = img_fts[:n_ways * n_shots * sup_bsize].view(
|
192 |
+
n_ways, n_shots, sup_bsize, -1, *fts_size) # wa x sh x b x c x h' x w'
|
193 |
+
qry_fts = img_fts[n_ways * n_shots * sup_bsize:].view(
|
194 |
+
n_queries, qry_bsize, -1, *fts_size) # N x B x C x H' x W'
|
195 |
+
else:
|
196 |
+
# N x B x C x H' x W'
|
197 |
+
qry_fts = img_fts.view(n_queries, qry_bsize, -1, *fts_size)
|
198 |
+
|
199 |
+
fore_mask = torch.stack([torch.stack(way, dim=0)
|
200 |
+
for way in fore_mask], dim=0) # Wa x Sh x B x H' x W'
|
201 |
+
fore_mask = torch.autograd.Variable(fore_mask, requires_grad=True)
|
202 |
+
back_mask = torch.stack([torch.stack(way, dim=0)
|
203 |
+
for way in back_mask], dim=0) # Wa x Sh x B x H' x W'
|
204 |
+
|
205 |
+
###### Compute loss ######
|
206 |
+
align_loss = 0
|
207 |
+
outputs = []
|
208 |
+
visualizes = [] # the buffer for visualization
|
209 |
+
|
210 |
+
for epi in range(1): # batch dimension, fixed to 1
|
211 |
+
fg_masks = [] # keep the way part
|
212 |
+
|
213 |
+
'''
|
214 |
+
for way in range(n_ways):
|
215 |
+
# note: index of n_ways starts from 0
|
216 |
+
mean_sup_ft = supp_fts[way].mean(dim = 0) # [ nb, C, H, W]. Just assume batch size is 1 as pytorch only allows this
|
217 |
+
mean_sup_msk = F.interpolate(fore_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear')
|
218 |
+
fg_masks.append( mean_sup_msk )
|
219 |
+
|
220 |
+
mean_bg_msk = F.interpolate(back_mask[way].mean(dim = 0).unsqueeze(1), size = mean_sup_ft.shape[-2:], mode = 'bilinear') # [nb, C, H, W]
|
221 |
+
'''
|
222 |
+
# re-interpolate support mask to the same size as support feature
|
223 |
+
if len(fts_size) == 3: # TODO make more generic
|
224 |
+
res_fg_msk = torch.stack([F.interpolate(fore_mask[0][0].unsqueeze(
|
225 |
+
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
|
226 |
+
res_bg_msk = torch.stack([F.interpolate(back_mask[0][0].unsqueeze(
|
227 |
+
0), size=fts_size, mode='nearest')], dim=0) # [nway, ns, nb, nd', nh', nw'])
|
228 |
+
else:
|
229 |
+
res_fg_msk = torch.stack([F.interpolate(fore_mask_w, size=fts_size, mode='nearest')
|
230 |
+
for fore_mask_w in fore_mask], dim=0) # [nway, ns, nb, nh', nw']
|
231 |
+
res_bg_msk = torch.stack([F.interpolate(back_mask_w, size=fts_size, mode='nearest')
|
232 |
+
for back_mask_w in back_mask], dim=0) # [nway, ns, nb, nh', nw']
|
233 |
+
|
234 |
+
scores = []
|
235 |
+
assign_maps = []
|
236 |
+
bg_sim_maps = []
|
237 |
+
fg_sim_maps = []
|
238 |
+
bg_mode = BG_PROT_MODE
|
239 |
+
|
240 |
+
_raw_score, _, aux_attr, _ = self.cls_unit(
|
241 |
+
qry_fts, supp_fts, res_bg_msk, mode=bg_mode, thresh=BG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
|
242 |
+
scores.append(_raw_score)
|
243 |
+
assign_maps.append(aux_attr['proto_assign'])
|
244 |
+
|
245 |
+
for way, _msks in enumerate(res_fg_msk):
|
246 |
+
raw_scores = []
|
247 |
+
for i, _msk in enumerate(_msks):
|
248 |
+
_msk = _msk.unsqueeze(0)
|
249 |
+
supp_ft = supp_fts[:, i].unsqueeze(0)
|
250 |
+
if self.config["cls_name"] == 'grid_proto_3d': # 3D
|
251 |
+
k_size = self.cls_unit.kernel_size
|
252 |
+
fg_mode = FG_PROT_MODE if F.avg_pool3d(_msk, k_size).max(
|
253 |
+
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask' # TODO figure out kernel size
|
254 |
+
else:
|
255 |
+
k_size = self.cls_unit.kernel_size
|
256 |
+
fg_mode = FG_PROT_MODE if F.avg_pool2d(_msk, k_size).max(
|
257 |
+
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
|
258 |
+
# TODO figure out kernel size
|
259 |
+
_raw_score, _, aux_attr, proto_grid = self.cls_unit(qry_fts, supp_ft, _msk.unsqueeze(
|
260 |
+
0), mode=fg_mode, thresh=FG_THRESH, isval=isval, val_wsize=val_wsize, vis_sim=show_viz)
|
261 |
+
raw_scores.append(_raw_score)
|
262 |
+
|
263 |
+
# create a score where each feature is the max of the raw_score
|
264 |
+
_raw_score = torch.stack(raw_scores, dim=1).max(dim=1)[
|
265 |
+
0]
|
266 |
+
scores.append(_raw_score)
|
267 |
+
assign_maps.append(aux_attr['proto_assign'])
|
268 |
+
if show_viz:
|
269 |
+
fg_sim_maps.append(aux_attr['raw_local_sims'])
|
270 |
+
# print(f"Time for fg: {time.time() - start_time}")
|
271 |
+
pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
|
272 |
+
interpolate_mode = 'bilinear'
|
273 |
+
outputs.append(F.interpolate(
|
274 |
+
pred, size=img_size, mode=interpolate_mode))
|
275 |
+
|
276 |
+
###### Prototype alignment loss ######
|
277 |
+
if self.config['align'] and self.training:
|
278 |
+
align_loss_epi = self.alignLoss(qry_fts[:, epi], pred, supp_fts[:, :, epi],
|
279 |
+
fore_mask[:, :, epi], back_mask[:, :, epi])
|
280 |
+
align_loss += align_loss_epi
|
281 |
+
|
282 |
+
output = torch.stack(outputs, dim=1) # N x B x (1 + Wa) x H x W
|
283 |
+
grid_shape = output.shape[2:]
|
284 |
+
if self.config["cls_name"] == 'grid_proto_3d':
|
285 |
+
grid_shape = output.shape[2:]
|
286 |
+
output = output.view(-1, *grid_shape)
|
287 |
+
assign_maps = torch.stack(assign_maps, dim=1) if show_viz else None
|
288 |
+
bg_sim_maps = torch.stack(bg_sim_maps, dim=1) if show_viz else None
|
289 |
+
fg_sim_maps = torch.stack(fg_sim_maps, dim=1) if show_viz else None
|
290 |
+
|
291 |
+
return output, align_loss / sup_bsize, [bg_sim_maps, fg_sim_maps], assign_maps, proto_grid, supp_fts, qry_fts
|
292 |
+
|
293 |
+
|
294 |
+
def alignLoss(self, qry_fts, pred, supp_fts, fore_mask, back_mask):
|
295 |
+
"""
|
296 |
+
Compute the loss for the prototype alignment branch
|
297 |
+
|
298 |
+
Args:
|
299 |
+
qry_fts: embedding features for query images
|
300 |
+
expect shape: N x C x H' x W'
|
301 |
+
pred: predicted segmentation score
|
302 |
+
expect shape: N x (1 + Wa) x H x W
|
303 |
+
supp_fts: embedding fatures for support images
|
304 |
+
expect shape: Wa x Sh x C x H' x W'
|
305 |
+
fore_mask: foreground masks for support images
|
306 |
+
expect shape: way x shot x H x W
|
307 |
+
back_mask: background masks for support images
|
308 |
+
expect shape: way x shot x H x W
|
309 |
+
"""
|
310 |
+
n_ways, n_shots = len(fore_mask), len(fore_mask[0])
|
311 |
+
|
312 |
+
# Masks for getting query prototype
|
313 |
+
pred_mask = pred.argmax(dim=1).unsqueeze(0) # 1 x N x H' x W'
|
314 |
+
binary_masks = [pred_mask == i for i in range(1 + n_ways)]
|
315 |
+
|
316 |
+
# skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
|
317 |
+
# FIXME: fix this in future we here make a stronger assumption that a positive class must be there to avoid undersegmentation/ lazyness
|
318 |
+
skip_ways = []
|
319 |
+
|
320 |
+
# added for matching dimensions to the new data format
|
321 |
+
qry_fts = qry_fts.unsqueeze(0).unsqueeze(
|
322 |
+
2) # added to nway(1) and nb(1)
|
323 |
+
# end of added part
|
324 |
+
|
325 |
+
loss = []
|
326 |
+
for way in range(n_ways):
|
327 |
+
if way in skip_ways:
|
328 |
+
continue
|
329 |
+
# Get the query prototypes
|
330 |
+
for shot in range(n_shots):
|
331 |
+
# actual local query [way(1), nb(1, nb is now nshot), nc, h, w]
|
332 |
+
img_fts = supp_fts[way: way + 1, shot: shot + 1]
|
333 |
+
size = img_fts.shape[-2:]
|
334 |
+
mode = 'bilinear'
|
335 |
+
if self.config["cls_name"] == 'grid_proto_3d':
|
336 |
+
size = img_fts.shape[-3:]
|
337 |
+
mode = 'trilinear'
|
338 |
+
qry_pred_fg_msk = F.interpolate(
|
339 |
+
binary_masks[way + 1].float(), size=size, mode=mode) # [1 (way), n (shot), h, w]
|
340 |
+
|
341 |
+
# background
|
342 |
+
qry_pred_bg_msk = F.interpolate(
|
343 |
+
binary_masks[0].float(), size=size, mode=mode) # 1, n, h ,w
|
344 |
+
scores = []
|
345 |
+
|
346 |
+
bg_mode = BG_PROT_MODE
|
347 |
+
_raw_score_bg, _, _, _ = self.cls_unit(
|
348 |
+
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_bg_msk.unsqueeze(-3), mode=bg_mode, thresh=BG_THRESH)
|
349 |
+
|
350 |
+
scores.append(_raw_score_bg)
|
351 |
+
if self.config["cls_name"] == 'grid_proto_3d':
|
352 |
+
fg_mode = FG_PROT_MODE if F.avg_pool3d(qry_pred_fg_msk, 4).max(
|
353 |
+
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
|
354 |
+
else:
|
355 |
+
fg_mode = FG_PROT_MODE if F.avg_pool2d(qry_pred_fg_msk, 4).max(
|
356 |
+
) >= FG_THRESH and FG_PROT_MODE != 'mask' else 'mask'
|
357 |
+
_raw_score_fg, _, _, _ = self.cls_unit(
|
358 |
+
qry=img_fts, sup_x=qry_fts, sup_y=qry_pred_fg_msk.unsqueeze(2), mode=fg_mode, thresh=FG_THRESH)
|
359 |
+
scores.append(_raw_score_fg)
|
360 |
+
|
361 |
+
supp_pred = torch.cat(scores, dim=1) # N x (1 + Wa) x H' x W'
|
362 |
+
size = fore_mask.shape[-2:]
|
363 |
+
if self.config["cls_name"] == 'grid_proto_3d':
|
364 |
+
size = fore_mask.shape[-3:]
|
365 |
+
supp_pred = F.interpolate(supp_pred, size=size, mode=mode)
|
366 |
+
|
367 |
+
# Construct the support Ground-Truth segmentation
|
368 |
+
supp_label = torch.full_like(fore_mask[way, shot], 255,
|
369 |
+
device=img_fts.device).long()
|
370 |
+
supp_label[fore_mask[way, shot] == 1] = 1
|
371 |
+
supp_label[back_mask[way, shot] == 1] = 0
|
372 |
+
# Compute Loss
|
373 |
+
loss.append(F.cross_entropy(
|
374 |
+
supp_pred.float(), supp_label[None, ...], ignore_index=255) / n_shots / n_ways)
|
375 |
+
|
376 |
+
return torch.sum(torch.stack(loss))
|
377 |
+
|
378 |
+
def dino_cls_loss(self, teacher_cls_tokens, student_cls_tokens):
|
379 |
+
cls_loss_weight = 0.1
|
380 |
+
student_temp = 1
|
381 |
+
teacher_cls_tokens = self.sinkhorn_knopp_teacher(teacher_cls_tokens)
|
382 |
+
lsm = F.log_softmax(student_cls_tokens / student_temp, dim=-1)
|
383 |
+
cls_loss = torch.sum(teacher_cls_tokens * lsm, dim=-1)
|
384 |
+
|
385 |
+
return -cls_loss.mean() * cls_loss_weight
|
386 |
+
|
387 |
+
@torch.no_grad()
|
388 |
+
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp=1, n_iterations=3):
|
389 |
+
teacher_output = teacher_output.float()
|
390 |
+
# world_size = dist.get_world_size() if dist.is_initialized() else 1
|
391 |
+
# Q is K-by-B for consistency with notations from our paper
|
392 |
+
Q = torch.exp(teacher_output / teacher_temp).t()
|
393 |
+
# B = Q.shape[1] * world_size # number of samples to assign
|
394 |
+
B = Q.shape[1]
|
395 |
+
K = Q.shape[0] # how many prototypes
|
396 |
+
|
397 |
+
# make the matrix sums to 1
|
398 |
+
sum_Q = torch.sum(Q)
|
399 |
+
Q /= sum_Q
|
400 |
+
|
401 |
+
for it in range(n_iterations):
|
402 |
+
# normalize each row: total weight per prototype must be 1/K
|
403 |
+
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
404 |
+
Q /= sum_of_rows
|
405 |
+
Q /= K
|
406 |
+
|
407 |
+
# normalize each column: total weight per sample must be 1/B
|
408 |
+
Q /= torch.sum(Q, dim=0, keepdim=True)
|
409 |
+
Q /= B
|
410 |
+
|
411 |
+
Q *= B # the columns must sum to 1 so that Q is an assignment
|
412 |
+
return Q.t()
|
413 |
+
|
414 |
+
def dino_patch_loss(self, features, masked_features, masks):
|
415 |
+
# for both supp and query features perform the patch wise loss
|
416 |
+
loss = 0.0
|
417 |
+
weight = 0.1
|
418 |
+
B = features.shape[0]
|
419 |
+
for (f, mf, mask) in zip(features, masked_features, masks):
|
420 |
+
# TODO sinkhorn knopp center features
|
421 |
+
f = f[mask]
|
422 |
+
f = self.sinkhorn_knopp_teacher(f)
|
423 |
+
mf = mf[mask]
|
424 |
+
loss += torch.sum(f * F.log_softmax(mf / 1,
|
425 |
+
dim=-1), dim=-1) / mask.sum()
|
426 |
+
|
427 |
+
return -loss.sum() * weight / B
|
models/segment_anything/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam,
|
9 |
+
build_sam_vit_h,
|
10 |
+
build_sam_vit_l,
|
11 |
+
build_sam_vit_b,
|
12 |
+
sam_model_registry,
|
13 |
+
)
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
models/segment_anything/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (451 Bytes). View file
|
|
models/segment_anything/__pycache__/automatic_mask_generator.cpython-312.pyc
ADDED
Binary file (16.9 kB). View file
|
|
models/segment_anything/__pycache__/build_sam.cpython-312.pyc
ADDED
Binary file (2.77 kB). View file
|
|
models/segment_anything/__pycache__/predictor.cpython-312.pyc
ADDED
Binary file (13.9 kB). View file
|
|
models/segment_anything/automatic_mask_generator.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from .modeling import Sam
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class SamAutomaticMaskGenerator:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
points_per_side: Optional[int] = 32,
|
40 |
+
points_per_batch: int = 64,
|
41 |
+
pred_iou_thresh: float = 0.88,
|
42 |
+
stability_score_thresh: float = 0.95,
|
43 |
+
stability_score_offset: float = 1.0,
|
44 |
+
box_nms_thresh: float = 0.7,
|
45 |
+
crop_n_layers: int = 0,
|
46 |
+
crop_nms_thresh: float = 0.7,
|
47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
48 |
+
crop_n_points_downscale_factor: int = 1,
|
49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
50 |
+
min_mask_region_area: int = 0,
|
51 |
+
output_mode: str = "binary_mask",
|
52 |
+
custom_points: bool = "false",
|
53 |
+
) -> None:
|
54 |
+
"""
|
55 |
+
Using a SAM model, generates masks for the entire image.
|
56 |
+
Generates a grid of point prompts over the image, then filters
|
57 |
+
low quality and duplicate masks. The default settings are chosen
|
58 |
+
for SAM with a ViT-H backbone.
|
59 |
+
|
60 |
+
Arguments:
|
61 |
+
model (Sam): The SAM model to use for mask prediction.
|
62 |
+
points_per_side (int or None): The number of points to be sampled
|
63 |
+
along one side of the image. The total number of points is
|
64 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
65 |
+
point sampling.
|
66 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
67 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
68 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
69 |
+
model's predicted mask quality.
|
70 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
71 |
+
the stability of the mask under changes to the cutoff used to binarize
|
72 |
+
the model's mask predictions.
|
73 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
74 |
+
calculated the stability score.
|
75 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
76 |
+
suppression to filter duplicate masks.
|
77 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
78 |
+
crops of the image. Sets the number of layers to run, where each
|
79 |
+
layer has 2**i_layer number of image crops.
|
80 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks between different crops.
|
82 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
83 |
+
In the first crop layer, crops will overlap by this fraction of
|
84 |
+
the image length. Later layers with more crops scale down this overlap.
|
85 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
86 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
87 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
88 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
89 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
90 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
91 |
+
to remove disconnected regions and holes in masks with area smaller
|
92 |
+
than min_mask_region_area. Requires opencv.
|
93 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
94 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
95 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
96 |
+
memory.
|
97 |
+
"""
|
98 |
+
|
99 |
+
assert (points_per_side is None) != (
|
100 |
+
point_grids is None
|
101 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
102 |
+
if points_per_side is not None:
|
103 |
+
self.point_grids = build_all_layer_point_grids(
|
104 |
+
points_per_side,
|
105 |
+
crop_n_layers,
|
106 |
+
crop_n_points_downscale_factor,
|
107 |
+
)
|
108 |
+
elif point_grids is not None:
|
109 |
+
self.point_grids = point_grids
|
110 |
+
else:
|
111 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
112 |
+
|
113 |
+
assert output_mode in [
|
114 |
+
"binary_mask",
|
115 |
+
"uncompressed_rle",
|
116 |
+
"coco_rle",
|
117 |
+
], f"Unknown output_mode {output_mode}."
|
118 |
+
if output_mode == "coco_rle":
|
119 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
120 |
+
|
121 |
+
if min_mask_region_area > 0:
|
122 |
+
import cv2 # type: ignore # noqa: F401
|
123 |
+
|
124 |
+
self.predictor = SamPredictor(model)
|
125 |
+
self.points_per_batch = points_per_batch
|
126 |
+
self.pred_iou_thresh = pred_iou_thresh
|
127 |
+
self.stability_score_thresh = stability_score_thresh
|
128 |
+
self.stability_score_offset = stability_score_offset
|
129 |
+
self.box_nms_thresh = box_nms_thresh
|
130 |
+
self.crop_n_layers = crop_n_layers
|
131 |
+
self.crop_nms_thresh = crop_nms_thresh
|
132 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
133 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
134 |
+
self.min_mask_region_area = min_mask_region_area
|
135 |
+
self.output_mode = output_mode
|
136 |
+
self.custom_points = custom_points
|
137 |
+
|
138 |
+
@torch.no_grad()
|
139 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
140 |
+
"""
|
141 |
+
Generates masks for the given image.
|
142 |
+
|
143 |
+
Arguments:
|
144 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
148 |
+
a dict containing the following keys:
|
149 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
150 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
151 |
+
is a dictionary containing the RLE.
|
152 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
153 |
+
area (int): The area in pixels of the mask.
|
154 |
+
predicted_iou (float): The model's own prediction of the mask's
|
155 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
156 |
+
point_coords (list(list(float))): The point coordinates input
|
157 |
+
to the model to generate this mask.
|
158 |
+
stability_score (float): A measure of the mask's quality. This
|
159 |
+
is filtered on using the stability_score_thresh parameter.
|
160 |
+
crop_box (list(float)): The crop of the image used to generate
|
161 |
+
the mask, given in XYWH format.
|
162 |
+
"""
|
163 |
+
|
164 |
+
# Generate masks
|
165 |
+
mask_data = self._generate_masks(image)
|
166 |
+
|
167 |
+
# Filter small disconnected regions and holes in masks
|
168 |
+
if self.min_mask_region_area > 0:
|
169 |
+
mask_data = self.postprocess_small_regions(
|
170 |
+
mask_data,
|
171 |
+
self.min_mask_region_area,
|
172 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
173 |
+
)
|
174 |
+
|
175 |
+
# Encode masks
|
176 |
+
if self.output_mode == "coco_rle":
|
177 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
178 |
+
elif self.output_mode == "binary_mask":
|
179 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
180 |
+
else:
|
181 |
+
mask_data["segmentations"] = mask_data["rles"]
|
182 |
+
|
183 |
+
# Write mask records
|
184 |
+
curr_anns = []
|
185 |
+
for idx in range(len(mask_data["segmentations"])):
|
186 |
+
ann = {
|
187 |
+
"segmentation": mask_data["segmentations"][idx],
|
188 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
189 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
190 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
191 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
192 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
193 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
194 |
+
}
|
195 |
+
curr_anns.append(ann)
|
196 |
+
|
197 |
+
return curr_anns
|
198 |
+
|
199 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
200 |
+
orig_size = image.shape[:2]
|
201 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
202 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
203 |
+
)
|
204 |
+
|
205 |
+
# Iterate over image crops
|
206 |
+
data = MaskData()
|
207 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
208 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
209 |
+
data.cat(crop_data)
|
210 |
+
|
211 |
+
# Remove duplicate masks between crops
|
212 |
+
if len(crop_boxes) > 1:
|
213 |
+
# Prefer masks from smaller crops
|
214 |
+
scores = 1 / box_area(data["crop_boxes"])
|
215 |
+
scores = scores.to(data["boxes"].device)
|
216 |
+
keep_by_nms = batched_nms(
|
217 |
+
data["boxes"].float(),
|
218 |
+
scores,
|
219 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
220 |
+
iou_threshold=self.crop_nms_thresh,
|
221 |
+
)
|
222 |
+
data.filter(keep_by_nms)
|
223 |
+
|
224 |
+
data.to_numpy()
|
225 |
+
return data
|
226 |
+
|
227 |
+
def _process_crop(
|
228 |
+
self,
|
229 |
+
image: np.ndarray,
|
230 |
+
crop_box: List[int],
|
231 |
+
crop_layer_idx: int,
|
232 |
+
orig_size: Tuple[int, ...],
|
233 |
+
) -> MaskData:
|
234 |
+
# Crop the image and calculate embeddings
|
235 |
+
x0, y0, x1, y1 = crop_box
|
236 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
237 |
+
cropped_im_size = cropped_im.shape[:2]
|
238 |
+
self.predictor.set_image(cropped_im)
|
239 |
+
|
240 |
+
# Get points for this crop
|
241 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
242 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
243 |
+
|
244 |
+
# Generate masks for this crop in batches
|
245 |
+
data = MaskData()
|
246 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
247 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
248 |
+
data.cat(batch_data)
|
249 |
+
del batch_data
|
250 |
+
self.predictor.reset_image()
|
251 |
+
|
252 |
+
# Remove duplicates within this crop.
|
253 |
+
keep_by_nms = batched_nms(
|
254 |
+
data["boxes"].float(),
|
255 |
+
data["iou_preds"],
|
256 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
257 |
+
iou_threshold=self.box_nms_thresh,
|
258 |
+
)
|
259 |
+
data.filter(keep_by_nms)
|
260 |
+
|
261 |
+
# Return to the original image frame
|
262 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
263 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
264 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
265 |
+
|
266 |
+
return data
|
267 |
+
|
268 |
+
def _process_batch(
|
269 |
+
self,
|
270 |
+
points: np.ndarray,
|
271 |
+
im_size: Tuple[int, ...],
|
272 |
+
crop_box: List[int],
|
273 |
+
orig_size: Tuple[int, ...],
|
274 |
+
) -> MaskData:
|
275 |
+
orig_h, orig_w = orig_size
|
276 |
+
|
277 |
+
# Run model on this batch
|
278 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
279 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
280 |
+
if self.custom_points:
|
281 |
+
in_pos_labels = torch.ones(in_points.shape[0]//2, dtype=torch.int, device=in_points.device)
|
282 |
+
in_neg_labels = torch.zeros_like(in_pos_labels)
|
283 |
+
in_labels = torch.cat((in_pos_labels, in_neg_labels), dim = 0)
|
284 |
+
else:
|
285 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
286 |
+
|
287 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
288 |
+
in_points[:, None, :],
|
289 |
+
in_labels[:, None],
|
290 |
+
multimask_output=True,
|
291 |
+
return_logits=True,
|
292 |
+
)
|
293 |
+
|
294 |
+
# Serialize predictions and store in MaskData
|
295 |
+
data = MaskData(
|
296 |
+
masks=masks.flatten(0, 1),
|
297 |
+
iou_preds=iou_preds.flatten(0, 1),
|
298 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
299 |
+
)
|
300 |
+
del masks
|
301 |
+
|
302 |
+
# Filter by predicted IoU
|
303 |
+
if self.pred_iou_thresh > 0.0:
|
304 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
305 |
+
data.filter(keep_mask)
|
306 |
+
|
307 |
+
# Calculate stability score
|
308 |
+
data["stability_score"] = calculate_stability_score(
|
309 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
310 |
+
)
|
311 |
+
if self.stability_score_thresh > 0.0:
|
312 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
313 |
+
data.filter(keep_mask)
|
314 |
+
|
315 |
+
# Threshold masks and calculate boxes
|
316 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
317 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
318 |
+
|
319 |
+
# Filter boxes that touch crop boundaries
|
320 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
321 |
+
if not torch.all(keep_mask):
|
322 |
+
data.filter(keep_mask)
|
323 |
+
|
324 |
+
# Compress to RLE
|
325 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
326 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
327 |
+
del data["masks"]
|
328 |
+
|
329 |
+
return data
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def postprocess_small_regions(
|
333 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
334 |
+
) -> MaskData:
|
335 |
+
"""
|
336 |
+
Removes small disconnected regions and holes in masks, then reruns
|
337 |
+
box NMS to remove any new duplicates.
|
338 |
+
|
339 |
+
Edits mask_data in place.
|
340 |
+
|
341 |
+
Requires open-cv as a dependency.
|
342 |
+
"""
|
343 |
+
if len(mask_data["rles"]) == 0:
|
344 |
+
return mask_data
|
345 |
+
|
346 |
+
# Filter small disconnected regions and holes
|
347 |
+
new_masks = []
|
348 |
+
scores = []
|
349 |
+
for rle in mask_data["rles"]:
|
350 |
+
mask = rle_to_mask(rle)
|
351 |
+
|
352 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
353 |
+
unchanged = not changed
|
354 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
355 |
+
unchanged = unchanged and not changed
|
356 |
+
|
357 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
358 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
359 |
+
# so NMS will prefer ones that didn't need postprocessing
|
360 |
+
scores.append(float(unchanged))
|
361 |
+
|
362 |
+
# Recalculate boxes and remove any new duplicates
|
363 |
+
masks = torch.cat(new_masks, dim=0)
|
364 |
+
boxes = batched_mask_to_box(masks)
|
365 |
+
keep_by_nms = batched_nms(
|
366 |
+
boxes.float(),
|
367 |
+
torch.as_tensor(scores),
|
368 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
369 |
+
iou_threshold=nms_thresh,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Only recalculate RLEs for masks that have changed
|
373 |
+
for i_mask in keep_by_nms:
|
374 |
+
if scores[i_mask] == 0.0:
|
375 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
376 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
377 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
378 |
+
mask_data.filter(keep_by_nms)
|
379 |
+
|
380 |
+
return mask_data
|
models/segment_anything/build_sam.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, SamBatched
|
12 |
+
|
13 |
+
|
14 |
+
def build_sam_vit_h(checkpoint=None):
|
15 |
+
return _build_sam(
|
16 |
+
encoder_embed_dim=1280,
|
17 |
+
encoder_depth=32,
|
18 |
+
encoder_num_heads=16,
|
19 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
20 |
+
checkpoint=checkpoint,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
build_sam = build_sam_vit_h
|
25 |
+
|
26 |
+
|
27 |
+
def build_sam_vit_l(checkpoint=None):
|
28 |
+
return _build_sam(
|
29 |
+
encoder_embed_dim=1024,
|
30 |
+
encoder_depth=24,
|
31 |
+
encoder_num_heads=16,
|
32 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def build_sam_vit_b(checkpoint=None):
|
38 |
+
return _build_sam(
|
39 |
+
encoder_embed_dim=768,
|
40 |
+
encoder_depth=12,
|
41 |
+
encoder_num_heads=12,
|
42 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
43 |
+
checkpoint=checkpoint,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
sam_model_registry = {
|
48 |
+
"default": build_sam_vit_h,
|
49 |
+
"vit_h": build_sam_vit_h,
|
50 |
+
"vit_l": build_sam_vit_l,
|
51 |
+
"vit_b": build_sam_vit_b,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def _build_sam(
|
56 |
+
encoder_embed_dim,
|
57 |
+
encoder_depth,
|
58 |
+
encoder_num_heads,
|
59 |
+
encoder_global_attn_indexes,
|
60 |
+
checkpoint=None,
|
61 |
+
):
|
62 |
+
prompt_embed_dim = 256
|
63 |
+
image_size = 1024
|
64 |
+
vit_patch_size = 16
|
65 |
+
image_embedding_size = image_size // vit_patch_size
|
66 |
+
sam = SamBatched(
|
67 |
+
image_encoder=ImageEncoderViT(
|
68 |
+
depth=encoder_depth,
|
69 |
+
embed_dim=encoder_embed_dim,
|
70 |
+
img_size=image_size,
|
71 |
+
mlp_ratio=4,
|
72 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
73 |
+
num_heads=encoder_num_heads,
|
74 |
+
patch_size=vit_patch_size,
|
75 |
+
qkv_bias=True,
|
76 |
+
use_rel_pos=True,
|
77 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
78 |
+
window_size=14,
|
79 |
+
out_chans=prompt_embed_dim,
|
80 |
+
),
|
81 |
+
prompt_encoder=PromptEncoder(
|
82 |
+
embed_dim=prompt_embed_dim,
|
83 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
84 |
+
input_image_size=(image_size, image_size),
|
85 |
+
mask_in_chans=16,
|
86 |
+
),
|
87 |
+
mask_decoder=MaskDecoder(
|
88 |
+
num_multimask_outputs=3,
|
89 |
+
transformer=TwoWayTransformer(
|
90 |
+
depth=2,
|
91 |
+
embedding_dim=prompt_embed_dim,
|
92 |
+
mlp_dim=2048,
|
93 |
+
num_heads=8,
|
94 |
+
),
|
95 |
+
transformer_dim=prompt_embed_dim,
|
96 |
+
iou_head_depth=3,
|
97 |
+
iou_head_hidden_dim=256,
|
98 |
+
),
|
99 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
100 |
+
pixel_std=[58.395, 57.12, 57.375],
|
101 |
+
)
|
102 |
+
sam.eval()
|
103 |
+
if checkpoint is not None:
|
104 |
+
with open(checkpoint, "rb") as f:
|
105 |
+
state_dict = torch.load(f)
|
106 |
+
sam.load_state_dict(state_dict)
|
107 |
+
return sam
|
models/segment_anything/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam, SamBatched
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
models/segment_anything/modeling/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (454 Bytes). View file
|
|
models/segment_anything/modeling/__pycache__/common.cpython-312.pyc
ADDED
Binary file (2.91 kB). View file
|
|
models/segment_anything/modeling/__pycache__/image_encoder.cpython-312.pyc
ADDED
Binary file (19.3 kB). View file
|
|
models/segment_anything/modeling/__pycache__/mask_decoder.cpython-312.pyc
ADDED
Binary file (8.53 kB). View file
|
|
models/segment_anything/modeling/__pycache__/prompt_encoder.cpython-312.pyc
ADDED
Binary file (12.3 kB). View file
|
|
models/segment_anything/modeling/__pycache__/sam.cpython-312.pyc
ADDED
Binary file (12.3 kB). View file
|
|